From 4d8f690eaead9480e91fbfb5ab1bd745d7b1f70f Mon Sep 17 00:00:00 2001 From: "Romain F. Laine" Date: Fri, 2 Jul 2021 07:42:09 +0100 Subject: [PATCH] Upload for v1.13 release --- .../3D-RCAN_ZeroCostDL4Mic.ipynb | 2 +- .../Cellpose_2D_ZeroCostDL4Mic.ipynb | 2183 +---------- .../DRMIME_2D_ZeroCostDL4Mic.ipynb | 2 +- .../DecoNoising_2D_ZeroCostDL4Mic.ipynb | 2062 +--------- ...roCostDL4Mic_BioImageModelZoo_export.ipynb | 3427 +---------------- .../DenoiSeg_ZeroCostDL4Mic.ipynb | 2167 +---------- .../Detectron2_2D_ZeroCostDL4Mic.ipynb | 2 +- .../MaskRCNN_ZeroCostDL4Mic.ipynb | 1 + ...roCostDL4Mic_BioImageModelZoo_export.ipynb | 2423 +----------- ...roCostDL4Mic_BioImageModelZoo_export.ipynb | 2690 +------------ ...Mic_Interactive_annotations_Cellpose.ipynb | 2 +- .../fnet_2D_ZeroCostDL4Mic.ipynb | 1 + Colab_notebooks/CARE_2D_ZeroCostDL4Mic.ipynb | 2040 +--------- Colab_notebooks/CARE_3D_ZeroCostDL4Mic.ipynb | 2083 +--------- Colab_notebooks/ChangeLog.txt | 22 + Colab_notebooks/CycleGAN_ZeroCostDL4Mic.ipynb | 2149 +---------- .../Deep-STORM_2D_ZeroCostDL4Mic.ipynb | 3111 +-------------- .../Latest_ZeroCostDL4Mic_Release.csv | 2 +- .../Noise2Void_2D_ZeroCostDL4Mic.ipynb | 2 +- .../Noise2Void_3D_ZeroCostDL4Mic.ipynb | 2 +- .../StarDist_2D_ZeroCostDL4Mic.ipynb | 2270 +---------- .../StarDist_3D_ZeroCostDL4Mic.ipynb | 1878 +-------- Colab_notebooks/Template_ZeroCostDL4Mic.ipynb | 1740 +-------- Colab_notebooks/U-Net_2D_ZeroCostDL4Mic.ipynb | 2 +- Colab_notebooks/U-Net_3D_ZeroCostDL4Mic.ipynb | 2404 +----------- Colab_notebooks/YOLOv2_ZeroCostDL4Mic.ipynb | 2 +- Colab_notebooks/fnet_2D_ZeroCostDL4Mic.ipynb | 1 + Colab_notebooks/fnet_3D_ZeroCostDL4Mic.ipynb | 1 + Colab_notebooks/fnet_ZeroCostDL4Mic.ipynb | 1 - Colab_notebooks/pix2pix_ZeroCostDL4Mic.ipynb | 2832 +------------- 30 files changed, 50 insertions(+), 35454 deletions(-) create mode 100644 Colab_notebooks/Beta notebooks/MaskRCNN_ZeroCostDL4Mic.ipynb create mode 100644 Colab_notebooks/Beta notebooks/fnet_2D_ZeroCostDL4Mic.ipynb create mode 100644 Colab_notebooks/fnet_2D_ZeroCostDL4Mic.ipynb create mode 100644 Colab_notebooks/fnet_3D_ZeroCostDL4Mic.ipynb delete mode 100644 Colab_notebooks/fnet_ZeroCostDL4Mic.ipynb diff --git a/Colab_notebooks/Beta notebooks/3D-RCAN_ZeroCostDL4Mic.ipynb b/Colab_notebooks/Beta notebooks/3D-RCAN_ZeroCostDL4Mic.ipynb index c5f972a1..d63de4fe 100644 --- a/Colab_notebooks/Beta notebooks/3D-RCAN_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/Beta notebooks/3D-RCAN_ZeroCostDL4Mic.ipynb @@ -1 +1 @@ -{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"3D-RCAN_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1JHO_gnWRtiFhhD5YgLE2UwgTB-63MVii","timestamp":1610723892159},{"file_id":"1_W4q9V1ExGFldTUBvGK91E0LG5QMc7K6","timestamp":1602523405636},{"file_id":"1t9a-44km730bI7F4I08-6Xh7wEZuL98p","timestamp":1591013189418},{"file_id":"11TigzvLl4FSSwFHUNwLzZKI2IAix4Nmu","timestamp":1586415689249},{"file_id":"1_dSnxUg_qtNWjrPc7D6RWDWlCanEL4Ve","timestamp":1585153449937},{"file_id":"1bKo8jYVZPPgXPa_-Gdu1KhDnNN4vYfLx","timestamp":1583200150464}],"collapsed_sections":[],"toc_visible":true,"machine_shape":"hm"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.4"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"V9zNGvape2-I"},"source":["# **3D-RCAN**\n","\n","\n","\n","---\n","\n","3D-RCAN is a neural network capable of image restoration from corrupted bio-images, first released in 2020 by [Chen *et al.* in biorXiv](https://www.biorxiv.org/content/10.1101/2020.08.27.270439v1). \n","\n"," **This particular notebook enables restoration of 3D dataset. If you are interested in restoring 2D dataset, you should use the CARE 2D notebook instead.**\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the following paper: \n","\n","**Three-dimensional residual channel attention networks denoise and sharpen fluorescence microscopy image volumes**, by Chen *et al.* published in bioRxiv in 2020 (https://www.biorxiv.org/content/10.1101/2020.08.27.270439v1)\n","\n","And source code found in: /~https://github.com/AiviaCommunity/3D-RCAN\n","\n","We provide a dataset for the training of this notebook as a way to test its functionalities but the training and test data of the restoration experiments is also available from the authors of the original paper [here](https://www.dropbox.com/sh/hieldept1x476dw/AAC0pY3FrwdZBctvFF0Fx0L3a?dl=0).\n","\n","**Please also cite this original paper when using or developing this notebook.**"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"vNMDQHm0Ah-Z"},"source":["#**0. Before getting started**\n","---\n"," For CARE to train, **it needs to have access to a paired training dataset**. This means that the same image needs to be acquired in the two conditions (for instance, low signal-to-noise ratio and high signal-to-noise ratio) and provided with indication of correspondence.\n","\n"," Therefore, the data structure is important. It is necessary that all the input data are in the same folder and that all the output data is in a separate folder. The provided training dataset is already split in two folders called \"Training - Low SNR images\" (Training_source) and \"Training - high SNR images\" (Training_target). Information on how to generate a training dataset is available in our Wiki page: /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n"," **Additionally, the corresponding input and output files need to have the same name**.\n","\n"," Please note that you currently can **only use .tif files!**\n","\n"," You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed. \n","\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Low SNR images (Training_source)\n"," - img_1.tif, img_2.tif, ...\n"," - High SNR images (Training_target)\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - Low SNR images\n"," - img_1.tif, img_2.tif\n"," - High SNR images\n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"b4-r1gE7Iamv"},"source":["# **1. Initialise the Colab session**\n","---"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"BDhmUgqCStlm","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-oqBTeLaImnU"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","cellView":"form"},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **2. Install 3D-RCAN and dependencies**\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"1kvDz2Ft4FX6"},"source":["## **2.1. Install key dependencies**\n","---\n"," "]},{"cell_type":"code","metadata":{"id":"3u2mXn3XsWzd","cellView":"form"},"source":["Notebook_version = ['1.12']\n","\n","\n","#@markdown ##Install 3D-RCAN and dependencies\n","\n","!git clone /~https://github.com/AiviaCommunity/3D-RCAN\n","\n","import os\n","\n","\n","!pip install q keras==2.2.5\n","\n","!pip install colorama; sys_platform=='win32'\n","!pip install jsonschema\n","!pip install numexpr\n","!pip install tqdm>=4.41.0\n","\n","\n","\n","%tensorflow_version 1.x\n","#Here, we install libraries which are not already included in Colab.\n","!pip install tifffile # contains tools to operate tiff-files\n","!pip install wget\n","!pip install fpdf\n","!pip install memory_profiler\n","%load_ext memory_profiler\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"IWhWPjyi33M2"},"source":["## **2.2. Restart your runtime**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"id":"SBag2atY36Js"},"source":["** Here you need to restart your runtime to load the newly installed dependencies**\n","\n"," Click on \"Runtime\" ---> \"Restart Runtime\""]},{"cell_type":"markdown","metadata":{"id":"_nRLOjuk3_8z"},"source":["## **2.3. Load key dependencies**\n","---\n"," "]},{"cell_type":"code","metadata":{"id":"TYYBwn_54G9j","cellView":"form"},"source":["Notebook_version = ['1.11.1']\n","\n","#@markdown ##Load key dependencies\n","\n","!pip install q keras==2.2.5\n","\n","#Here, we import and enable Tensorflow 1 instead of Tensorflow 2.\n","%tensorflow_version 1.x\n","import tensorflow\n","import tensorflow as tf\n","\n","print(tensorflow.__version__)\n","print(\"Tensorflow enabled.\")\n","\n","# ------- Variable specific to 3D-RCAN -------\n","\n","\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","from skimage.util import img_as_ubyte\n","from tqdm import tqdm \n","from fpdf import FPDF, HTMLMixin\n","from datetime import datetime\n","from pip._internal.operations.freeze import freeze\n","import subprocess\n","\n","\n","# For sliders and dropdown menu and progress bar\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print(\"Libraries installed\")\n","\n","# Check if this is the latest version of the notebook\n","Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n","print('Notebook version: '+Notebook_version[0])\n","strlist = Notebook_version[0].split('.')\n","Notebook_version_main = strlist[0]+'.'+strlist[1]\n","if Notebook_version_main == Latest_notebook_version.columns:\n"," print(\"This notebook is up-to-date.\")\n","else:\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = '3D-RCAN'\n"," day = datetime.now()\n"," datetime_str = str(day)[0:16]\n","\n"," Header = 'Training report for '+Network+' model ('+model_name+')\\nDate and Time: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," \n"," # add another cell \n"," if trained:\n"," training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n"," pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n"," pdf.ln(1)\n","\n"," Header_2 = 'Information for your materials and method:'\n"," pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n"," #print(all_packages)\n","\n"," #Main Packages\n"," main_packages = ''\n"," version_numbers = []\n"," for name in ['tensorflow','numpy','Keras']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n"," cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n"," cuda_version = cuda_version.stdout.decode('utf-8')\n"," cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n"," gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n"," gpu_name = gpu_name.stdout.decode('utf-8')\n"," gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n"," #print(cuda_version[cuda_version.find(', V')+3:-1])\n"," #print(gpu_name)\n","\n"," shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n"," dataset_size = len(os.listdir(Training_source))\n","\n"," text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs (image dimensions: '+str(shape)+', using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.multi_cell(190, 5, txt = text, align='L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n"," pdf.set_font('')\n"," if augmentation:\n"," aug_text = 'The dataset was augmented by'\n"," if Rotation:\n"," aug_text = aug_text+'\\n- rotation'\n"," if Flip:\n"," aug_text = aug_text+'\\n- flipping'\n"," else:\n"," aug_text = 'No augmentation was used for training.'\n"," pdf.multi_cell(190, 5, txt=aug_text, align='L')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," if Use_Default_Advanced_Parameters:\n"," pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n"," pdf.cell(200, 5, txt='The following parameters were used for training:')\n"," pdf.ln(1)\n"," html = \"\"\" \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","\n","
ParameterValue
number_of_epochs{0}
number_of_steps{1}
percentage_validation{2}
num_residual_groups{3}
num_residual_blocks{4}
num_channels{5}
channel_reduction{6}
\n"," \"\"\".format(number_of_epochs,number_of_steps, percentage_validation, num_residual_groups, num_residual_blocks, num_channels, channel_reduction)\n"," pdf.write_html(html)\n","\n"," #pdf.multi_cell(190, 5, txt = text_2, align='L')\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," pdf.ln(1)\n"," pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(32, 5, txt= 'Training_source:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(30, 5, txt= 'Training_target:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n"," #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(22, 5, txt= 'Model Path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n"," pdf.ln(1)\n"," pdf.cell(60, 5, txt = 'Example Training pair', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread('/content/TrainingDataExample_3D_RCAN.png').shape\n"," pdf.image('/content/TrainingDataExample_3D_RCAN.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- 3D-RCAN: Chen et al. \"Three-dimensional residual channel attention networks denoise and sharpen fluorescence microscopy image volumes.\" bioRxiv 2020 https://www.biorxiv.org/content/10.1101/2020.08.27.270439v1'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n"," pdf.ln(3)\n"," reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," if trained:\n"," pdf.output(model_path+'/'+model_name+'/'+model_name+\"_training_report.pdf\")\n"," else:\n"," pdf.output('/content/'+model_name+\"_training_report.pdf\")\n","\n","\n","def qc_pdf_export():\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = '3D RCAN'\n","\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(2)\n"," pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n"," if os.path.exists(full_QC_model_path+'/Quality Control/lossCurvePlots.png'):\n"," pdf.image(full_QC_model_path+'/Quality Control/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/13))\n"," else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," # pdf.ln(3)\n"," pdf.multi_cell(190, 5, txt='You can see these curves in the notebook.')\n"," pdf.ln(3)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(3)\n"," pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n"," pdf.image(full_QC_model_path+'/Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/10))\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n","\n"," pdf.ln(1)\n"," html = \"\"\"\n"," \n"," \n"," \"\"\"\n"," with open(full_QC_model_path+'/Quality Control/QC_metrics_'+QC_model_name+'.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," slice_n = header[1]\n"," mSSIM_PvsGT = header[2]\n"," mSSIM_SvsGT = header[3]\n"," NRMSE_PvsGT = header[4]\n"," NRMSE_SvsGT = header[5]\n"," PSNR_PvsGT = header[6]\n"," PSNR_SvsGT = header[7]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,slice_n,mSSIM_PvsGT,mSSIM_SvsGT,NRMSE_PvsGT,NRMSE_SvsGT,PSNR_PvsGT,PSNR_SvsGT)\n"," html = html+header\n"," for row in metrics:\n"," image = row[0]\n"," slice_n = row[1]\n"," mSSIM_PvsGT = row[2]\n"," mSSIM_SvsGT = row[3]\n"," NRMSE_PvsGT = row[4]\n"," NRMSE_SvsGT = row[5]\n"," PSNR_PvsGT = row[6]\n"," PSNR_SvsGT = row[7]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,slice_n,str(round(float(mSSIM_PvsGT),3)),str(round(float(mSSIM_SvsGT),3)),str(round(float(NRMSE_PvsGT),3)),str(round(float(NRMSE_SvsGT),3)),str(round(float(PSNR_PvsGT),3)),str(round(float(PSNR_SvsGT),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}{3}{4}{5}{6}{7}
{0}{1}{2}{3}{4}{5}{6}{7}
\"\"\"\n"," \n"," pdf.write_html(html)\n","\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" bioRxiv (2020).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- Three-dimensional residual channel attention networks denoise and sharpen fluorescence microscopy image volumes, by Chen et al. bioRxiv (2020)'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n"," pdf.ln(3)\n"," reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n","\n","\n","!pip freeze > requirements.txt"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Fw0kkTU6CsU4"},"source":["# **3. Select your parameters and paths**\n","\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"WzYAA-MuaYrT"},"source":["## **3.1. Setting main training parameters**\n","---\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"CB6acvUFtWqd"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source (Low SNR images) and Training_target (High SNR images or ground truth) training data respecively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**Training Parameters**\n","\n","**`number of epochs`:**Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for 100-300 epochs. Evaluate the performance after training (see 5.). **Default value: 30**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each patch is seen at least once per epoch. **Default value: 256**\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during the training. **Default value: 10** \n","\n","**`num_residual_groups`:** Number of residual groups in RCAN. **Default value: 5** \n","\n","**If you get an Out of memory (OOM) error during the training, manually decrease the num_residual_groups value until the OOM error disappear.**\n","\n","**`num_residual_blocks`:** Number of residual channel attention blocks in each residual group in RCAN. **Default value: 3** \n","\n","**`num_channels`:** Number of feature channels in RCAN. **Default value: 32** \n","\n","**`channel_reduction`:** Channel reduction ratio for channel attention. **Default value: 8** \n","\n","\n"]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","cellView":"form"},"source":["\n","#@markdown ###Path to training images:\n","\n","# base folder of GT and low images\n","base = \"/content\"\n","\n","# low SNR images\n","Training_source = \"\" #@param {type:\"string\"}\n","lowfile = Training_source+\"/*.tif\"\n","# Ground truth images\n","Training_target = \"\" #@param {type:\"string\"}\n","GTfile = Training_target+\"/*.tif\"\n","\n","\n","# model name and path\n","#@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","\n","# create the training data file into model_path folder.\n","training_data = model_path+\"/my_training_data.npz\"\n","\n","# other parameters for training.\n","#@markdown ###Training Parameters\n","#@markdown Number of epochs:\n","\n","number_of_epochs = 30#@param {type:\"number\"}\n","number_of_steps = 256#@param {type:\"number\"}\n","\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","#@markdown ###If not, please input:\n","\n","percentage_validation = 10 #@param {type:\"number\"}\n","num_residual_groups = 5 #@param {type:\"number\"}\n","num_residual_blocks = 3 #@param {type:\"number\"}\n","num_channels = 32 #@param {type:\"number\"}\n","channel_reduction = 8 #@param {type:\"number\"}\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\") \n"," percentage_validation = 10\n"," num_residual_groups = 5\n"," num_channels = 32\n"," num_residual_blocks = 3\n"," channel_reduction = 8\n"," \n","\n","percentage = percentage_validation/100\n","\n","\n","full_model_path = model_path+'/'+model_name\n","\n","#here we check that no model with the same name already exist, if so print a warning\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n"," print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\"+W)\n"," \n"," \n","# Here we disable pre-trained model by default (in case the next cell is not ran)\n","Use_pretrained_model = False\n","\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = False\n","\n","\n","#Load one randomly chosen training source file\n","\n","random_choice=random.choice(os.listdir(Training_source))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","\n","# Here we check that the input images are stacks\n","if len(x.shape) == 3:\n"," print(\"Image dimensions (z,y,x)\",x.shape)\n","\n","if not len(x.shape) == 3:\n"," print(bcolors.WARNING +\"Your images appear to have the wrong dimensions. Image dimension\",x.shape)\n","\n","\n","#Find image Z dimension and select the mid-plane\n","Image_Z = x.shape[0]\n","mid_plane = int(Image_Z / 2)+1\n","\n","\n","#Find image XY dimension\n","Image_Y = x.shape[1]\n","Image_X = x.shape[2]\n","\n","# Here we split the data between training and validation\n","# Here we count the number of files in the training target folder\n","Filelist = os.listdir(Training_target)\n","number_files = len(Filelist)\n","\n","File_for_validation = int((number_files)/percentage_validation)+1\n","\n","#Here we split the training dataset between training and validation\n","# Everything is copied in the /Content Folder\n","\n","Training_source_temp = \"/content/training_source\"\n","\n","if os.path.exists(Training_source_temp):\n"," shutil.rmtree(Training_source_temp)\n","os.makedirs(Training_source_temp)\n","\n","Training_target_temp = \"/content/training_target\"\n","if os.path.exists(Training_target_temp):\n"," shutil.rmtree(Training_target_temp)\n","os.makedirs(Training_target_temp)\n","\n","Validation_source_temp = \"/content/validation_source\"\n","\n","if os.path.exists(Validation_source_temp):\n"," shutil.rmtree(Validation_source_temp)\n","os.makedirs(Validation_source_temp)\n","\n","Validation_target_temp = \"/content/validation_target\"\n","if os.path.exists(Validation_target_temp):\n"," shutil.rmtree(Validation_target_temp)\n","os.makedirs(Validation_target_temp)\n","\n","list_source = os.listdir(os.path.join(Training_source))\n","list_target = os.listdir(os.path.join(Training_target))\n","\n","#Move files into the temporary source and target directories:\n"," \n"," \n","for f in os.listdir(os.path.join(Training_source)):\n"," shutil.copy(Training_source+\"/\"+f, Training_source_temp+\"/\"+f)\n","\n","for p in os.listdir(os.path.join(Training_target)):\n"," shutil.copy(Training_target+\"/\"+p, Training_target_temp+\"/\"+p)\n","\n","\n","list_source_temp = os.listdir(os.path.join(Training_source_temp))\n","list_target_temp = os.listdir(os.path.join(Training_target_temp))\n","\n","\n","#Here we move images to be used for validation\n","for i in range(File_for_validation):\n","\n"," name = list_source_temp[i]\n"," shutil.move(Training_source_temp+\"/\"+name, Validation_source_temp+\"/\"+name)\n"," shutil.move(Training_target_temp+\"/\"+name, Validation_target_temp+\"/\"+name)\n","\n","\n","#Load one randomly chosen training target file\n","\n","os.chdir(Training_target)\n","y = imread(Training_target+\"/\"+random_choice)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x[mid_plane], norm=simple_norm(x[mid_plane], percent = 99), interpolation='nearest')\n","plt.axis('off')\n","plt.title('Low SNR image (single Z plane)');\n","plt.subplot(1,2,2)\n","plt.imshow(y[mid_plane], norm=simple_norm(y[mid_plane], percent = 99), interpolation='nearest')\n","plt.axis('off')\n","plt.title('High SNR image (single Z plane)');\n","plt.savefig('/content/TrainingDataExample_3D_RCAN.png',bbox_inches='tight',pad_inches=0)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xGcl7WGP4WHt"},"source":["## **3.2. Data augmentation**\n","---"]},{"cell_type":"markdown","metadata":{"id":"5Lio8hpZ4PJ1"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n"," **However, data augmentation is not a magic solution and may also introduce issues. Therefore, we recommend that you train your network with and without augmentation, and use the QC section to validate that it improves overall performances.** \n","\n","Data augmentation is performed here by rotating the training images in the XY-Plane and flipping them along X-Axis.\n","\n","**The flip option alone will double the size of your dataset, rotation will quadruple and both together will increase the dataset by a factor of 8.**"]},{"cell_type":"code","metadata":{"id":"htqjkJWt5J_8","cellView":"form"},"source":["Use_Data_augmentation = False #@param{type:\"boolean\"}\n","\n","#@markdown Select this option if you want to use augmentation to increase the size of your dataset\n","\n","#@markdown **Rotate each image 3 times by 90 degrees.**\n","Rotation = False #@param{type:\"boolean\"}\n","\n","#@markdown **Flip each image once around the x axis of the stack.**\n","Flip = False #@param{type:\"boolean\"}\n","\n","\n","#@markdown **Would you like to save your augmented images?**\n","\n","Save_augmented_images = False #@param {type:\"boolean\"}\n","\n","Saving_path = \"\" #@param {type:\"string\"}\n","\n","\n","if not Save_augmented_images:\n"," Saving_path= \"/content\"\n","\n","\n","def rotation_aug(Source_path, Target_path, flip=False):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path)\n"," \n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," # Source Rotation\n"," source_img_90 = np.rot90(source_img,axes=(1,2))\n"," source_img_180 = np.rot90(source_img_90,axes=(1,2))\n"," source_img_270 = np.rot90(source_img_180,axes=(1,2))\n","\n"," # Target Rotation\n"," target_img_90 = np.rot90(target_img,axes=(1,2))\n"," target_img_180 = np.rot90(target_img_90,axes=(1,2))\n"," target_img_270 = np.rot90(target_img_180,axes=(1,2))\n","\n"," # Add a flip to the rotation\n"," \n"," if flip == True:\n"," source_img_lr = np.fliplr(source_img)\n"," source_img_90_lr = np.fliplr(source_img_90)\n"," source_img_180_lr = np.fliplr(source_img_180)\n"," source_img_270_lr = np.fliplr(source_img_270)\n","\n"," target_img_lr = np.fliplr(target_img)\n"," target_img_90_lr = np.fliplr(target_img_90)\n"," target_img_180_lr = np.fliplr(target_img_180)\n"," target_img_270_lr = np.fliplr(target_img_270)\n","\n"," #source_img_90_ud = np.flipud(source_img_90)\n"," \n"," # Save the augmented files\n"," # Source images\n"," io.imsave(Saving_path+'/augmented_source/'+image,source_img)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_90.tif',source_img_90)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_180.tif',source_img_180)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_270.tif',source_img_270)\n"," # Target images\n"," io.imsave(Saving_path+'/augmented_target/'+image,target_img)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_90.tif',target_img_90)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_180.tif',target_img_180)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_270.tif',target_img_270)\n","\n"," if flip == True:\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_90_lr.tif',source_img_90_lr)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_180_lr.tif',source_img_180_lr)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_270_lr.tif',source_img_270_lr)\n","\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_90_lr.tif',target_img_90_lr)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_180_lr.tif',target_img_180_lr)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_270_lr.tif',target_img_270_lr)\n","\n","def flip(Source_path, Target_path):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path) \n","\n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," source_img_lr = np.fliplr(source_img)\n"," target_img_lr = np.fliplr(target_img)\n","\n"," io.imsave(Saving_path+'/augmented_source/'+image,source_img)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n","\n"," io.imsave(Saving_path+'/augmented_target/'+image,target_img)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n","\n","\n","if Use_Data_augmentation:\n","\n"," if os.path.exists(Saving_path+'/augmented_source'):\n"," shutil.rmtree(Saving_path+'/augmented_source')\n"," os.mkdir(Saving_path+'/augmented_source')\n","\n"," if os.path.exists(Saving_path+'/augmented_target'):\n"," shutil.rmtree(Saving_path+'/augmented_target') \n"," os.mkdir(Saving_path+'/augmented_target')\n","\n"," print(\"Data augmentation enabled\")\n"," print(\"Data augmentation in progress....\")\n","\n"," if Rotation == True:\n"," rotation_aug(Training_source_temp,Training_target_temp,flip=Flip)\n"," \n"," elif Rotation == False and Flip == True:\n"," flip(Training_source_temp,Training_target_temp)\n"," print(\"Done\")\n","\n","\n","if not Use_Data_augmentation:\n"," print(bcolors.WARNING+\"Data augmentation disabled\")\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"rQndJj70FzfL"},"source":["# **4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"tGW2iaU6X5zi"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"id":"WMJnGJpCMa4y","cellView":"form"},"source":["#@markdown ##Create the model and dataset objects\n","\n","# --------------------- Here we delete the model folder if it already exist ------------------------\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: Model folder already exists and has been removed !!\" + W)\n"," shutil.rmtree(model_path+'/'+model_name)\n","\n","\n","print(\"Preparing the config file...\")\n","\n","if Use_Data_augmentation == True:\n"," Training_source_temp = Saving_path+'/augmented_source'\n"," Training_target_temp = Saving_path+'/augmented_target'\n","\n","# Here we prepare the JSON file\n","\n","import json \n"," \n","# Config file for 3D-RCAN \n","dictionary ={\n"," \"epochs\": number_of_epochs,\n"," \"steps_per_epoch\": number_of_steps,\n"," \"num_residual_groups\": num_residual_groups,\n"," \"training_data_dir\": {\"raw\": Training_source_temp,\n"," \"gt\": Training_target_temp},\n"," \n"," \"validation_data_dir\": {\"raw\": Validation_source_temp,\n"," \"gt\": Validation_target_temp},\n"," \"num_channels\": num_channels,\n"," \"num_residual_blocks\": num_residual_blocks,\n"," \"channel_reduction\": channel_reduction\n"," \n"," \n","}\n"," \n","json_object = json.dumps(dictionary, indent = 4) \n"," \n","with open(\"/content/config.json\", \"w\") as outfile: \n"," outfile.write(json_object)\n","\n","# Export pdf summary of training parameters\n","pdf_export(augmentation = Use_Data_augmentation)\n","\n","print(\"Done\")\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"wQPz0F6JlvJR"},"source":["## **4.2. Start Training**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point.\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder from Google Drive as all data can be erased at the next training if using the same folder.\n","\n"]},{"cell_type":"code","metadata":{"id":"j_Qm5JBmlvJg","cellView":"form"},"source":["#@markdown ##Start Training\n","\n","start = time.time()\n","\n","# Start Training\n","!python /content/3D-RCAN/train.py -c /content/config.json -o \"$full_model_path\"\n","\n","print(\"Training, done.\")\n","\n","\n","\n","if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n"," shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","\n","#Create a pdf document with training summary\n","\n","pdf_export(trained = True, augmentation = Use_Data_augmentation)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"QYuIOWQ3imuU"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"id":"zazOZ3wDx0zQ","cellView":"form"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"yDY9dtzdUTLh"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"id":"vMzSP50kMv5p","cellView":"form"},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","%load_ext tensorboard\n","%tensorboard --logdir \"$full_QC_model_path\"\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"biT9FI9Ri77_"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"nAs4Wni7VYbq","cellView":"form"},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","path_metrics_save = QC_model_path+'/'+QC_model_name+'/Quality Control/'\n","\n","path_QC_prediction = path_metrics_save+'Prediction'\n","\n","\n","# Create a quality control/Prediction Folder\n","if os.path.exists(path_QC_prediction):\n"," shutil.rmtree(path_QC_prediction)\n","os.makedirs(path_QC_prediction)\n","\n","\n","# Perform the predictions\n","\n","print(\"Restoring images...\")\n","\n","!python /content/3D-RCAN/apply.py -m \"$full_QC_model_path\" -i \"$Source_QC_folder\" -o \"$path_QC_prediction\"\n","\n","print(\"Done...\")\n","\n","\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(path_metrics_save+'QC_metrics_'+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"File name\",\"Slice #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n"," \n"," # These lists will be used to collect all the metrics values per slice\n"," file_name_list = []\n"," slice_number_list = []\n"," mSSIM_GvP_list = []\n"," mSSIM_GvS_list = []\n"," NRMSE_GvP_list = []\n"," NRMSE_GvS_list = []\n"," PSNR_GvP_list = []\n"," PSNR_GvS_list = []\n","\n"," # These lists will be used to display the mean metrics for the stacks\n"," mSSIM_GvP_list_mean = []\n"," mSSIM_GvS_list_mean = []\n"," NRMSE_GvP_list_mean = []\n"," NRMSE_GvS_list_mean = []\n"," PSNR_GvP_list_mean = []\n"," PSNR_GvS_list_mean = []\n","\n"," # Let's loop through the provided dataset in the QC folders\n"," for thisFile in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder, thisFile)):\n"," print('Running QC on: '+thisFile)\n","\n"," test_GT_stack = io.imread(os.path.join(Target_QC_folder, thisFile))\n"," test_source_stack = io.imread(os.path.join(Source_QC_folder,thisFile))\n"," test_prediction_stack_raw = io.imread(os.path.join(path_metrics_save+\"Prediction/\",thisFile))\n"," test_prediction_stack = test_prediction_stack_raw[:, 1, :, :]\n"," n_slices = test_GT_stack.shape[0]\n","\n"," # Calculating the position of the mid-plane slice\n"," z_mid_plane = int(n_slices / 2)+1\n","\n"," img_SSIM_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_SSIM_GTvsSource_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_RSE_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_RSE_GTvsSource_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n","\n"," for z in range(n_slices): \n"," # -------------------------------- Normalising the dataset --------------------------------\n","\n"," test_GT_norm, test_source_norm = norm_minmse(test_GT_stack[z], test_source_stack[z], normalize_gt=True)\n"," test_GT_norm, test_prediction_norm = norm_minmse(test_GT_stack[z], test_prediction_stack[z], normalize_gt=True)\n","\n"," # -------------------------------- Calculate the SSIM metric and maps --------------------------------\n","\n"," # Calculate the SSIM maps and index\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = structural_similarity(test_GT_norm, test_prediction_norm, data_range=1.0, full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = structural_similarity(test_GT_norm, test_source_norm, data_range=1.0, full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n"," #Calculate ssim_maps\n"," img_SSIM_GTvsPrediction_stack[z] = img_as_float32(img_SSIM_GTvsPrediction, force_copy=False)\n"," img_SSIM_GTvsSource_stack[z] = img_as_float32(img_SSIM_GTvsSource, force_copy=False)\n"," \n","\n"," # -------------------------------- Calculate the NRMSE metrics --------------------------------\n","\n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n","\n"," # Calculate SE maps\n"," img_RSE_GTvsPrediction_stack[z] = img_as_float32(img_RSE_GTvsPrediction, force_copy=False)\n"," img_RSE_GTvsSource_stack[z] = img_as_float32(img_RSE_GTvsSource, force_copy=False)\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n","\n"," # Calculate the PSNR between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n","\n"," writer.writerow([thisFile, str(z),str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource), str(PSNR_GTvsPrediction), str(PSNR_GTvsSource)])\n"," \n"," # Collect values to display in dataframe output\n"," slice_number_list.append(z)\n"," mSSIM_GvP_list.append(index_SSIM_GTvsPrediction)\n"," mSSIM_GvS_list.append(index_SSIM_GTvsSource)\n"," NRMSE_GvP_list.append(NRMSE_GTvsPrediction)\n"," NRMSE_GvS_list.append(NRMSE_GTvsSource)\n"," PSNR_GvP_list.append(PSNR_GTvsPrediction)\n"," PSNR_GvS_list.append(PSNR_GTvsSource)\n","\n"," if (z == z_mid_plane): # catch these for display\n"," SSIM_GTvsP_forDisplay = index_SSIM_GTvsPrediction\n"," SSIM_GTvsS_forDisplay = index_SSIM_GTvsSource\n"," NRMSE_GTvsP_forDisplay = NRMSE_GTvsPrediction\n"," NRMSE_GTvsS_forDisplay = NRMSE_GTvsSource\n"," \n"," # If calculating average metrics for dataframe output\n"," file_name_list.append(thisFile)\n"," mSSIM_GvP_list_mean.append(sum(mSSIM_GvP_list)/len(mSSIM_GvP_list))\n"," mSSIM_GvS_list_mean.append(sum(mSSIM_GvS_list)/len(mSSIM_GvS_list))\n"," NRMSE_GvP_list_mean.append(sum(NRMSE_GvP_list)/len(NRMSE_GvP_list))\n"," NRMSE_GvS_list_mean.append(sum(NRMSE_GvS_list)/len(NRMSE_GvS_list))\n"," PSNR_GvP_list_mean.append(sum(PSNR_GvP_list)/len(PSNR_GvP_list))\n"," PSNR_GvS_list_mean.append(sum(PSNR_GvS_list)/len(PSNR_GvS_list))\n","\n"," # ----------- Change the stacks to 32 bit images -----------\n","\n"," img_SSIM_GTvsSource_stack_32 = img_as_float32(img_SSIM_GTvsSource_stack, force_copy=False)\n"," img_SSIM_GTvsPrediction_stack_32 = img_as_float32(img_SSIM_GTvsPrediction_stack, force_copy=False)\n"," img_RSE_GTvsSource_stack_32 = img_as_float32(img_RSE_GTvsSource_stack, force_copy=False)\n"," img_RSE_GTvsPrediction_stack_32 = img_as_float32(img_RSE_GTvsPrediction_stack, force_copy=False)\n","\n"," # ----------- Saving the error map stacks -----------\n"," io.imsave(path_metrics_save+'SSIM_GTvsSource_'+thisFile,img_SSIM_GTvsSource_stack_32)\n"," io.imsave(path_metrics_save+'SSIM_GTvsPrediction_'+thisFile,img_SSIM_GTvsPrediction_stack_32)\n"," io.imsave(path_metrics_save+'RSE_GTvsSource_'+thisFile,img_RSE_GTvsSource_stack_32)\n"," io.imsave(path_metrics_save+'RSE_GTvsPrediction_'+thisFile,img_RSE_GTvsPrediction_stack_32)\n","\n","#Averages of the metrics per stack as dataframe output\n","pdResults = pd.DataFrame(file_name_list, columns = [\"File name\"])\n","pdResults[\"Prediction v. GT mSSIM\"] = mSSIM_GvP_list_mean\n","pdResults[\"Input v. GT mSSIM\"] = mSSIM_GvS_list_mean\n","pdResults[\"Prediction v. GT NRMSE\"] = NRMSE_GvP_list_mean\n","pdResults[\"Input v. GT NRMSE\"] = NRMSE_GvS_list_mean\n","pdResults[\"Prediction v. GT PSNR\"] = PSNR_GvP_list_mean\n","pdResults[\"Input v. GT PSNR\"] = PSNR_GvS_list_mean\n","\n","# All data is now processed saved\n","Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same way\n","\n","plt.figure(figsize=(20,20))\n","# Currently only displays the last computed set, from memory\n","# Target (Ground-truth)\n","plt.subplot(3,3,1)\n","plt.axis('off')\n","img_GT = io.imread(os.path.join(Target_QC_folder, Test_FileList[-1]))\n","\n","# Calculating the position of the mid-plane slice\n","z_mid_plane = int(img_GT.shape[0] / 2)+1\n","\n","plt.imshow(img_GT[z_mid_plane], norm=simple_norm(img_GT[z_mid_plane], percent = 99))\n","plt.title('Target (slice #'+str(z_mid_plane)+')')\n","\n","# Source\n","plt.subplot(3,3,2)\n","plt.axis('off')\n","img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_Source[z_mid_plane], norm=simple_norm(img_Source[z_mid_plane], percent = 99))\n","plt.title('Source (slice #'+str(z_mid_plane)+')')\n","\n","#Prediction\n","plt.subplot(3,3,3)\n","plt.axis('off')\n","img_Prediction_raw = io.imread(os.path.join(path_metrics_save+'Prediction/', Test_FileList[-1]))\n","\n","img_Prediction = img_Prediction_raw[:, 1, :, :]\n","plt.imshow(img_Prediction[z_mid_plane], norm=simple_norm(img_Prediction[z_mid_plane], percent = 99))\n","plt.title('Prediction (slice #'+str(z_mid_plane)+')')\n","\n","#Setting up colours\n","cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Source\n","plt.subplot(3,3,5)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","img_SSIM_GTvsSource = io.imread(os.path.join(path_metrics_save, 'SSIM_GTvsSource_'+Test_FileList[-1]))\n","imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource[z_mid_plane], cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(SSIM_GTvsS_forDisplay,3)),fontsize=14)\n","plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n","plt.subplot(3,3,6)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_SSIM_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'SSIM_GTvsPrediction_'+Test_FileList[-1]))\n","imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0,vmax=1)\n","plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(SSIM_GTvsP_forDisplay,3)),fontsize=14)\n","\n","#Root Squared Error between GT and Source\n","plt.subplot(3,3,8)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","img_RSE_GTvsSource = io.imread(os.path.join(path_metrics_save, 'RSE_GTvsSource_'+Test_FileList[-1]))\n","imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource[z_mid_plane], cmap = cmap, vmin=0, vmax = 1) \n","plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsS_forDisplay,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n","#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\n","plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#Root Squared Error between GT and Prediction\n","plt.subplot(3,3,9)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_RSE_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'RSE_GTvsPrediction_'+Test_FileList[-1]))\n","imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsP_forDisplay,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n","plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","\n","\n","print('-----------------------------------')\n","print('Here are the average scores for the stacks you tested in Quality control. To see values for all slices, open the .csv file saved in the Quality Control folder.')\n","pdResults.head()\n","\n","\n","\n","#Make a pdf summary of the QC results\n","\n","qc_pdf_export()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"69aJVFfsqXbY"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"tcPNRq1TrMPB"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images."]},{"cell_type":"code","metadata":{"id":"Am2JSmpC0frj","cellView":"form"},"source":["\n","#@markdown ##Provide the path to your dataset and to the folder where the prediction will be saved, then play the cell to predict output on your unseen images.\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n"," \n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","print(\"Restoring images...\")\n","\n","!python /content/3D-RCAN/apply.py -m \"$full_Prediction_model_path\" -i \"$Data_folder\" -o \"$Result_folder\"\n","\n","print(\"Images saved into the result folder:\", Result_folder)\n","\n","#Display an example\n","\n","random_choice=random.choice(os.listdir(Data_folder))\n","x = imread(Data_folder+\"/\"+random_choice)\n","\n","z_mid_plane = int(x.shape[0] / 2)+1\n","\n","@interact\n","def show_results(file=os.listdir(Data_folder), z_plane=widgets.IntSlider(min=0, max=(x.shape[0]-1), step=1, value=z_mid_plane)):\n"," x = imread(Data_folder+\"/\"+file)\n"," y_raw = imread(Result_folder+\"/\"+file)\n"," y = y_raw[:, 1, :, :]\n","\n"," f=plt.figure(figsize=(16,8))\n"," plt.subplot(1,2,1)\n"," plt.imshow(x[z_plane], norm=simple_norm(x[z_plane], percent = 99), interpolation='nearest')\n"," plt.axis('off')\n"," plt.title('Noisy Input (single Z plane)');\n"," plt.subplot(1,2,2)\n"," plt.imshow(y[z_plane], norm=simple_norm(y[z_plane], percent = 99), interpolation='nearest')\n"," plt.axis('off')\n"," plt.title('Prediction (single Z plane)');\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB"},"source":["## **6.2. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"u4pcBe8Z3T2J"},"source":["#**Thank you for using 3D-RCAN!**"]}]} \ No newline at end of file +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"3D-RCAN_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1JHO_gnWRtiFhhD5YgLE2UwgTB-63MVii","timestamp":1610723892159},{"file_id":"1_W4q9V1ExGFldTUBvGK91E0LG5QMc7K6","timestamp":1602523405636},{"file_id":"1t9a-44km730bI7F4I08-6Xh7wEZuL98p","timestamp":1591013189418},{"file_id":"11TigzvLl4FSSwFHUNwLzZKI2IAix4Nmu","timestamp":1586415689249},{"file_id":"1_dSnxUg_qtNWjrPc7D6RWDWlCanEL4Ve","timestamp":1585153449937},{"file_id":"1bKo8jYVZPPgXPa_-Gdu1KhDnNN4vYfLx","timestamp":1583200150464}],"collapsed_sections":[],"machine_shape":"hm"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.4"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"V9zNGvape2-I"},"source":["# **3D-RCAN**\n","\n","\n","\n","---\n","\n","3D-RCAN is a neural network capable of image restoration from corrupted bio-images, first released in 2020 by [Chen *et al.* in biorXiv](https://www.biorxiv.org/content/10.1101/2020.08.27.270439v1). \n","\n"," **This particular notebook enables restoration of 3D dataset. If you are interested in restoring 2D dataset, you should use the CARE 2D notebook instead.**\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the following paper: \n","\n","**Three-dimensional residual channel attention networks denoise and sharpen fluorescence microscopy image volumes**, by Chen *et al.* published in bioRxiv in 2020 (https://www.biorxiv.org/content/10.1101/2020.08.27.270439v1)\n","\n","And source code found in: /~https://github.com/AiviaCommunity/3D-RCAN\n","\n","We provide a dataset for the training of this notebook as a way to test its functionalities but the training and test data of the restoration experiments is also available from the authors of the original paper [here](https://www.dropbox.com/sh/hieldept1x476dw/AAC0pY3FrwdZBctvFF0Fx0L3a?dl=0).\n","\n","**Please also cite this original paper when using or developing this notebook.**"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"vNMDQHm0Ah-Z"},"source":["#**0. Before getting started**\n","---\n"," For CARE to train, **it needs to have access to a paired training dataset**. This means that the same image needs to be acquired in the two conditions (for instance, low signal-to-noise ratio and high signal-to-noise ratio) and provided with indication of correspondence.\n","\n"," Therefore, the data structure is important. It is necessary that all the input data are in the same folder and that all the output data is in a separate folder. The provided training dataset is already split in two folders called \"Training - Low SNR images\" (Training_source) and \"Training - high SNR images\" (Training_target). Information on how to generate a training dataset is available in our Wiki page: /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n"," **Additionally, the corresponding input and output files need to have the same name**.\n","\n"," Please note that you currently can **only use .tif files!**\n","\n"," You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed. \n","\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Low SNR images (Training_source)\n"," - img_1.tif, img_2.tif, ...\n"," - High SNR images (Training_target)\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - Low SNR images\n"," - img_1.tif, img_2.tif\n"," - High SNR images\n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **1. Install 3D-RCAN and dependencies**\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"1kvDz2Ft4FX6"},"source":["## **1.1. Install key dependencies**\n","---\n"," "]},{"cell_type":"code","metadata":{"id":"3u2mXn3XsWzd","cellView":"form"},"source":["Notebook_version = ['1.13']\n","\n","from builtins import any as b_any\n","\n","def get_requirements_path():\n"," # Store requirements file in 'contents' directory\n"," current_dir = os.getcwd()\n"," dir_count = current_dir.count('/') - 1\n"," path = '../' * (dir_count) + 'requirements.txt'\n"," return path\n","\n","def filter_files(file_list, filter_list):\n"," filtered_list = []\n"," for fname in file_list:\n"," if b_any(fname.split('==')[0] in s for s in filter_list):\n"," filtered_list.append(fname)\n"," return filtered_list\n","\n","def build_requirements_file(before, after):\n"," path = get_requirements_path()\n","\n"," # Exporting requirements.txt for local run\n"," !pip freeze > $path\n","\n"," # Get minimum requirements file\n"," df = pd.read_csv(path, delimiter = \"\\n\")\n"," mod_list = [m.split('.')[0] for m in after if not m in before]\n"," req_list_temp = df.values.tolist()\n"," req_list = [x[0] for x in req_list_temp]\n","\n"," # Replace with package name and handle cases where import name is different to module name\n"," mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]\n"," mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list]\n"," filtered_list = filter_files(req_list, mod_replace_list)\n","\n"," file=open(path,'w')\n"," for item in filtered_list:\n"," file.writelines(item + '\\n')\n","\n"," file.close()\n","\n","import sys\n","before = [str(m) for m in sys.modules]\n","\n","##############################\n","# Import statements go here: #\n","# import ... #\n","# import ... #\n","##############################\n","\n","#@markdown ##Install 3D-RCAN and dependencies\n","\n","!git clone /~https://github.com/AiviaCommunity/3D-RCAN\n","\n","import os\n","import pandas as pd\n","\n","!pip uninstall -y keras-nightly\n","\n","\n","!pip install q keras==2.2.5\n","\n","!pip install colorama; sys_platform=='win32'\n","!pip install jsonschema\n","!pip install numexpr\n","!pip install tqdm>=4.41.0\n","\n","\n","\n","%tensorflow_version 1.x\n","#Here, we install libraries which are not already included in Colab.\n","!pip install tifffile # contains tools to operate tiff-files\n","!pip install wget\n","!pip install fpdf\n","!pip install memory_profiler\n","%load_ext memory_profiler\n","\n","# Build requirements file for local run\n","after = [str(m) for m in sys.modules]\n","build_requirements_file(before, after)\n","\n","#Force session restart\n","exit(0)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"SBag2atY36Js"},"source":["\n","## **1.2. Restart your runtime**\n","---\n","\n","\n","\n","** Ignore the following message error message. Your Runtime has automatically restarted. This is normal.**\n","\n","\"\"
\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"_nRLOjuk3_8z"},"source":["## **1.3. Load key dependencies**\n","---\n"," "]},{"cell_type":"code","metadata":{"id":"TYYBwn_54G9j","cellView":"form"},"source":["Notebook_version = '1.13'\n","Network = '3D RCAN'\n","\n","\n","#@markdown ##Load key dependencies\n","\n","!pip install q keras==2.2.5\n","\n","#Here, we import and enable Tensorflow 1 instead of Tensorflow 2.\n","%tensorflow_version 1.x\n","import tensorflow\n","import tensorflow as tf\n","\n","print(tensorflow.__version__)\n","print(\"Tensorflow enabled.\")\n","\n","\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","from skimage.util import img_as_ubyte\n","from tqdm import tqdm \n","from fpdf import FPDF, HTMLMixin\n","from datetime import datetime\n","from pip._internal.operations.freeze import freeze\n","import subprocess\n","\n","\n","# For sliders and dropdown menu and progress bar\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print(\"Libraries installed\")\n","\n","# Check if this is the latest version of the notebook\n","All_notebook_versions = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_Notebook_versions.csv\", dtype=str)\n","print('Notebook version: '+Notebook_version)\n","Latest_Notebook_version = All_notebook_versions[All_notebook_versions[\"Notebook\"] == Network]['Version'].iloc[0]\n","print('Latest notebook version: '+Latest_Notebook_version)\n","if Notebook_version == Latest_Notebook_version:\n"," print(\"This notebook is up-to-date.\")\n","else:\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","\n","#PDF export\n","\n","def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n"," \n"," day = datetime.now()\n"," datetime_str = str(day)[0:16]\n","\n"," Header = 'Training report for '+Network+' model ('+model_name+')\\nDate and Time: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," \n"," # add another cell \n"," if trained:\n"," training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n"," pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n"," pdf.ln(1)\n","\n"," Header_2 = 'Information for your materials and method:'\n"," pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n"," #print(all_packages)\n","\n"," #Main Packages\n"," main_packages = ''\n"," version_numbers = []\n"," for name in ['tensorflow','numpy','Keras']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n"," cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n"," cuda_version = cuda_version.stdout.decode('utf-8')\n"," cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n"," gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n"," gpu_name = gpu_name.stdout.decode('utf-8')\n"," gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n"," #print(cuda_version[cuda_version.find(', V')+3:-1])\n"," #print(gpu_name)\n","\n"," shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n"," dataset_size = len(os.listdir(Training_source))\n","\n"," text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs (image dimensions: '+str(shape)+', using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.multi_cell(190, 5, txt = text, align='L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n"," pdf.set_font('')\n"," if augmentation:\n"," aug_text = 'The dataset was augmented by'\n"," if Rotation:\n"," aug_text = aug_text+'\\n- rotation'\n"," if Flip:\n"," aug_text = aug_text+'\\n- flipping'\n"," else:\n"," aug_text = 'No augmentation was used for training.'\n"," pdf.multi_cell(190, 5, txt=aug_text, align='L')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," if Use_Default_Advanced_Parameters:\n"," pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n"," pdf.cell(200, 5, txt='The following parameters were used for training:')\n"," pdf.ln(1)\n"," html = \"\"\" \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","\n","
ParameterValue
number_of_epochs{0}
number_of_steps{1}
percentage_validation{2}
num_residual_groups{3}
num_residual_blocks{4}
num_channels{5}
channel_reduction{6}
\n"," \"\"\".format(number_of_epochs,number_of_steps, percentage_validation, num_residual_groups, num_residual_blocks, num_channels, channel_reduction)\n"," pdf.write_html(html)\n","\n"," #pdf.multi_cell(190, 5, txt = text_2, align='L')\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," pdf.ln(1)\n"," pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(32, 5, txt= 'Training_source:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(30, 5, txt= 'Training_target:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n"," #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(22, 5, txt= 'Model Path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n"," pdf.ln(1)\n"," pdf.cell(60, 5, txt = 'Example Training pair', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread('/content/TrainingDataExample_3D_RCAN.png').shape\n"," pdf.image('/content/TrainingDataExample_3D_RCAN.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- 3D-RCAN: Chen et al. \"Three-dimensional residual channel attention networks denoise and sharpen fluorescence microscopy image volumes.\" bioRxiv 2020 https://www.biorxiv.org/content/10.1101/2020.08.27.270439v1'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n"," pdf.ln(3)\n"," reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," if trained:\n"," pdf.output(model_path+'/'+model_name+'/'+model_name+\"_training_report.pdf\")\n"," else:\n"," pdf.output('/content/'+model_name+\"_training_report.pdf\")\n","\n","\n","def qc_pdf_export():\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = '3D RCAN'\n","\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(2)\n"," pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n"," if os.path.exists(full_QC_model_path+'/Quality Control/lossCurvePlots.png'):\n"," pdf.image(full_QC_model_path+'/Quality Control/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/13))\n"," else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," # pdf.ln(3)\n"," pdf.multi_cell(190, 5, txt='You can see these curves in the notebook.')\n"," pdf.ln(3)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(3)\n"," pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n"," pdf.image(full_QC_model_path+'/Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/10))\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n","\n"," pdf.ln(1)\n"," html = \"\"\"\n"," \n"," \n"," \"\"\"\n"," with open(full_QC_model_path+'/Quality Control/QC_metrics_'+QC_model_name+'.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," slice_n = header[1]\n"," mSSIM_PvsGT = header[2]\n"," mSSIM_SvsGT = header[3]\n"," NRMSE_PvsGT = header[4]\n"," NRMSE_SvsGT = header[5]\n"," PSNR_PvsGT = header[6]\n"," PSNR_SvsGT = header[7]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,slice_n,mSSIM_PvsGT,mSSIM_SvsGT,NRMSE_PvsGT,NRMSE_SvsGT,PSNR_PvsGT,PSNR_SvsGT)\n"," html = html+header\n"," for row in metrics:\n"," image = row[0]\n"," slice_n = row[1]\n"," mSSIM_PvsGT = row[2]\n"," mSSIM_SvsGT = row[3]\n"," NRMSE_PvsGT = row[4]\n"," NRMSE_SvsGT = row[5]\n"," PSNR_PvsGT = row[6]\n"," PSNR_SvsGT = row[7]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,slice_n,str(round(float(mSSIM_PvsGT),3)),str(round(float(mSSIM_SvsGT),3)),str(round(float(NRMSE_PvsGT),3)),str(round(float(NRMSE_SvsGT),3)),str(round(float(PSNR_PvsGT),3)),str(round(float(PSNR_SvsGT),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}{3}{4}{5}{6}{7}
{0}{1}{2}{3}{4}{5}{6}{7}
\"\"\"\n"," \n"," pdf.write_html(html)\n","\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- Three-dimensional residual channel attention networks denoise and sharpen fluorescence microscopy image volumes, by Chen et al. bioRxiv (2020)'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n"," pdf.ln(3)\n"," reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n","\n","\n","!pip freeze > requirements.txt"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"b4-r1gE7Iamv"},"source":["# **2. Initialise the Colab session**\n","---"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["\n","## **2.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelerator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"BDhmUgqCStlm","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-oqBTeLaImnU"},"source":["## **2.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","cellView":"form"},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"sfz7CtBQE1I_"},"source":["** If you cannot see your files, reactivate your session by connecting to your hosted runtime.** \n","\n","\n","\"Example
Connect to a hosted runtime.
"]},{"cell_type":"markdown","metadata":{"id":"Fw0kkTU6CsU4"},"source":["# **3. Select your parameters and paths**\n","\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"WzYAA-MuaYrT"},"source":["## **3.1. Setting main training parameters**\n","---\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"CB6acvUFtWqd"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source (Low SNR images) and Training_target (High SNR images or ground truth) training data respecively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**Training Parameters**\n","\n","**`number of epochs`:**Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for 100-300 epochs. Evaluate the performance after training (see 5.). **Default value: 30**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each patch is seen at least once per epoch. **Default value: 256**\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during the training. **Default value: 10** \n","\n","**`num_residual_groups`:** Number of residual groups in RCAN. **Default value: 5** \n","\n","**If you get an Out of memory (OOM) error during the training, manually decrease the num_residual_groups value until the OOM error disappear.**\n","\n","**`num_residual_blocks`:** Number of residual channel attention blocks in each residual group in RCAN. **Default value: 3** \n","\n","**`num_channels`:** Number of feature channels in RCAN. **Default value: 32** \n","\n","**`channel_reduction`:** Channel reduction ratio for channel attention. **Default value: 8** \n","\n","\n"]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","cellView":"form"},"source":["\n","#@markdown ###Path to training images:\n","\n","# base folder of GT and low images\n","base = \"/content\"\n","\n","# low SNR images\n","Training_source = \"\" #@param {type:\"string\"}\n","lowfile = Training_source+\"/*.tif\"\n","# Ground truth images\n","Training_target = \"\" #@param {type:\"string\"}\n","GTfile = Training_target+\"/*.tif\"\n","\n","\n","# model name and path\n","#@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","\n","# create the training data file into model_path folder.\n","training_data = model_path+\"/my_training_data.npz\"\n","\n","# other parameters for training.\n","#@markdown ###Training Parameters\n","#@markdown Number of epochs:\n","\n","number_of_epochs = 30#@param {type:\"number\"}\n","number_of_steps = 256#@param {type:\"number\"}\n","\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","#@markdown ###If not, please input:\n","\n","percentage_validation = 10 #@param {type:\"number\"}\n","num_residual_groups = 5 #@param {type:\"number\"}\n","num_residual_blocks = 3 #@param {type:\"number\"}\n","num_channels = 32 #@param {type:\"number\"}\n","channel_reduction = 8 #@param {type:\"number\"}\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\") \n"," percentage_validation = 10\n"," num_residual_groups = 5\n"," num_channels = 32\n"," num_residual_blocks = 3\n"," channel_reduction = 8\n"," \n","\n","percentage = percentage_validation/100\n","\n","\n","full_model_path = model_path+'/'+model_name\n","\n","#here we check that no model with the same name already exist, if so print a warning\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n"," print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\"+W)\n"," \n"," \n","# Here we disable pre-trained model by default (in case the next cell is not ran)\n","Use_pretrained_model = False\n","\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = False\n","\n","\n","#Load one randomly chosen training source file\n","\n","random_choice=random.choice(os.listdir(Training_source))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","\n","# Here we check that the input images are stacks\n","if len(x.shape) == 3:\n"," print(\"Image dimensions (z,y,x)\",x.shape)\n","\n","if not len(x.shape) == 3:\n"," print(bcolors.WARNING +\"Your images appear to have the wrong dimensions. Image dimension\",x.shape)\n","\n","\n","#Find image Z dimension and select the mid-plane\n","Image_Z = x.shape[0]\n","mid_plane = int(Image_Z / 2)+1\n","\n","\n","#Find image XY dimension\n","Image_Y = x.shape[1]\n","Image_X = x.shape[2]\n","\n","# Here we split the data between training and validation\n","# Here we count the number of files in the training target folder\n","Filelist = os.listdir(Training_target)\n","number_files = len(Filelist)\n","\n","File_for_validation = int((number_files)/percentage_validation)+1\n","\n","#Here we split the training dataset between training and validation\n","# Everything is copied in the /Content Folder\n","\n","Training_source_temp = \"/content/training_source\"\n","\n","if os.path.exists(Training_source_temp):\n"," shutil.rmtree(Training_source_temp)\n","os.makedirs(Training_source_temp)\n","\n","Training_target_temp = \"/content/training_target\"\n","if os.path.exists(Training_target_temp):\n"," shutil.rmtree(Training_target_temp)\n","os.makedirs(Training_target_temp)\n","\n","Validation_source_temp = \"/content/validation_source\"\n","\n","if os.path.exists(Validation_source_temp):\n"," shutil.rmtree(Validation_source_temp)\n","os.makedirs(Validation_source_temp)\n","\n","Validation_target_temp = \"/content/validation_target\"\n","if os.path.exists(Validation_target_temp):\n"," shutil.rmtree(Validation_target_temp)\n","os.makedirs(Validation_target_temp)\n","\n","list_source = os.listdir(os.path.join(Training_source))\n","list_target = os.listdir(os.path.join(Training_target))\n","\n","#Move files into the temporary source and target directories:\n"," \n"," \n","for f in os.listdir(os.path.join(Training_source)):\n"," shutil.copy(Training_source+\"/\"+f, Training_source_temp+\"/\"+f)\n","\n","for p in os.listdir(os.path.join(Training_target)):\n"," shutil.copy(Training_target+\"/\"+p, Training_target_temp+\"/\"+p)\n","\n","\n","list_source_temp = os.listdir(os.path.join(Training_source_temp))\n","list_target_temp = os.listdir(os.path.join(Training_target_temp))\n","\n","\n","#Here we move images to be used for validation\n","for i in range(File_for_validation):\n","\n"," name = list_source_temp[i]\n"," shutil.move(Training_source_temp+\"/\"+name, Validation_source_temp+\"/\"+name)\n"," shutil.move(Training_target_temp+\"/\"+name, Validation_target_temp+\"/\"+name)\n","\n","\n","#Load one randomly chosen training target file\n","\n","os.chdir(Training_target)\n","y = imread(Training_target+\"/\"+random_choice)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x[mid_plane], norm=simple_norm(x[mid_plane], percent = 99), interpolation='nearest')\n","plt.axis('off')\n","plt.title('Low SNR image (single Z plane)');\n","plt.subplot(1,2,2)\n","plt.imshow(y[mid_plane], norm=simple_norm(y[mid_plane], percent = 99), interpolation='nearest')\n","plt.axis('off')\n","plt.title('High SNR image (single Z plane)');\n","plt.savefig('/content/TrainingDataExample_3D_RCAN.png',bbox_inches='tight',pad_inches=0)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xGcl7WGP4WHt"},"source":["## **3.2. Data augmentation**\n","---"]},{"cell_type":"markdown","metadata":{"id":"5Lio8hpZ4PJ1"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n"," **However, data augmentation is not a magic solution and may also introduce issues. Therefore, we recommend that you train your network with and without augmentation, and use the QC section to validate that it improves overall performances.** \n","\n","Data augmentation is performed here by rotating the training images in the XY-Plane and flipping them along X-Axis.\n","\n","**The flip option alone will double the size of your dataset, rotation will quadruple and both together will increase the dataset by a factor of 8.**"]},{"cell_type":"code","metadata":{"id":"htqjkJWt5J_8","cellView":"form"},"source":["Use_Data_augmentation = False #@param{type:\"boolean\"}\n","\n","#@markdown Select this option if you want to use augmentation to increase the size of your dataset\n","\n","#@markdown **Rotate each image 3 times by 90 degrees.**\n","Rotation = False #@param{type:\"boolean\"}\n","\n","#@markdown **Flip each image once around the x axis of the stack.**\n","Flip = False #@param{type:\"boolean\"}\n","\n","\n","#@markdown **Would you like to save your augmented images?**\n","\n","Save_augmented_images = False #@param {type:\"boolean\"}\n","\n","Saving_path = \"\" #@param {type:\"string\"}\n","\n","\n","if not Save_augmented_images:\n"," Saving_path= \"/content\"\n","\n","\n","def rotation_aug(Source_path, Target_path, flip=False):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path)\n"," \n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," # Source Rotation\n"," source_img_90 = np.rot90(source_img,axes=(1,2))\n"," source_img_180 = np.rot90(source_img_90,axes=(1,2))\n"," source_img_270 = np.rot90(source_img_180,axes=(1,2))\n","\n"," # Target Rotation\n"," target_img_90 = np.rot90(target_img,axes=(1,2))\n"," target_img_180 = np.rot90(target_img_90,axes=(1,2))\n"," target_img_270 = np.rot90(target_img_180,axes=(1,2))\n","\n"," # Add a flip to the rotation\n"," \n"," if flip == True:\n"," source_img_lr = np.fliplr(source_img)\n"," source_img_90_lr = np.fliplr(source_img_90)\n"," source_img_180_lr = np.fliplr(source_img_180)\n"," source_img_270_lr = np.fliplr(source_img_270)\n","\n"," target_img_lr = np.fliplr(target_img)\n"," target_img_90_lr = np.fliplr(target_img_90)\n"," target_img_180_lr = np.fliplr(target_img_180)\n"," target_img_270_lr = np.fliplr(target_img_270)\n","\n"," #source_img_90_ud = np.flipud(source_img_90)\n"," \n"," # Save the augmented files\n"," # Source images\n"," io.imsave(Saving_path+'/augmented_source/'+image,source_img)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_90.tif',source_img_90)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_180.tif',source_img_180)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_270.tif',source_img_270)\n"," # Target images\n"," io.imsave(Saving_path+'/augmented_target/'+image,target_img)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_90.tif',target_img_90)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_180.tif',target_img_180)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_270.tif',target_img_270)\n","\n"," if flip == True:\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_90_lr.tif',source_img_90_lr)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_180_lr.tif',source_img_180_lr)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_270_lr.tif',source_img_270_lr)\n","\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_90_lr.tif',target_img_90_lr)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_180_lr.tif',target_img_180_lr)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_270_lr.tif',target_img_270_lr)\n","\n","def flip(Source_path, Target_path):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path) \n","\n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," source_img_lr = np.fliplr(source_img)\n"," target_img_lr = np.fliplr(target_img)\n","\n"," io.imsave(Saving_path+'/augmented_source/'+image,source_img)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n","\n"," io.imsave(Saving_path+'/augmented_target/'+image,target_img)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n","\n","\n","if Use_Data_augmentation:\n","\n"," if os.path.exists(Saving_path+'/augmented_source'):\n"," shutil.rmtree(Saving_path+'/augmented_source')\n"," os.mkdir(Saving_path+'/augmented_source')\n","\n"," if os.path.exists(Saving_path+'/augmented_target'):\n"," shutil.rmtree(Saving_path+'/augmented_target') \n"," os.mkdir(Saving_path+'/augmented_target')\n","\n"," print(\"Data augmentation enabled\")\n"," print(\"Data augmentation in progress....\")\n","\n"," if Rotation == True:\n"," rotation_aug(Training_source_temp,Training_target_temp,flip=Flip)\n"," \n"," elif Rotation == False and Flip == True:\n"," flip(Training_source_temp,Training_target_temp)\n"," print(\"Done\")\n","\n","\n","if not Use_Data_augmentation:\n"," print(bcolors.WARNING+\"Data augmentation disabled\")\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"rQndJj70FzfL"},"source":["# **4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"tGW2iaU6X5zi"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"id":"WMJnGJpCMa4y","cellView":"form"},"source":["#@markdown ##Create the model and dataset objects\n","\n","# --------------------- Here we delete the model folder if it already exist ------------------------\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: Model folder already exists and has been removed !!\" + W)\n"," shutil.rmtree(model_path+'/'+model_name)\n","\n","\n","print(\"Preparing the config file...\")\n","\n","if Use_Data_augmentation == True:\n"," Training_source_temp = Saving_path+'/augmented_source'\n"," Training_target_temp = Saving_path+'/augmented_target'\n","\n","# Here we prepare the JSON file\n","\n","import json \n"," \n","# Config file for 3D-RCAN \n","dictionary ={\n"," \"epochs\": number_of_epochs,\n"," \"steps_per_epoch\": number_of_steps,\n"," \"num_residual_groups\": num_residual_groups,\n"," \"training_data_dir\": {\"raw\": Training_source_temp,\n"," \"gt\": Training_target_temp},\n"," \n"," \"validation_data_dir\": {\"raw\": Validation_source_temp,\n"," \"gt\": Validation_target_temp},\n"," \"num_channels\": num_channels,\n"," \"num_residual_blocks\": num_residual_blocks,\n"," \"channel_reduction\": channel_reduction\n"," \n"," \n","}\n"," \n","json_object = json.dumps(dictionary, indent = 4) \n"," \n","with open(\"/content/config.json\", \"w\") as outfile: \n"," outfile.write(json_object)\n","\n","# Export pdf summary of training parameters\n","pdf_export(augmentation = Use_Data_augmentation)\n","\n","print(\"Done\")\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"wQPz0F6JlvJR"},"source":["## **4.2. Start Training**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point.\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder from Google Drive as all data can be erased at the next training if using the same folder.\n","\n"]},{"cell_type":"code","metadata":{"id":"j_Qm5JBmlvJg","cellView":"form"},"source":["#@markdown ##Start Training\n","\n","start = time.time()\n","\n","# Start Training\n","!python /content/3D-RCAN/train.py -c /content/config.json -o \"$full_model_path\"\n","\n","print(\"Training, done.\")\n","\n","\n","\n","if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n"," shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","\n","#Create a pdf document with training summary\n","\n","pdf_export(trained = True, augmentation = Use_Data_augmentation)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"QYuIOWQ3imuU"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"id":"zazOZ3wDx0zQ","cellView":"form"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"yDY9dtzdUTLh"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"id":"vMzSP50kMv5p","cellView":"form"},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","%load_ext tensorboard\n","%tensorboard --logdir \"$full_QC_model_path\"\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"biT9FI9Ri77_"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"nAs4Wni7VYbq","cellView":"form"},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","path_metrics_save = QC_model_path+'/'+QC_model_name+'/Quality Control/'\n","\n","path_QC_prediction = path_metrics_save+'Prediction'\n","\n","\n","# Create a quality control/Prediction Folder\n","if os.path.exists(path_QC_prediction):\n"," shutil.rmtree(path_QC_prediction)\n","os.makedirs(path_QC_prediction)\n","\n","\n","# Perform the predictions\n","\n","print(\"Restoring images...\")\n","\n","!python /content/3D-RCAN/apply.py -m \"$full_QC_model_path\" -i \"$Source_QC_folder\" -o \"$path_QC_prediction\"\n","\n","print(\"Done...\")\n","\n","\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(path_metrics_save+'QC_metrics_'+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"File name\",\"Slice #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n"," \n"," # These lists will be used to collect all the metrics values per slice\n"," file_name_list = []\n"," slice_number_list = []\n"," mSSIM_GvP_list = []\n"," mSSIM_GvS_list = []\n"," NRMSE_GvP_list = []\n"," NRMSE_GvS_list = []\n"," PSNR_GvP_list = []\n"," PSNR_GvS_list = []\n","\n"," # These lists will be used to display the mean metrics for the stacks\n"," mSSIM_GvP_list_mean = []\n"," mSSIM_GvS_list_mean = []\n"," NRMSE_GvP_list_mean = []\n"," NRMSE_GvS_list_mean = []\n"," PSNR_GvP_list_mean = []\n"," PSNR_GvS_list_mean = []\n","\n"," # Let's loop through the provided dataset in the QC folders\n"," for thisFile in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder, thisFile)):\n"," print('Running QC on: '+thisFile)\n","\n"," test_GT_stack = io.imread(os.path.join(Target_QC_folder, thisFile))\n"," test_source_stack = io.imread(os.path.join(Source_QC_folder,thisFile))\n"," test_prediction_stack_raw = io.imread(os.path.join(path_metrics_save+\"Prediction/\",thisFile))\n"," test_prediction_stack = test_prediction_stack_raw[:, 1, :, :]\n"," n_slices = test_GT_stack.shape[0]\n","\n"," # Calculating the position of the mid-plane slice\n"," z_mid_plane = int(n_slices / 2)+1\n","\n"," img_SSIM_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_SSIM_GTvsSource_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_RSE_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_RSE_GTvsSource_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n","\n"," for z in range(n_slices): \n"," # -------------------------------- Normalising the dataset --------------------------------\n","\n"," test_GT_norm, test_source_norm = norm_minmse(test_GT_stack[z], test_source_stack[z], normalize_gt=True)\n"," test_GT_norm, test_prediction_norm = norm_minmse(test_GT_stack[z], test_prediction_stack[z], normalize_gt=True)\n","\n"," # -------------------------------- Calculate the SSIM metric and maps --------------------------------\n","\n"," # Calculate the SSIM maps and index\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = structural_similarity(test_GT_norm, test_prediction_norm, data_range=1.0, full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = structural_similarity(test_GT_norm, test_source_norm, data_range=1.0, full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n"," #Calculate ssim_maps\n"," img_SSIM_GTvsPrediction_stack[z] = img_as_float32(img_SSIM_GTvsPrediction, force_copy=False)\n"," img_SSIM_GTvsSource_stack[z] = img_as_float32(img_SSIM_GTvsSource, force_copy=False)\n"," \n","\n"," # -------------------------------- Calculate the NRMSE metrics --------------------------------\n","\n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n","\n"," # Calculate SE maps\n"," img_RSE_GTvsPrediction_stack[z] = img_as_float32(img_RSE_GTvsPrediction, force_copy=False)\n"," img_RSE_GTvsSource_stack[z] = img_as_float32(img_RSE_GTvsSource, force_copy=False)\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n","\n"," # Calculate the PSNR between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n","\n"," writer.writerow([thisFile, str(z),str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource), str(PSNR_GTvsPrediction), str(PSNR_GTvsSource)])\n"," \n"," # Collect values to display in dataframe output\n"," slice_number_list.append(z)\n"," mSSIM_GvP_list.append(index_SSIM_GTvsPrediction)\n"," mSSIM_GvS_list.append(index_SSIM_GTvsSource)\n"," NRMSE_GvP_list.append(NRMSE_GTvsPrediction)\n"," NRMSE_GvS_list.append(NRMSE_GTvsSource)\n"," PSNR_GvP_list.append(PSNR_GTvsPrediction)\n"," PSNR_GvS_list.append(PSNR_GTvsSource)\n","\n"," if (z == z_mid_plane): # catch these for display\n"," SSIM_GTvsP_forDisplay = index_SSIM_GTvsPrediction\n"," SSIM_GTvsS_forDisplay = index_SSIM_GTvsSource\n"," NRMSE_GTvsP_forDisplay = NRMSE_GTvsPrediction\n"," NRMSE_GTvsS_forDisplay = NRMSE_GTvsSource\n"," \n"," # If calculating average metrics for dataframe output\n"," file_name_list.append(thisFile)\n"," mSSIM_GvP_list_mean.append(sum(mSSIM_GvP_list)/len(mSSIM_GvP_list))\n"," mSSIM_GvS_list_mean.append(sum(mSSIM_GvS_list)/len(mSSIM_GvS_list))\n"," NRMSE_GvP_list_mean.append(sum(NRMSE_GvP_list)/len(NRMSE_GvP_list))\n"," NRMSE_GvS_list_mean.append(sum(NRMSE_GvS_list)/len(NRMSE_GvS_list))\n"," PSNR_GvP_list_mean.append(sum(PSNR_GvP_list)/len(PSNR_GvP_list))\n"," PSNR_GvS_list_mean.append(sum(PSNR_GvS_list)/len(PSNR_GvS_list))\n","\n"," # ----------- Change the stacks to 32 bit images -----------\n","\n"," img_SSIM_GTvsSource_stack_32 = img_as_float32(img_SSIM_GTvsSource_stack, force_copy=False)\n"," img_SSIM_GTvsPrediction_stack_32 = img_as_float32(img_SSIM_GTvsPrediction_stack, force_copy=False)\n"," img_RSE_GTvsSource_stack_32 = img_as_float32(img_RSE_GTvsSource_stack, force_copy=False)\n"," img_RSE_GTvsPrediction_stack_32 = img_as_float32(img_RSE_GTvsPrediction_stack, force_copy=False)\n","\n"," # ----------- Saving the error map stacks -----------\n"," io.imsave(path_metrics_save+'SSIM_GTvsSource_'+thisFile,img_SSIM_GTvsSource_stack_32)\n"," io.imsave(path_metrics_save+'SSIM_GTvsPrediction_'+thisFile,img_SSIM_GTvsPrediction_stack_32)\n"," io.imsave(path_metrics_save+'RSE_GTvsSource_'+thisFile,img_RSE_GTvsSource_stack_32)\n"," io.imsave(path_metrics_save+'RSE_GTvsPrediction_'+thisFile,img_RSE_GTvsPrediction_stack_32)\n","\n","#Averages of the metrics per stack as dataframe output\n","pdResults = pd.DataFrame(file_name_list, columns = [\"File name\"])\n","pdResults[\"Prediction v. GT mSSIM\"] = mSSIM_GvP_list_mean\n","pdResults[\"Input v. GT mSSIM\"] = mSSIM_GvS_list_mean\n","pdResults[\"Prediction v. GT NRMSE\"] = NRMSE_GvP_list_mean\n","pdResults[\"Input v. GT NRMSE\"] = NRMSE_GvS_list_mean\n","pdResults[\"Prediction v. GT PSNR\"] = PSNR_GvP_list_mean\n","pdResults[\"Input v. GT PSNR\"] = PSNR_GvS_list_mean\n","\n","# All data is now processed saved\n","Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same way\n","\n","plt.figure(figsize=(20,20))\n","# Currently only displays the last computed set, from memory\n","# Target (Ground-truth)\n","plt.subplot(3,3,1)\n","plt.axis('off')\n","img_GT = io.imread(os.path.join(Target_QC_folder, Test_FileList[-1]))\n","\n","# Calculating the position of the mid-plane slice\n","z_mid_plane = int(img_GT.shape[0] / 2)+1\n","\n","plt.imshow(img_GT[z_mid_plane], norm=simple_norm(img_GT[z_mid_plane], percent = 99))\n","plt.title('Target (slice #'+str(z_mid_plane)+')')\n","\n","# Source\n","plt.subplot(3,3,2)\n","plt.axis('off')\n","img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_Source[z_mid_plane], norm=simple_norm(img_Source[z_mid_plane], percent = 99))\n","plt.title('Source (slice #'+str(z_mid_plane)+')')\n","\n","#Prediction\n","plt.subplot(3,3,3)\n","plt.axis('off')\n","img_Prediction_raw = io.imread(os.path.join(path_metrics_save+'Prediction/', Test_FileList[-1]))\n","\n","img_Prediction = img_Prediction_raw[:, 1, :, :]\n","plt.imshow(img_Prediction[z_mid_plane], norm=simple_norm(img_Prediction[z_mid_plane], percent = 99))\n","plt.title('Prediction (slice #'+str(z_mid_plane)+')')\n","\n","#Setting up colours\n","cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Source\n","plt.subplot(3,3,5)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","img_SSIM_GTvsSource = io.imread(os.path.join(path_metrics_save, 'SSIM_GTvsSource_'+Test_FileList[-1]))\n","imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource[z_mid_plane], cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(SSIM_GTvsS_forDisplay,3)),fontsize=14)\n","plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n","plt.subplot(3,3,6)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_SSIM_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'SSIM_GTvsPrediction_'+Test_FileList[-1]))\n","imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0,vmax=1)\n","plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(SSIM_GTvsP_forDisplay,3)),fontsize=14)\n","\n","#Root Squared Error between GT and Source\n","plt.subplot(3,3,8)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","img_RSE_GTvsSource = io.imread(os.path.join(path_metrics_save, 'RSE_GTvsSource_'+Test_FileList[-1]))\n","imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource[z_mid_plane], cmap = cmap, vmin=0, vmax = 1) \n","plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsS_forDisplay,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n","#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\n","plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#Root Squared Error between GT and Prediction\n","plt.subplot(3,3,9)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_RSE_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'RSE_GTvsPrediction_'+Test_FileList[-1]))\n","imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsP_forDisplay,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n","plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","\n","\n","print('-----------------------------------')\n","print('Here are the average scores for the stacks you tested in Quality control. To see values for all slices, open the .csv file saved in the Quality Control folder.')\n","pdResults.head()\n","\n","\n","\n","#Make a pdf summary of the QC results\n","\n","qc_pdf_export()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"69aJVFfsqXbY"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"tcPNRq1TrMPB"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images."]},{"cell_type":"code","metadata":{"id":"Am2JSmpC0frj","cellView":"form"},"source":["\n","#@markdown ##Provide the path to your dataset and to the folder where the prediction will be saved, then play the cell to predict output on your unseen images.\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n"," \n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","print(\"Restoring images...\")\n","\n","!python /content/3D-RCAN/apply.py -m \"$full_Prediction_model_path\" -i \"$Data_folder\" -o \"$Result_folder\"\n","\n","print(\"Images saved into the result folder:\", Result_folder)\n","\n","#Display an example\n","\n","random_choice=random.choice(os.listdir(Data_folder))\n","x = imread(Data_folder+\"/\"+random_choice)\n","\n","z_mid_plane = int(x.shape[0] / 2)+1\n","\n","@interact\n","def show_results(file=os.listdir(Data_folder), z_plane=widgets.IntSlider(min=0, max=(x.shape[0]-1), step=1, value=z_mid_plane)):\n"," x = imread(Data_folder+\"/\"+file)\n"," y_raw = imread(Result_folder+\"/\"+file)\n"," y = y_raw[:, 1, :, :]\n","\n"," f=plt.figure(figsize=(16,8))\n"," plt.subplot(1,2,1)\n"," plt.imshow(x[z_plane], norm=simple_norm(x[z_plane], percent = 99), interpolation='nearest')\n"," plt.axis('off')\n"," plt.title('Noisy Input (single Z plane)');\n"," plt.subplot(1,2,2)\n"," plt.imshow(y[z_plane], norm=simple_norm(y[z_plane], percent = 99), interpolation='nearest')\n"," plt.axis('off')\n"," plt.title('Prediction (single Z plane)');\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB"},"source":["## **6.2. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"3yoo_0c7FU-4"},"source":["# **7. Version log**\n","---\n","**v1.13**: \n","\n","\n","* The section 1 and 2 are now swapped for better export of *requirements.txt*.\n","\n","* This version also now includes built-in version check and the version log that you're reading now.\n","\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"u4pcBe8Z3T2J"},"source":["#**Thank you for using 3D-RCAN!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/Beta notebooks/Cellpose_2D_ZeroCostDL4Mic.ipynb b/Colab_notebooks/Beta notebooks/Cellpose_2D_ZeroCostDL4Mic.ipynb index e3dcc661..bf533042 100644 --- a/Colab_notebooks/Beta notebooks/Cellpose_2D_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/Beta notebooks/Cellpose_2D_ZeroCostDL4Mic.ipynb @@ -1,2182 +1 @@ -{ - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "accelerator": "GPU", - "colab": { - "name": "Cellpose_2D_ZeroCostDL4Mic.ipynb", - "provenance": [], - "collapsed_sections": [], - "toc_visible": true, - "machine_shape": "hm" - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.4" - } - }, - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "0YVrnjRozAGg" - }, - "source": [ - "This is notebook is in beta, expect bugs and missing features compared to other ZeroCostDL4Mic notebooks\n", - "\n", - "- Training now uses TORCH. \n", - "- Currently missing features include: \n", - " - The PDF report is not yet generated\n", - " - The training and validation curves are not saved or visualised\n", - " " - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "V9zNGvape2-I" - }, - "source": [ - "# **Cellpose (2D)**\n", - "\n", - "---\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "pwLsIXtEw3Kx" - }, - "source": [ - "**Cellpose 2D** is a deep-learning method that can be used to segment cell and/or nuclei from bioimages and was first published by [Stringer *et al.* in 2020, in Nature Method](https://www.nature.com/articles/s41592-020-01018-x). \n", - "\n", - " **This particular notebook enables nuclei segmentation of 2D dataset. If you are interested in 3D dataset, you should use the StarDist or U-Net 3D notebooks instead.**\n", - "\n", - "---\n", - "*Disclaimer*:\n", - "\n", - "This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n", - "\n", - "This notebook is largely based on the paper:\n", - "\n", - "**Cellpose: a generalist algorithm for cellular segmentation** from Stringer *et al.*, Nature Methods, 2020. (https://www.nature.com/articles/s41592-020-01018-x)\n", - "\n", - "**The Original code** is freely available in GitHub:\n", - "/~https://github.com/MouseLand/cellpose\n", - "\n", - "**Please also cite this original paper when using or developing this notebook.**\n", - "\n", - "**This notebook was also inspired by the one created by @pr4deepr** which is available here:\n", - "https://colab.research.google.com/github/MouseLand/cellpose/blob/master/notebooks/Cellpose_2D_v0_1.ipynb\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "C5oYf0Q5yXrl" - }, - "source": [ - "#**0. Before getting started**\n", - "---\n", - " For Cellpose to train, **it needs to have access to a paired training dataset made of images and their corresponding masks (label images)**. Information on how to generate a training dataset is available in our Wiki page: /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n", - "\n", - "**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model**. The quality control assessment can be done directly in this notebook.\n", - "\n", - "The data structure is important. It is necessary that all the input data are in the same folder and that all the output data is in a separate folder. The provided training dataset is already split in two folders called \"Training - Images\" (Training_source) and \"Training - Masks\" (Training_target).\n", - "\n", - "Additionally, the corresponding Training_source and Training_target files need to have **the same name**.\n", - "\n", - " **Use 8/16 bit png or Tiff images**.\n", - "\n", - "\n", - "You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed. This can include Test dataset for which you have the equivalent output and can compare to what the network provides.\n", - "\n", - "Here's a common data structure that can work:\n", - "* Experiment A\n", - " - **Training dataset**\n", - " - Images (Training_source)\n", - " - img_1.tif, img_2.tif, ...\n", - " - Label images (Training_target)\n", - " - img_1.tif, img_2.tif, ...\n", - " - **Quality control dataset**\n", - " - Images\n", - " - img_1.tif, img_2.tif\n", - " - Masks \n", - " - img_1.tif, img_2.tif\n", - " - **Data to be predicted**\n", - " - **Results**\n", - "\n", - "---\n", - "**Important note**\n", - "\n", - "- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n", - "\n", - "- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n", - "\n", - "- If you only wish to **run predictions** using a pretrained model you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "b4-r1gE7Iamv" - }, - "source": [ - "# **1. Initialise the Colab session**\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "DMNHVZfHmbKb" - }, - "source": [ - "\n", - "## **1.1. Check for GPU access**\n", - "---\n", - "\n", - "By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n", - "\n", - "Go to **Runtime -> Change the Runtime type**\n", - "\n", - "**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n", - "\n", - "**Accelerator: GPU** *(Graphics processing unit)*\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "cellView": "form", - "id": "BDhmUgqCStlm" - }, - "source": [ - "#@markdown ##Run this cell to check if you have GPU access\n", - "\n", - "%tensorflow_version 1.x\n", - "\n", - "import tensorflow as tf\n", - "if tf.test.gpu_device_name()=='':\n", - " print('You do not have GPU access.') \n", - " print('Did you change your runtime ?') \n", - " print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n", - " print('Expect slow performance. To access GPU try reconnecting later')\n", - "\n", - "else:\n", - " print('You have GPU access')\n", - " !nvidia-smi" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-oqBTeLaImnU" - }, - "source": [ - "## **1.2. Mount your Google Drive**\n", - "---\n", - " To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n", - "\n", - " Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n", - "\n", - " Once this is done, your data are available in the **Files** tab on the top left of notebook." - ] - }, - { - "cell_type": "code", - "metadata": { - "cellView": "form", - "id": "01Djr8v-5pPk" - }, - "source": [ - "\n", - "#@markdown ##Run this cell to connect your Google Drive to Colab\n", - "\n", - "#@markdown * Click on the URL. \n", - "\n", - "#@markdown * Sign in your Google Account. \n", - "\n", - "#@markdown * Copy the authorization code. \n", - "\n", - "#@markdown * Enter the authorization code. \n", - "\n", - "#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n", - "\n", - "#mounts user's Google Drive to Google Colab.\n", - "\n", - "from google.colab import drive\n", - "drive.mount('/content/gdrive')\n", - "\n", - "\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "n4yWFoJNnoin" - }, - "source": [ - "# **2. Install Cellpose and dependencies**\n", - "---\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "3u2mXn3XsWzd", - "cellView": "form" - }, - "source": [ - "Notebook_version = ['1.12.4']\n", - "\n", - "from builtins import any as b_any\n", - "\n", - "def get_requirements_path():\n", - " # Store requirements file in 'contents' directory\n", - " current_dir = os.getcwd()\n", - " dir_count = current_dir.count('/') - 1\n", - " path = '../' * (dir_count) + 'requirements.txt'\n", - " return path\n", - "\n", - "def filter_files(file_list, filter_list):\n", - " filtered_list = []\n", - " for fname in file_list:\n", - " if b_any(fname.split('==')[0] in s for s in filter_list):\n", - " filtered_list.append(fname)\n", - " return filtered_list\n", - "\n", - "def build_requirements_file(before, after):\n", - " path = get_requirements_path()\n", - "\n", - " # Exporting requirements.txt for local run\n", - " !pip freeze > $path\n", - "\n", - " # Get minimum requirements file\n", - " df = pd.read_csv(path, delimiter = \"\\n\")\n", - " mod_list = [m.split('.')[0] for m in after if not m in before]\n", - " req_list_temp = df.values.tolist()\n", - " req_list = [x[0] for x in req_list_temp]\n", - "\n", - " # Replace with package name and handle cases where import name is different to module name\n", - " mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]\n", - " mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list]\n", - " filtered_list = filter_files(req_list, mod_replace_list)\n", - "\n", - " file=open(path,'w')\n", - " for item in filtered_list:\n", - " file.writelines(item + '\\n')\n", - "\n", - " file.close()\n", - "\n", - "import sys\n", - "before = [str(m) for m in sys.modules]\n", - "\n", - "#@markdown ##Install Cellpose and dependencies\n", - "\n", - "#Libraries contains information of certain topics. \n", - "#For example the tifffile library contains information on how to handle tif-files.\n", - "\n", - "#Here, we install libraries which are not already included in Colab.\n", - "\n", - "!pip install tifffile # contains tools to operate tiff-files\n", - "!pip install cellpose \n", - "#!pip install mxnet-cu101 \n", - "!pip install wget\n", - "!pip install memory_profiler\n", - "!pip install fpdf\n", - "%load_ext memory_profiler\n", - "\n", - "# ------- Variable specific to Cellpose -------\n", - "\n", - "from urllib.parse import urlparse\n", - "%matplotlib inline\n", - "from cellpose import models\n", - "use_GPU = models.use_gpu()\n", - "\n", - "#import mxnet as mx\n", - "\n", - "from skimage.util import img_as_ubyte\n", - "import cv2\n", - "from cellpose import plot\n", - "from ipywidgets import interact, interact_manual\n", - "from zipfile import ZIP_DEFLATED\n", - "\n", - "\n", - "# For sliders and dropdown menu and progress bar\n", - "from ipywidgets import interact\n", - "import ipywidgets as widgets\n", - "\n", - "\n", - "# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n", - "import numpy as np\n", - "from matplotlib import pyplot as plt\n", - "import urllib\n", - "import os, random\n", - "import shutil \n", - "import zipfile\n", - "from tifffile import imread, imsave\n", - "import time\n", - "import sys\n", - "import wget\n", - "from pathlib import Path\n", - "import pandas as pd\n", - "import csv\n", - "from glob import glob\n", - "from scipy import signal\n", - "from scipy import ndimage\n", - "from skimage import io\n", - "from sklearn.linear_model import LinearRegression\n", - "from skimage.util import img_as_uint\n", - "import matplotlib as mpl\n", - "from skimage.metrics import structural_similarity\n", - "from skimage.metrics import peak_signal_noise_ratio as psnr\n", - "from astropy.visualization import simple_norm\n", - "from skimage import img_as_float32\n", - "from skimage.util import img_as_ubyte\n", - "from tqdm import tqdm \n", - "from fpdf import FPDF, HTMLMixin\n", - "from datetime import datetime\n", - "import subprocess\n", - "from pip._internal.operations.freeze import freeze\n", - "\n", - "# Colors for the warning messages\n", - "class bcolors:\n", - " WARNING = '\\033[31m'\n", - "\n", - "W = '\\033[0m' # white (normal)\n", - "R = '\\033[31m' # red\n", - "\n", - "#Disable some of the tensorflow warnings\n", - "import warnings\n", - "warnings.filterwarnings(\"ignore\")\n", - "\n", - "print(\"Libraries installed\")\n", - "\n", - "\n", - "# Check if this is the latest version of the notebook\n", - "Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n", - "\n", - "print('Notebook version: '+Notebook_version[0])\n", - "\n", - "strlist = Notebook_version[0].split('.')\n", - "Notebook_version_main = strlist[0]+'.'+strlist[1]\n", - "\n", - "if Notebook_version_main == Latest_notebook_version.columns:\n", - " print(\"This notebook is up-to-date.\")\n", - "else:\n", - " print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n", - "\n", - "\n", - "!pip freeze > requirements.txt\n", - "\n", - "#Create a pdf document with training summary\n", - "\n", - "def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n", - " # save FPDF() class into a \n", - " # variable pdf \n", - " #from datetime import datetime\n", - "\n", - " class MyFPDF(FPDF, HTMLMixin):\n", - " pass\n", - "\n", - " pdf = MyFPDF()\n", - " pdf.add_page()\n", - " pdf.set_right_margin(-1)\n", - " pdf.set_font(\"Arial\", size = 11, style='B') \n", - "\n", - " Network = 'Cellpose 2D'\n", - " day = datetime.now()\n", - " datetime_str = str(day)[0:10]\n", - "\n", - " Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n", - " pdf.multi_cell(180, 5, txt = Header, align = 'L') \n", - "\n", - " # add another cell \n", - " if trained:\n", - " training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n", - " pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n", - " pdf.ln(1)\n", - "\n", - " Header_2 = 'Information for your materials and methods:'\n", - " pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n", - "\n", - " all_packages = ''\n", - " for requirement in freeze(local_only=True):\n", - " all_packages = all_packages+requirement+', '\n", - " #print(all_packages)\n", - "\n", - " #Main Packages\n", - " main_packages = ''\n", - " version_numbers = []\n", - " for name in ['tensorflow','numpy','Keras','csbdeep']:\n", - " find_name=all_packages.find(name)\n", - " main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n", - " #Version numbers only here:\n", - " version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n", - "\n", - " cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n", - " cuda_version = cuda_version.stdout.decode('utf-8')\n", - " cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n", - " gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n", - " gpu_name = gpu_name.stdout.decode('utf-8')\n", - " gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n", - " #print(cuda_version[cuda_version.find(', V')+3:-1])\n", - " #print(gpu_name)\n", - "\n", - " shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n", - " dataset_size = len(os.listdir(Training_source))\n", - "\n", - " text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(dataset_size*number_of_patches)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+config.train_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n", - "\n", - " if pretrained_model:\n", - " text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(dataset_size*number_of_patches)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+config.train_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was re-trained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n", - "\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - " pdf.multi_cell(190, 5, txt = text, align='L')\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.ln(1)\n", - " pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n", - " pdf.set_font('')\n", - " if augmentation:\n", - " aug_text = 'The dataset was augmented by a factor of '+str(Multiply_dataset_by)+' by'\n", - " if rotate_270_degrees != 0 or rotate_90_degrees != 0:\n", - " aug_text = aug_text+'\\n- rotation'\n", - " if flip_left_right != 0 or flip_top_bottom != 0:\n", - " aug_text = aug_text+'\\n- flipping'\n", - " if random_zoom_magnification != 0:\n", - " aug_text = aug_text+'\\n- random zoom magnification'\n", - " if random_distortion != 0:\n", - " aug_text = aug_text+'\\n- random distortion'\n", - " if image_shear != 0:\n", - " aug_text = aug_text+'\\n- image shearing'\n", - " if skew_image != 0:\n", - " aug_text = aug_text+'\\n- image skewing'\n", - " else:\n", - " aug_text = 'No augmentation was used for training.'\n", - " pdf.multi_cell(190, 5, txt=aug_text, align='L')\n", - " pdf.set_font('Arial', size = 11, style = 'B')\n", - " pdf.ln(1)\n", - " pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - " if Use_Default_Advanced_Parameters:\n", - " pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n", - " pdf.cell(200, 5, txt='The following parameters were used for training:')\n", - " pdf.ln(1)\n", - " html = \"\"\" \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
ParameterValue
number_of_epochs{0}
patch_size{1}
number_of_patches{2}
batch_size{3}
number_of_steps{4}
percentage_validation{5}
initial_learning_rate{6}
\n", - " \"\"\".format(number_of_epochs,str(patch_size)+'x'+str(patch_size),number_of_patches,batch_size,number_of_steps,percentage_validation,initial_learning_rate)\n", - " pdf.write_html(html)\n", - "\n", - " #pdf.multi_cell(190, 5, txt = text_2, align='L')\n", - " pdf.set_font(\"Arial\", size = 11, style='B')\n", - " pdf.ln(1)\n", - " pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.cell(29, 5, txt= 'Training_source:', align = 'L', ln=0)\n", - " pdf.set_font('')\n", - " pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.cell(27, 5, txt= 'Training_target:', align = 'L', ln=0)\n", - " pdf.set_font('')\n", - " pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n", - " #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n", - " pdf.ln(1)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.cell(22, 5, txt= 'Model Path:', align = 'L', ln=0)\n", - " pdf.set_font('')\n", - " pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n", - " pdf.ln(1)\n", - " pdf.cell(60, 5, txt = 'Example Training pair', ln=1)\n", - " pdf.ln(1)\n", - " exp_size = io.imread('/content/TrainingDataExample_CARE2D.png').shape\n", - " pdf.image('/content/TrainingDataExample_CARE2D.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n", - " pdf.ln(1)\n", - " ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n", - " pdf.multi_cell(190, 5, txt = ref_1, align='L')\n", - " ref_2 = '- CARE: Weigert, Martin, et al. \"Content-aware image restoration: pushing the limits of fluorescence microscopy.\" Nature methods 15.12 (2018): 1090-1097.'\n", - " pdf.multi_cell(190, 5, txt = ref_2, align='L')\n", - " if augmentation:\n", - " ref_3 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\n", - " pdf.multi_cell(190, 5, txt = ref_3, align='L')\n", - " pdf.ln(3)\n", - " reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n", - " pdf.set_font('Arial', size = 11, style='B')\n", - " pdf.multi_cell(190, 5, txt=reminder, align='C')\n", - "\n", - " pdf.output(model_path+'/'+model_name+'/'+model_name+\"_training_report.pdf\")\n", - "\n", - "\n", - "#Make a pdf summary of the QC results\n", - "\n", - "def qc_pdf_export():\n", - " class MyFPDF(FPDF, HTMLMixin):\n", - " pass\n", - "\n", - " pdf = MyFPDF()\n", - " pdf.add_page()\n", - " pdf.set_right_margin(-1)\n", - " pdf.set_font(\"Arial\", size = 11, style='B') \n", - "\n", - " Network = 'CARE 2D'\n", - " #model_name = os.path.basename(full_QC_model_path)\n", - " day = datetime.now()\n", - " datetime_str = str(day)[0:10]\n", - "\n", - " Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n", - " pdf.multi_cell(180, 5, txt = Header, align = 'L') \n", - "\n", - " all_packages = ''\n", - " for requirement in freeze(local_only=True):\n", - " all_packages = all_packages+requirement+', '\n", - "\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 11, style = 'B')\n", - " pdf.ln(2)\n", - " pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n", - " pdf.ln(1)\n", - " exp_size = io.imread(full_QC_model_path+'Quality Control/QC_example_data.png').shape\n", - " if os.path.exists(full_QC_model_path+'Quality Control/lossCurvePlots.png'):\n", - " pdf.image(full_QC_model_path+'Quality Control/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/13))\n", - " else:\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size=10)\n", - " pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.', align='L')\n", - " pdf.ln(2)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.ln(3)\n", - " pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n", - " pdf.ln(1)\n", - " exp_size = io.imread(full_QC_model_path+'Quality Control/QC_example_data.png').shape\n", - " pdf.image(full_QC_model_path+'Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/10))\n", - " pdf.ln(1)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 11, style = 'B')\n", - " pdf.ln(1)\n", - " pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - "\n", - " pdf.ln(1)\n", - " html = \"\"\"\n", - " \n", - " \n", - " \"\"\"\n", - " with open(full_QC_model_path+'Quality Control/QC_metrics_'+QC_model_name+'.csv', 'r') as csvfile:\n", - " metrics = csv.reader(csvfile)\n", - " header = next(metrics)\n", - " image = header[0]\n", - " mSSIM_PvsGT = header[1]\n", - " mSSIM_SvsGT = header[2]\n", - " NRMSE_PvsGT = header[3]\n", - " NRMSE_SvsGT = header[4]\n", - " PSNR_PvsGT = header[5]\n", - " PSNR_SvsGT = header[6]\n", - " header = \"\"\"\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \"\"\".format(image,mSSIM_PvsGT,mSSIM_SvsGT,NRMSE_PvsGT,NRMSE_SvsGT,PSNR_PvsGT,PSNR_SvsGT)\n", - " html = html+header\n", - " for row in metrics:\n", - " image = row[0]\n", - " mSSIM_PvsGT = row[1]\n", - " mSSIM_SvsGT = row[2]\n", - " NRMSE_PvsGT = row[3]\n", - " NRMSE_SvsGT = row[4]\n", - " PSNR_PvsGT = row[5]\n", - " PSNR_SvsGT = row[6]\n", - " cells = \"\"\"\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \"\"\".format(image,str(round(float(mSSIM_PvsGT),3)),str(round(float(mSSIM_SvsGT),3)),str(round(float(NRMSE_PvsGT),3)),str(round(float(NRMSE_SvsGT),3)),str(round(float(PSNR_PvsGT),3)),str(round(float(PSNR_SvsGT),3)))\n", - " html = html+cells\n", - " html = html+\"\"\"
{0}{1}{2}{3}{4}{5}{6}
{0}{1}{2}{3}{4}{5}{6}
\"\"\"\n", - "\n", - " pdf.write_html(html)\n", - "\n", - " pdf.ln(1)\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - " ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n", - " pdf.multi_cell(190, 5, txt = ref_1, align='L')\n", - " ref_2 = '- CARE: Weigert, Martin, et al. \"Content-aware image restoration: pushing the limits of fluorescence microscopy.\" Nature methods 15.12 (2018): 1090-1097.'\n", - " pdf.multi_cell(190, 5, txt = ref_2, align='L')\n", - "\n", - " pdf.ln(3)\n", - " reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n", - "\n", - " pdf.set_font('Arial', size = 11, style='B')\n", - " pdf.multi_cell(190, 5, txt=reminder, align='C')\n", - "\n", - " pdf.output(full_QC_model_path+'Quality Control/'+QC_model_name+'_QC_report.pdf')\n", - "\n", - "# Build requirements file for local run\n", - "after = [str(m) for m in sys.modules]\n", - "build_requirements_file(before, after)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "2UfUWjI_askO" - }, - "source": [ - "# **3. Select your parameters and paths**\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ZyMxrSWvavVL" - }, - "source": [ - "## **3.1. Setting main training parameters**\n", - "---\n", - " " - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "5MlTyQVXXvDx" - }, - "source": [ - " **Paths for training, predictions and results**\n", - "\n", - "\n", - "**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source (images of cells) and Training_target (masks) training data respecively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n", - "\n", - "**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n", - "\n", - "**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n", - "\n", - "\n", - "**Training parameters**\n", - "\n", - "**`number_of_epochs`:** Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a 100 epochs, but a full training should run for up to 500-1000 epochs. Evaluate the performance after training (see 5.). **Default value: 500**\n", - "\n", - "**Advanced Parameters - experienced users only**\n", - "\n", - "**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 8**\n", - "\n", - "\n", - "**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during the training. **Default value: 10** \n", - "\n", - "\n", - "\n", - "\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "2HkNZ16BdfJv", - "cellView": "form" - }, - "source": [ - "#@markdown ###Path to training images:\n", - "\n", - "Training_source = \"\" #@param {type:\"string\"}\n", - "Training_target = \"\" #@param {type:\"string\"}\n", - "\n", - "#Define where the patch file will be saved\n", - "base = \"/content\"\n", - "\n", - "# model name and path\n", - "#@markdown ###Name of the model and path to model folder:\n", - "model_name = \"\" #@param {type:\"string\"}\n", - "model_path = \"\" #@param {type:\"string\"}\n", - "\n", - "# other parameters for training.\n", - "#@markdown ###Training Parameters:\n", - "#@markdown Number of epochs:\n", - "number_of_epochs = 500 #@param {type:\"number\"}\n", - "\n", - "Channel_to_use_for_training = \"Grayscale\" #@param [\"Grayscale\", \"Blue\", \"Green\", \"Red\"]\n", - "\n", - "# @markdown ###If you have a secondary channel that can be used for training, for instance nuclei, choose it here:\n", - "\n", - "Second_training_channel= \"None\" #@param [\"None\", \"Blue\", \"Green\", \"Red\"]\n", - "\n", - "\n", - "#@markdown ###Advanced Parameters\n", - "\n", - "Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n", - "#@markdown ###If not, please input:\n", - "batch_size = 8#@param {type:\"number\"}\n", - "initial_learning_rate = 0.0002 #@param {type:\"number\"}\n", - "percentage_validation = 10#@param {type:\"number\"}\n", - "\n", - "if (Use_Default_Advanced_Parameters): \n", - " print(\"Default advanced parameters enabled\")\n", - " batch_size = 8 \n", - " initial_learning_rate = 0.0002\n", - " percentage_validation = 10\n", - "\n", - "#here we check that no model with the same name already exist, if so delete\n", - "if os.path.exists(model_path+'/'+model_name):\n", - " print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n", - " print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\"+W)\n", - " \n", - "# Here we enable the cyto pre-trained model by default (in case the cell is not ran)\n", - "model_to_load = \"cyto\"\n", - "# Here we disable data augmentation by default (in case the cell is not ran)\n", - "\n", - "Use_Data_augmentation = True\n", - "\n", - "# This will display a randomly chosen dataset input and output\n", - "random_choice = random.choice(os.listdir(Training_source))\n", - "x = io.imread(Training_source+\"/\"+random_choice)\n", - "norm = simple_norm(x, percent = 99)\n", - "y = io.imread(Training_target+\"/\"+random_choice)\n", - "\n", - "# Find the number of channel in the input image\n", - "\n", - "n_channel = 1 if x.ndim == 2 else x.shape[-1]\n", - "\n", - "\n", - "# Here we match the channel to number\n", - "\n", - "if Channel_to_use_for_training == \"Grayscale\":\n", - " Training_channel = 0\n", - "\n", - " if not n_channel == 1:\n", - " print(bcolors.WARNING +\"!! WARNING: your image has more than one channel, choose which channel you want to use for trainning !!\")\n", - "\n", - "if Channel_to_use_for_training == \"Blue\":\n", - " Training_channel = 3\n", - "\n", - "if Channel_to_use_for_training == \"Green\":\n", - " Training_channel = 2\n", - "\n", - "if Channel_to_use_for_training == \"Red\":\n", - " Training_channel = 1\n", - "\n", - "\n", - "if Second_training_channel == \"Blue\":\n", - " Second_training_channel = 3\n", - "\n", - "if Second_training_channel == \"Green\":\n", - " Second_training_channel = 2\n", - "\n", - "if Second_training_channel == \"Red\":\n", - " Second_training_channel = 1\n", - "\n", - "if Second_training_channel == \"None\":\n", - " Second_training_channel = 0\n", - "\n", - "\n", - "if n_channel ==1:\n", - "\n", - " f=plt.figure(figsize=(16,8))\n", - " plt.subplot(1,2,1)\n", - " plt.imshow(x, norm=norm, cmap='magma', interpolation='nearest')\n", - " plt.title('Training source')\n", - " plt.axis('off');\n", - "\n", - " plt.subplot(1,2,2)\n", - " plt.imshow(y,cmap='nipy_spectral', interpolation='nearest')\n", - " plt.title('Training target')\n", - " plt.axis('off');\n", - "\n", - " plt.savefig('/content/TrainingDataExample_Cellpose2D.png',bbox_inches='tight',pad_inches=0)\n", - "\n", - "else:\n", - "\n", - " f=plt.figure(figsize=(20,10))\n", - " plt.subplot(1,3,1)\n", - " plt.imshow(x, interpolation='nearest')\n", - " plt.title('Training source')\n", - " plt.axis('off');\n", - "\n", - " plt.subplot(1,3,2)\n", - " plt.imshow(x[:, :, int(Training_channel-1)],cmap='magma', interpolation='nearest')\n", - " plt.title('Channel used for training')\n", - " plt.axis('off');\n", - "\n", - " plt.subplot(1,3,3)\n", - " plt.imshow(y,cmap='nipy_spectral', interpolation='nearest')\n", - " plt.title('Training target')\n", - " plt.axis('off');\n", - "\n", - " plt.savefig('/content/TrainingDataExample_Cellpose2D.png',bbox_inches='tight',pad_inches=0)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "qEg6ar0PhuDY" - }, - "source": [ - "## **3.2. Data augmentation**\n", - "---\n", - "" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "t6q9aqDUhxlw" - }, - "source": [ - "Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n", - "\n", - "By default, a x4 data augmentation is enabled in this notebook." - ] - }, - { - "cell_type": "code", - "metadata": { - "cellView": "form", - "id": "SblwpgmahfBl" - }, - "source": [ - "#Data augmentation\n", - "\n", - "Use_Data_augmentation = True #@param {type:\"boolean\"}\n", - "\n", - "if Use_Data_augmentation:\n", - " print(bcolors.WARNING+\"Data augmentation enabled\") \n", - "\n", - "\n", - "if not Use_Data_augmentation:\n", - " print(bcolors.WARNING+\"Data augmentation disabled\") " - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "s2NC_-Tuc02W" - }, - "source": [ - "\n", - "## **3.3. Using weights from a pre-trained model as initial weights**\n", - "---\n", - " Here, you can set the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a Cellpose model**. \n", - "\n", - " You can also use the pretrained models already available in Cellpose: \n", - "\n", - "- The cytoplasm model in cellpose is trained on two-channel images, where the first channel is the channel to segment, and the second channel is an optional nuclear channel.\n", - "\n", - "- The cytoplasm2 model is an updated cytoplasm model trained with user-submitted images.\n", - "\n", - "- The nuclear model in cellpose is trained on two-channel images, where the first channel is the channel to segment, and the second channel is always set to an array of zeros.\n", - "\n", - " This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "sLdgQM6Rc7vp" - }, - "source": [ - "# @markdown ##Loading weights from a pre-trained network\n", - "\n", - "Use_pretrained_model = True #@param {type:\"boolean\"}\n", - "\n", - "Pretrained_model = \"Cytoplasm2\" #@param [\"Cytoplasm\",\"Cytoplasm2\", \"Nuclei\", \"Own_model\"]\n", - "\n", - "#@markdown ###If using your own model, please provide the path to the model (not the folder):\n", - "pretrained_model_path = \"\" #@param {type:\"string\"}\n", - "\n", - "# --------------------- Check if we load a previously trained model ------------------------\n", - "\n", - "if Use_pretrained_model == True :\n", - "\n", - " if Pretrained_model == \"Own_model\": \n", - "\n", - " model_to_load = pretrained_model_path\n", - "\n", - " print('The model '+ str(model_to_load) + \"will be used as a starting point\")\n", - "\n", - " if Pretrained_model == \"Cytoplasm\":\n", - " model_to_load = \"cyto\"\n", - " print('The model Cytoplasm will be used as a starting point')\n", - "\n", - " if Pretrained_model == \"Cytoplasm2\":\n", - " model_to_load = \"cyto2\"\n", - " print('The model Cytoplasm2 (cyto2) will be used as a starting point')\n", - "\n", - " if Pretrained_model == \"Nuclei\":\n", - " model_to_load = \"nuclei\"\n", - " print('The model nuclei will be used as a starting point')\n", - "\n", - "else:\n", - " model_to_load = None\n", - " print(bcolors.WARNING+'No pretrained network will be used.')\n", - "\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "qeYZ7PeValfs" - }, - "source": [ - "#**4. Train the network**\n", - "---" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "tsn8WV3Wl0sG", - "cellView": "form" - }, - "source": [ - "#@markdown ##Create the model and dataset objects\n", - "\n", - "\n", - "# Here we check that the model destination folder is empty\n", - "if os.path.exists(model_path+'/'+model_name):\n", - " print(bcolors.WARNING +\"!! WARNING: Model folder already exists and has been removed !!\" + W)\n", - " shutil.rmtree(model_path+'/'+model_name)\n", - "\n", - "#To use cellpose to work we need to organise the data in a way the network can understand\n", - "\n", - "# Here we count the number of files in the training target folder\n", - "Filelist = os.listdir(Training_target)\n", - "number_files = len(Filelist)\n", - "\n", - "# Here we count the number of file to use for validation\n", - "Image_for_validation = int((number_files)*(percentage_validation/100))\n", - "\n", - "Saving_path= \"/content/\"+model_name\n", - "\n", - "if os.path.exists(Saving_path):\n", - " shutil.rmtree(Saving_path)\n", - "os.makedirs(Saving_path)\n", - "\n", - "train_folder = Saving_path+\"/train_folder\"\n", - "os.makedirs(train_folder)\n", - "\n", - "test_folder = Saving_path+\"/test_folder\"\n", - "os.makedirs(test_folder)\n", - "\n", - "index = 0\n", - "\n", - "print('Copying training source data...')\n", - "for f in tqdm(os.listdir(Training_source)): \n", - " short_name = os.path.splitext(f)\n", - "\n", - " if index < Image_for_validation:\n", - " shutil.copyfile(Training_source+\"/\"+f, test_folder+\"/\"+short_name[0]+\"_img.tif\")\n", - " shutil.copyfile(Training_target+\"/\"+f, test_folder+\"/\"+short_name[0]+\"_masks.tif\") \n", - " else:\n", - " shutil.copyfile(Training_source+\"/\"+f, train_folder+\"/\"+short_name[0]+\"_img.tif\")\n", - " shutil.copyfile(Training_target+\"/\"+f, train_folder+\"/\"+short_name[0]+\"_masks.tif\") \n", - " index = index +1\n", - " \n", - "print(\"Done\")\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jRorFe296LgI" - }, - "source": [ - "## **4.2. Start Training**\n", - "---\n", - "\n", - "When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n", - "\n", - "* **CRITICAL NOTE:** Training is currently done using Torch.\n", - "\n", - "* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point.\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "YXUnd3awi6K3", - "cellView": "form" - }, - "source": [ - "#@markdown ##Start training\n", - "\n", - "start = time.time()\n", - "\n", - "if not Use_Data_augmentation:\n", - " #!python -m cellpose --train --use_gpu --mxnet --fast_mode --dir \"$train_folder\" --test_dir \"$test_folder\" --pretrained_model $model_to_load --chan $Training_channel --chan2 $Second_training_channel --n_epochs $number_of_epochs --learning_rate $initial_learning_rate --batch_size $batch_size --img_filter img --mask_filter masks\n", - " !python -m cellpose --train --use_gpu --fast_mode --dir \"$train_folder\" --test_dir \"$test_folder\" --pretrained_model $model_to_load --chan $Training_channel --chan2 $Second_training_channel --n_epochs $number_of_epochs --learning_rate $initial_learning_rate --batch_size $batch_size --img_filter img --mask_filter masks\n", - "\n", - "else:\n", - " #!python -m cellpose --train --use_gpu --mxnet --dir \"$train_folder\" --test_dir \"$test_folder\" --pretrained_model $model_to_load --chan $Training_channel --chan2 $Second_training_channel --n_epochs $number_of_epochs --learning_rate $initial_learning_rate --batch_size $batch_size --img_filter img --mask_filter masks\n", - " !python -m cellpose --train --use_gpu --fast_mode --dir \"$train_folder\" --test_dir \"$test_folder\" --pretrained_model $model_to_load --chan $Training_channel --chan2 $Second_training_channel --n_epochs $number_of_epochs --learning_rate $initial_learning_rate --batch_size $batch_size --img_filter img --mask_filter masks\n", - "\n", - "\n", - "#Settings\n", - "# --check_mkl', action='store_true', help='check if mkl working'\n", - "\n", - "#'--mkldnn', action='store_true', help='for mxnet, force MXNET_SUBGRAPH_BACKEND = \"MKLDNN\"')\n", - "\n", - "#'--train', action='store_true', help='train network using images in dir')\n", - "#'--dir', required=False, help='folder containing data to run or train on')\n", - "# '--mxnet', action='store_true', help='use mxnet')\n", - "# '--img_filter', required=False, default=[], type=str, help='end string for images to run on')\n", - "# '--use_gpu', action='store_true', help='use gpu if mxnet with cuda installed')\n", - "# '--fast_mode', action='store_true', help=\"make code run faster by turning off 4 network averaging\")\n", - "# '--resample', action='store_true', help=\"run dynamics on full image (slower for images with large diameters)\")\n", - "# '--no_interp', action='store_true', help='do not interpolate when running dynamics (was default)')\n", - "# '--do_3D', action='store_true', help='process images as 3D stacks of images (nplanes x nchan x Ly x Lx')\n", - " \n", - "# settings for training\n", - "# parser.add_argument('--train_size', action='store_true', help='train size network at end of training')\n", - "# parser.add_argument('--mask_filter', required=False, default='_masks', type=str, help='end string for masks to run on')\n", - "# parser.add_argument('--test_dir', required=False, default=[], type=str, help='folder containing test data (optional)')\n", - "# parser.add_argument('--learning_rate', required=False, default=0.2, type=float, help='learning rate')\n", - "# parser.add_argument('--n_epochs', required=False, default=500, type=int, help='number of epochs')\n", - "# parser.add_argument('--batch_size', required=False, default=8, type=int, help='batch size')\n", - "# parser.add_argument('--residual_on', required=False, default=1, type=int, help='use residual connections')\n", - "# parser.add_argument('--style_on', required=False, default=1, type=int, help='use style vector')\n", - "# parser.add_argument('--concatenation', required=False, dfault=0, type=int, help='concatenate downsampled layers with upsampled layers (off by default which means they are added)')\n", - "\n", - "\n", - "#Here we copy the model to the result folder after training\n", - "\n", - "if os.path.exists(model_path+'/'+model_name):\n", - " print(bcolors.WARNING +\"!! WARNING: Model folder already exists and has been removed !!\" + W)\n", - " shutil.rmtree(model_path+'/'+model_name)\n", - "destination = shutil.copytree(Saving_path+\"/train_folder/models\", model_path+\"/\"+model_name)\n", - "\n", - "# Displaying the time elapsed for training\n", - "dt = time.time() - start\n", - "mins, sec = divmod(dt, 60) \n", - "hour, mins = divmod(mins, 60) \n", - "print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n", - "print(\"Your model is also available here: \"+str(model_path+\"/\"+model_name))\n", - "\n", - "\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "qvbm9EJGaXr9" - }, - "source": [ - "# **5. Evaluate your model**\n", - "---\n", - "\n", - "This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n", - "\n", - "\n", - "**We highly recommend to perform quality control on all newly trained models.**\n", - "\n", - "\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "IeiU6D2jGDh4", - "cellView": "form" - }, - "source": [ - "# model name and path\n", - "#@markdown ###Do you want to assess the model you just trained ?\n", - "Use_the_current_trained_model = True #@param {type:\"boolean\"}\n", - "\n", - "#@markdown ###If not, indicate which model you want to assess:\n", - "\n", - "QC_model = \"Own_model\" #@param [\"Cytoplasm\",\"Cytoplasm2\", \"Nuclei\", \"Own_model\"]\n", - "\n", - "#@markdown ###If using your own model, please provide the path to the model (not the folder):\n", - "\n", - "QC_model_path = \"\" #@param {type:\"string\"}\n", - "\n", - "#@markdown ###If using the \"Cytoplasm\" or \"Nuclei\" models, please indicate where you want to save the results:\n", - "Saving_path = \"\" #@param {type:\"string\"}\n", - "\n", - "\n", - "if Use_the_current_trained_model :\n", - "\n", - " list_files = os.listdir(model_path+\"/\"+model_name)\n", - " \n", - " QC_model_path = model_path+\"/\"+model_name+\"/\"+list_files[0]\n", - " QC_model = \"Own_model\"\n", - "\n", - " #model = models.CellposeModel(gpu=True, pretrained_model=QC_model_path, torch=False, diam_mean=30.0, net_avg=True, device=None, residual_on=True, style_on=True, concatenation=False)\n", - " model = models.CellposeModel(gpu=True, pretrained_model=QC_model_path, torch=True, diam_mean=30.0, net_avg=True, device=None, residual_on=True, style_on=True, concatenation=False)\n", - "\n", - " QC_model_folder = os.path.dirname(QC_model_path)\n", - " QC_model_name = os.path.basename(QC_model_folder)\n", - " Saving_path = QC_model_folder\n", - "\n", - " print(\"The \"+str(QC_model_name)+\" model will be evaluated\")\n", - "\n", - "if not Use_the_current_trained_model:\n", - "\n", - " if QC_model == \"Cytoplasm\":\n", - " model = models.Cellpose(gpu=True, model_type=\"cyto\")\n", - " QC_model_folder = Saving_path\n", - " QC_model_name = \"Cytoplasm\"\n", - "\n", - " print('The model \"Cytoplasm\" will be evaluated')\n", - "\n", - " if QC_model == \"Cytoplasm2\":\n", - " model = models.Cellpose(gpu=True, model_type=\"cyto2\")\n", - " QC_model_folder = Saving_path\n", - " QC_model_name = \"Cytoplasm2\"\n", - "\n", - " print('The model \"Cytoplasm\" will be evaluated')\n", - "\n", - " if QC_model == \"Nuclei\":\n", - " model = models.Cellpose(gpu=True, model_type=\"nuclei\")\n", - "\n", - " QC_model_folder = Saving_path\n", - " QC_model_name = \"Nuclei\"\n", - "\n", - " print('The model \"Nuclei\" will be evaluated')\n", - " \n", - " if QC_model == \"Own_model\":\n", - "\n", - " if os.path.exists(QC_model_path):\n", - " model = models.CellposeModel(gpu=True, pretrained_model=QC_model_path, torch=True, diam_mean=30.0, net_avg=True, device=None, residual_on=True, style_on=True, concatenation=False)\n", - " \n", - " QC_model_folder = os.path.dirname(QC_model_path)\n", - " Saving_path = QC_model_folder\n", - " QC_model_name = os.path.basename(QC_model_folder)\n", - " print(\"The \"+str(QC_model_name)+\" model will be evaluated\")\n", - " \n", - " else: \n", - " print(bcolors.WARNING+'!! WARNING: The chosen model does not exist !!')\n", - " print('Please make sure you provide a valid model path and model name before proceeding further.')\n", - "\n", - "#Here we make the folder to save the resuslts if it does not exists\n", - "\n", - "if not Saving_path == \"\":\n", - " if os.path.exists(QC_model_folder) == False:\n", - " os.makedirs(QC_model_folder)\n", - "else:\n", - " print(bcolors.WARNING+'!! WARNING: Indicate where you want to save the results')\n", - "\n", - "\n", - "# Here we load the def that perform the QC, code taken from StarDist /~https://github.com/mpicbg-csbd/stardist/blob/master/stardist/matching.py\n", - "\n", - "import numpy as np\n", - "from numba import jit\n", - "from tqdm import tqdm\n", - "from scipy.optimize import linear_sum_assignment\n", - "from collections import namedtuple\n", - "\n", - "\n", - "matching_criteria = dict()\n", - "\n", - "def label_are_sequential(y):\n", - " \"\"\" returns true if y has only sequential labels from 1... \"\"\"\n", - " labels = np.unique(y)\n", - " return (set(labels)-{0}) == set(range(1,1+labels.max()))\n", - "\n", - "\n", - "def is_array_of_integers(y):\n", - " return isinstance(y,np.ndarray) and np.issubdtype(y.dtype, np.integer)\n", - "\n", - "\n", - "def _check_label_array(y, name=None, check_sequential=False):\n", - " err = ValueError(\"{label} must be an array of {integers}.\".format(\n", - " label = 'labels' if name is None else name,\n", - " integers = ('sequential ' if check_sequential else '') + 'non-negative integers',\n", - " ))\n", - " is_array_of_integers(y) or print(\"An error occured\")\n", - " if check_sequential:\n", - " label_are_sequential(y) or print(\"An error occured\")\n", - " else:\n", - " y.min() >= 0 or print(\"An error occured\")\n", - " return True\n", - "\n", - "\n", - "def label_overlap(x, y, check=True):\n", - " if check:\n", - " _check_label_array(x,'x',True)\n", - " _check_label_array(y,'y',True)\n", - " x.shape == y.shape or _raise(ValueError(\"x and y must have the same shape\"))\n", - " return _label_overlap(x, y)\n", - "\n", - "@jit(nopython=True)\n", - "def _label_overlap(x, y):\n", - " x = x.ravel()\n", - " y = y.ravel()\n", - " overlap = np.zeros((1+x.max(),1+y.max()), dtype=np.uint)\n", - " for i in range(len(x)):\n", - " overlap[x[i],y[i]] += 1\n", - " return overlap\n", - "\n", - "\n", - "def intersection_over_union(overlap):\n", - " _check_label_array(overlap,'overlap')\n", - " if np.sum(overlap) == 0:\n", - " return overlap\n", - " n_pixels_pred = np.sum(overlap, axis=0, keepdims=True)\n", - " n_pixels_true = np.sum(overlap, axis=1, keepdims=True)\n", - " return overlap / (n_pixels_pred + n_pixels_true - overlap)\n", - "\n", - "matching_criteria['iou'] = intersection_over_union\n", - "\n", - "\n", - "def intersection_over_true(overlap):\n", - " _check_label_array(overlap,'overlap')\n", - " if np.sum(overlap) == 0:\n", - " return overlap\n", - " n_pixels_true = np.sum(overlap, axis=1, keepdims=True)\n", - " return overlap / n_pixels_true\n", - "\n", - "matching_criteria['iot'] = intersection_over_true\n", - "\n", - "\n", - "def intersection_over_pred(overlap):\n", - " _check_label_array(overlap,'overlap')\n", - " if np.sum(overlap) == 0:\n", - " return overlap\n", - " n_pixels_pred = np.sum(overlap, axis=0, keepdims=True)\n", - " return overlap / n_pixels_pred\n", - "\n", - "matching_criteria['iop'] = intersection_over_pred\n", - "\n", - "\n", - "def precision(tp,fp,fn):\n", - " return tp/(tp+fp) if tp > 0 else 0\n", - "def recall(tp,fp,fn):\n", - " return tp/(tp+fn) if tp > 0 else 0\n", - "def accuracy(tp,fp,fn):\n", - " # also known as \"average precision\" (?)\n", - " # -> https://www.kaggle.com/c/data-science-bowl-2018#evaluation\n", - " return tp/(tp+fp+fn) if tp > 0 else 0\n", - "def f1(tp,fp,fn):\n", - " # also known as \"dice coefficient\"\n", - " return (2*tp)/(2*tp+fp+fn) if tp > 0 else 0\n", - "\n", - "\n", - "def _safe_divide(x,y):\n", - " return x/y if y>0 else 0.0\n", - "\n", - "def matching(y_true, y_pred, thresh=0.5, criterion='iou', report_matches=False):\n", - " \"\"\"Calculate detection/instance segmentation metrics between ground truth and predicted label images.\n", - " Currently, the following metrics are implemented:\n", - " 'fp', 'tp', 'fn', 'precision', 'recall', 'accuracy', 'f1', 'criterion', 'thresh', 'n_true', 'n_pred', 'mean_true_score', 'mean_matched_score', 'panoptic_quality'\n", - " Corresponding objects of y_true and y_pred are counted as true positives (tp), false positives (fp), and false negatives (fn)\n", - " whether their intersection over union (IoU) >= thresh (for criterion='iou', which can be changed)\n", - " * mean_matched_score is the mean IoUs of matched true positives\n", - " * mean_true_score is the mean IoUs of matched true positives but normalized by the total number of GT objects\n", - " * panoptic_quality defined as in Eq. 1 of Kirillov et al. \"Panoptic Segmentation\", CVPR 2019\n", - " Parameters\n", - " ----------\n", - " y_true: ndarray\n", - " ground truth label image (integer valued)\n", - " predicted label image (integer valued)\n", - " thresh: float\n", - " threshold for matching criterion (default 0.5)\n", - " criterion: string\n", - " matching criterion (default IoU)\n", - " report_matches: bool\n", - " if True, additionally calculate matched_pairs and matched_scores (note, that this returns even gt-pred pairs whose scores are below 'thresh')\n", - " Returns\n", - " -------\n", - " Matching object with different metrics as attributes\n", - " Examples\n", - " --------\n", - " >>> y_true = np.zeros((100,100), np.uint16)\n", - " >>> y_true[10:20,10:20] = 1\n", - " >>> y_pred = np.roll(y_true,5,axis = 0)\n", - " >>> stats = matching(y_true, y_pred)\n", - " >>> print(stats)\n", - " Matching(criterion='iou', thresh=0.5, fp=1, tp=0, fn=1, precision=0, recall=0, accuracy=0, f1=0, n_true=1, n_pred=1, mean_true_score=0.0, mean_matched_score=0.0, panoptic_quality=0.0)\n", - " \"\"\"\n", - " _check_label_array(y_true,'y_true')\n", - " _check_label_array(y_pred,'y_pred')\n", - " y_true.shape == y_pred.shape or _raise(ValueError(\"y_true ({y_true.shape}) and y_pred ({y_pred.shape}) have different shapes\".format(y_true=y_true, y_pred=y_pred)))\n", - " criterion in matching_criteria or _raise(ValueError(\"Matching criterion '%s' not supported.\" % criterion))\n", - " if thresh is None: thresh = 0\n", - " thresh = float(thresh) if np.isscalar(thresh) else map(float,thresh)\n", - "\n", - " y_true, _, map_rev_true = relabel_sequential(y_true)\n", - " y_pred, _, map_rev_pred = relabel_sequential(y_pred)\n", - "\n", - " overlap = label_overlap(y_true, y_pred, check=False)\n", - " scores = matching_criteria[criterion](overlap)\n", - " assert 0 <= np.min(scores) <= np.max(scores) <= 1\n", - "\n", - " # ignoring background\n", - " scores = scores[1:,1:]\n", - " n_true, n_pred = scores.shape\n", - " n_matched = min(n_true, n_pred)\n", - "\n", - " def _single(thr):\n", - " not_trivial = n_matched > 0 and np.any(scores >= thr)\n", - " if not_trivial:\n", - " # compute optimal matching with scores as tie-breaker\n", - " costs = -(scores >= thr).astype(float) - scores / (2*n_matched)\n", - " true_ind, pred_ind = linear_sum_assignment(costs)\n", - " assert n_matched == len(true_ind) == len(pred_ind)\n", - " match_ok = scores[true_ind,pred_ind] >= thr\n", - " tp = np.count_nonzero(match_ok)\n", - " else:\n", - " tp = 0\n", - " fp = n_pred - tp\n", - " fn = n_true - tp\n", - " # assert tp+fp == n_pred\n", - " # assert tp+fn == n_true\n", - "\n", - " # the score sum over all matched objects (tp)\n", - " sum_matched_score = np.sum(scores[true_ind,pred_ind][match_ok]) if not_trivial else 0.0\n", - "\n", - " # the score average over all matched objects (tp)\n", - " mean_matched_score = _safe_divide(sum_matched_score, tp)\n", - " # the score average over all gt/true objects\n", - " mean_true_score = _safe_divide(sum_matched_score, n_true)\n", - " panoptic_quality = _safe_divide(sum_matched_score, tp+fp/2+fn/2)\n", - "\n", - " stats_dict = dict (\n", - " criterion = criterion,\n", - " thresh = thr,\n", - " fp = fp,\n", - " tp = tp,\n", - " fn = fn,\n", - " precision = precision(tp,fp,fn),\n", - " recall = recall(tp,fp,fn),\n", - " accuracy = accuracy(tp,fp,fn),\n", - " f1 = f1(tp,fp,fn),\n", - " n_true = n_true,\n", - " n_pred = n_pred,\n", - " mean_true_score = mean_true_score,\n", - " mean_matched_score = mean_matched_score,\n", - " panoptic_quality = panoptic_quality,\n", - " )\n", - " if bool(report_matches):\n", - " if not_trivial:\n", - " stats_dict.update (\n", - " # int() to be json serializable\n", - " matched_pairs = tuple((int(map_rev_true[i]),int(map_rev_pred[j])) for i,j in zip(1+true_ind,1+pred_ind)),\n", - " matched_scores = tuple(scores[true_ind,pred_ind]),\n", - " matched_tps = tuple(map(int,np.flatnonzero(match_ok))),\n", - " )\n", - " else:\n", - " stats_dict.update (\n", - " matched_pairs = (),\n", - " matched_scores = (),\n", - " matched_tps = (),\n", - " )\n", - " return namedtuple('Matching',stats_dict.keys())(*stats_dict.values())\n", - "\n", - " return _single(thresh) if np.isscalar(thresh) else tuple(map(_single,thresh))\n", - "\n", - "\n", - "\n", - "def matching_dataset(y_true, y_pred, thresh=0.5, criterion='iou', by_image=False, show_progress=True, parallel=False):\n", - " \"\"\"matching metrics for list of images, see `stardist.matching.matching`\n", - " \"\"\"\n", - " len(y_true) == len(y_pred) or _raise(ValueError(\"y_true and y_pred must have the same length.\"))\n", - " return matching_dataset_lazy (\n", - " tuple(zip(y_true,y_pred)), thresh=thresh, criterion=criterion, by_image=by_image, show_progress=show_progress, parallel=parallel,\n", - " )\n", - "\n", - "\n", - "\n", - "def matching_dataset_lazy(y_gen, thresh=0.5, criterion='iou', by_image=False, show_progress=True, parallel=False):\n", - "\n", - " expected_keys = set(('fp', 'tp', 'fn', 'precision', 'recall', 'accuracy', 'f1', 'criterion', 'thresh', 'n_true', 'n_pred', 'mean_true_score', 'mean_matched_score', 'panoptic_quality'))\n", - "\n", - " single_thresh = False\n", - " if np.isscalar(thresh):\n", - " single_thresh = True\n", - " thresh = (thresh,)\n", - "\n", - " tqdm_kwargs = {}\n", - " tqdm_kwargs['disable'] = not bool(show_progress)\n", - " if int(show_progress) > 1:\n", - " tqdm_kwargs['total'] = int(show_progress)\n", - "\n", - " # compute matching stats for every pair of label images\n", - " if parallel:\n", - " from concurrent.futures import ThreadPoolExecutor\n", - " fn = lambda pair: matching(*pair, thresh=thresh, criterion=criterion, report_matches=False)\n", - " with ThreadPoolExecutor() as pool:\n", - " stats_all = tuple(pool.map(fn, tqdm(y_gen,**tqdm_kwargs)))\n", - " else:\n", - " stats_all = tuple (\n", - " matching(y_t, y_p, thresh=thresh, criterion=criterion, report_matches=False)\n", - " for y_t,y_p in tqdm(y_gen,**tqdm_kwargs)\n", - " )\n", - "\n", - " # accumulate results over all images for each threshold separately\n", - " n_images, n_threshs = len(stats_all), len(thresh)\n", - " accumulate = [{} for _ in range(n_threshs)]\n", - " for stats in stats_all:\n", - " for i,s in enumerate(stats):\n", - " acc = accumulate[i]\n", - " for k,v in s._asdict().items():\n", - " if k == 'mean_true_score' and not bool(by_image):\n", - " # convert mean_true_score to \"sum_matched_score\"\n", - " acc[k] = acc.setdefault(k,0) + v * s.n_true\n", - " else:\n", - " try:\n", - " acc[k] = acc.setdefault(k,0) + v\n", - " except TypeError:\n", - " pass\n", - "\n", - " # normalize/compute 'precision', 'recall', 'accuracy', 'f1'\n", - " for thr,acc in zip(thresh,accumulate):\n", - " set(acc.keys()) == expected_keys or _raise(ValueError(\"unexpected keys\"))\n", - " acc['criterion'] = criterion\n", - " acc['thresh'] = thr\n", - " acc['by_image'] = bool(by_image)\n", - " if bool(by_image):\n", - " for k in ('precision', 'recall', 'accuracy', 'f1', 'mean_true_score', 'mean_matched_score', 'panoptic_quality'):\n", - " acc[k] /= n_images\n", - " else:\n", - " tp, fp, fn, n_true = acc['tp'], acc['fp'], acc['fn'], acc['n_true']\n", - " sum_matched_score = acc['mean_true_score']\n", - "\n", - " mean_matched_score = _safe_divide(sum_matched_score, tp)\n", - " mean_true_score = _safe_divide(sum_matched_score, n_true)\n", - " panoptic_quality = _safe_divide(sum_matched_score, tp+fp/2+fn/2)\n", - "\n", - " acc.update(\n", - " precision = precision(tp,fp,fn),\n", - " recall = recall(tp,fp,fn),\n", - " accuracy = accuracy(tp,fp,fn),\n", - " f1 = f1(tp,fp,fn),\n", - " mean_true_score = mean_true_score,\n", - " mean_matched_score = mean_matched_score,\n", - " panoptic_quality = panoptic_quality,\n", - " )\n", - "\n", - " accumulate = tuple(namedtuple('DatasetMatching',acc.keys())(*acc.values()) for acc in accumulate)\n", - " return accumulate[0] if single_thresh else accumulate\n", - "\n", - "\n", - "\n", - "# copied from scikit-image master for now (remove when part of a release)\n", - "def relabel_sequential(label_field, offset=1):\n", - " \"\"\"Relabel arbitrary labels to {`offset`, ... `offset` + number_of_labels}.\n", - " This function also returns the forward map (mapping the original labels to\n", - " the reduced labels) and the inverse map (mapping the reduced labels back\n", - " to the original ones).\n", - " Parameters\n", - " ----------\n", - " label_field : numpy array of int, arbitrary shape\n", - " An array of labels, which must be non-negative integers.\n", - " offset : int, optional\n", - " The return labels will start at `offset`, which should be\n", - " strictly positive.\n", - " Returns\n", - " -------\n", - " relabeled : numpy array of int, same shape as `label_field`\n", - " The input label field with labels mapped to\n", - " {offset, ..., number_of_labels + offset - 1}.\n", - " The data type will be the same as `label_field`, except when\n", - " offset + number_of_labels causes overflow of the current data type.\n", - " forward_map : numpy array of int, shape ``(label_field.max() + 1,)``\n", - " The map from the original label space to the returned label\n", - " space. Can be used to re-apply the same mapping. See examples\n", - " for usage. The data type will be the same as `relabeled`.\n", - " inverse_map : 1D numpy array of int, of length offset + number of labels\n", - " The map from the new label space to the original space. This\n", - " can be used to reconstruct the original label field from the\n", - " relabeled one. The data type will be the same as `relabeled`.\n", - " Notes\n", - " -----\n", - " The label 0 is assumed to denote the background and is never remapped.\n", - " The forward map can be extremely big for some inputs, since its\n", - " length is given by the maximum of the label field. However, in most\n", - " situations, ``label_field.max()`` is much smaller than\n", - " ``label_field.size``, and in these cases the forward map is\n", - " guaranteed to be smaller than either the input or output images.\n", - " Examples\n", - " --------\n", - " >>> from skimage.segmentation import relabel_sequential\n", - " >>> label_field = np.array([1, 1, 5, 5, 8, 99, 42])\n", - " >>> relab, fw, inv = relabel_sequential(label_field)\n", - " >>> relab\n", - " array([1, 1, 2, 2, 3, 5, 4])\n", - " >>> fw\n", - " array([0, 1, 0, 0, 0, 2, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", - " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0,\n", - " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", - " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", - " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5])\n", - " >>> inv\n", - " array([ 0, 1, 5, 8, 42, 99])\n", - " >>> (fw[label_field] == relab).all()\n", - " True\n", - " >>> (inv[relab] == label_field).all()\n", - " True\n", - " >>> relab, fw, inv = relabel_sequential(label_field, offset=5)\n", - " >>> relab\n", - " array([5, 5, 6, 6, 7, 9, 8])\n", - " \"\"\"\n", - " offset = int(offset)\n", - " if offset <= 0:\n", - " raise ValueError(\"Offset must be strictly positive.\")\n", - " if np.min(label_field) < 0:\n", - " raise ValueError(\"Cannot relabel array that contains negative values.\")\n", - " max_label = int(label_field.max()) # Ensure max_label is an integer\n", - " if not np.issubdtype(label_field.dtype, np.integer):\n", - " new_type = np.min_scalar_type(max_label)\n", - " label_field = label_field.astype(new_type)\n", - " labels = np.unique(label_field)\n", - " labels0 = labels[labels != 0]\n", - " new_max_label = offset - 1 + len(labels0)\n", - " new_labels0 = np.arange(offset, new_max_label + 1)\n", - " output_type = label_field.dtype\n", - " required_type = np.min_scalar_type(new_max_label)\n", - " if np.dtype(required_type).itemsize > np.dtype(label_field.dtype).itemsize:\n", - " output_type = required_type\n", - " forward_map = np.zeros(max_label + 1, dtype=output_type)\n", - " forward_map[labels0] = new_labels0\n", - " inverse_map = np.zeros(new_max_label + 1, dtype=output_type)\n", - " inverse_map[offset:] = labels0\n", - " relabeled = forward_map[label_field]\n", - " return relabeled, forward_map, inverse_map\n", - "\n", - "\n", - "\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Tbv6DpxZjVN3" - }, - "source": [ - "## **5.1. Inspection of the loss function**\n", - "---\n", - "\n", - "First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n", - "\n", - "**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n", - "\n", - "**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n", - "\n", - "During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n", - "\n", - "Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased.\n", - "\n", - "\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "cellView": "form", - "id": "jtSv-B0AjX8j" - }, - "source": [ - "#@markdown ###Not implemented yet" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "s2VXDuiOF7r4" - }, - "source": [ - "## **5.2. Error mapping and quality metrics estimation**\n", - "---\n", - "This section will calculate the Intersection over Union score for all the images provided in the Source_QC_folder and Target_QC_folder ! The result for one of the image will also be displayed.\n", - "\n", - "The **Intersection over Union** (IuO) metric is a method that can be used to quantify the percent overlap between the target mask and your prediction output. **Therefore, the closer to 1, the better the performance.** This metric can be used to assess the quality of your model to accurately predict nuclei. \n", - "\n", - "Here, the IuO is both calculated over the whole image and on a per-object basis. The value displayed below is the IuO value calculated over the entire image. The IuO value calculated on a per-object basis is used to calculate the other metrics displayed.\n", - "\n", - "“n_true” refers to the number of objects present in the ground truth image. “n_pred” refers to the number of objects present in the predicted image. \n", - "\n", - "When a segmented object has an IuO value above 0.5 (compared to the corresponding ground truth), it is then considered a true positive. The number of “**true positives**” is available in the table below. The number of “false positive” is then defined as “**false positive**” = “n_pred” - “true positive”. The number of “false negative” is defined as “false negative” = “n_true” - “true positive”.\n", - "\n", - "The mean_matched_score is the mean IoUs of matched true positives. The mean_true_score is the mean IoUs of matched true positives but normalized by the total number of ground truth objects. The panoptic_quality is calculated as described by [Kirillov et al. 2019](https://arxiv.org/abs/1801.00868).\n", - "\n", - "For more information about the other metric displayed, please consult the SI of the paper describing ZeroCostDL4Mic.\n", - "\n", - " The results can be found in the \"*Quality Control*\" folder which is located inside your \"model_folder\".\n", - "\n", - "**`model_choice`:** Choose the model to use to make predictions. This model needs to be a Cellpose model. You can also use the pretrained models already available in cellpose: \n", - "\n", - "- The cytoplasm model in cellpose is trained on two-channel images, where the first channel is the channel to segment, and the second channel is an optional nuclear channel.\n", - "- The cytoplasm2 model is an updated cytoplasm model trained with user-submitted images.\n", - "\n", - "- The nuclear model in cellpose is trained on two-channel images, where the first channel is the channel to segment, and the second channel is always set to an array of zeros.\n", - "\n", - "**`Channel_to_segment`:** Choose the channel to segment. If using single-channel grayscale images, choose \"Grayscale\".\n", - "\n", - "**`Nuclear_channel`:** If you are using a model that segment the \"cytoplasm\", you can use a nuclear channel to aid the segmentation. \n", - "\n", - "**`Object_diameter`:** Indicate the diameter of the objects (cells or Nuclei) you want to segment (in pixel). If you input \"0\", this parameter will be estimated automatically for each of your images.\n", - "\n", - "**`Flow_threshold`:** This parameter controls the maximum allowed error of the flows for each mask. Increase this threshold if cellpose is not returning as many masks as you'd expect. Similarly, decrease this threshold if cellpose is returning too many ill-shaped masks. **Default value: 0.4**\n", - "\n", - "**`Cell_probability_threshold`:** The pixels greater than the Cell_probability_threshold are used to run dynamics and determine masks. Decrease this threshold if cellpose is not returning as many masks as you'd expect. Similarly, increase this threshold if cellpose is returning too many masks, particularly from dim areas. **Default value: 0.0**" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "BUrTuonhEH5J", - "cellView": "form" - }, - "source": [ - "#@markdown ##Choose the folders that contain your Quality Control dataset\n", - "\n", - "Source_QC_folder = \"\" #@param{type:\"string\"}\n", - "Target_QC_folder = \"\" #@param{type:\"string\"}\n", - "\n", - "Channel_to_segment= \"Grayscale\" #@param [\"Grayscale\", \"Blue\", \"Green\", \"Red\"]\n", - "\n", - "# @markdown ###If you chose the model \"cytoplasm\" indicate if you also have a nuclear channel that can be used to aid the segmentation.\n", - "\n", - "Nuclear_channel= \"None\" #@param [\"None\", \"Blue\", \"Green\", \"Red\"]\n", - "\n", - "#@markdown ### Segmentation parameters:\n", - "Object_diameter = 0#@param {type:\"number\"}\n", - "\n", - "Flow_threshold = 0.4 #@param {type:\"slider\", min:0.1, max:1.1, step:0.1}\n", - "Cell_probability_threshold=0 #@param {type:\"slider\", min:-6, max:6, step:1}\n", - "\n", - "if Object_diameter is 0:\n", - " Object_diameter = None\n", - " print(\"The cell size will be estimated automatically for each image\")\n", - "\n", - "\n", - "# Find the number of channel in the input image\n", - "\n", - "random_choice = random.choice(os.listdir(Source_QC_folder))\n", - "x = io.imread(Source_QC_folder+\"/\"+random_choice)\n", - "n_channel = 1 if x.ndim == 2 else x.shape[-1]\n", - "\n", - "if Channel_to_segment == \"Grayscale\":\n", - " segment_channel = 0\n", - "\n", - " if not n_channel == 1:\n", - " print(bcolors.WARNING +\"!! WARNING: your image has more than one channel, choose which channel you want to use for QC !!\")\n", - "\n", - "if Channel_to_segment == \"Blue\":\n", - " segment_channel = 3\n", - "\n", - "if Channel_to_segment == \"Green\":\n", - " segment_channel = 2\n", - "\n", - "if Channel_to_segment == \"Red\":\n", - " segment_channel = 1\n", - "\n", - "if Nuclear_channel == \"Blue\":\n", - " nuclear_channel = 3\n", - "\n", - "if Nuclear_channel == \"Green\":\n", - " nuclear_channel = 2\n", - "\n", - "if Nuclear_channel == \"Red\":\n", - " nuclear_channel = 1\n", - "\n", - "if Nuclear_channel == \"None\":\n", - " nuclear_channel = 0\n", - "\n", - "if QC_model == \"Cytoplasm\": \n", - " channels=[segment_channel,nuclear_channel]\n", - "\n", - "if QC_model == \"Cytoplasm2\": \n", - " channels=[segment_channel,nuclear_channel]\n", - " \n", - "if QC_model == \"Nuclei\":\n", - " channels=[segment_channel,0]\n", - "\n", - "if QC_model == \"Own_model\":\n", - " channels=[segment_channel,nuclear_channel]\n", - "\n", - "#Create a quality control Folder and check if the folder already exist\n", - "if os.path.exists(QC_model_folder+\"/Quality Control\") == False:\n", - " os.makedirs(QC_model_folder+\"/Quality Control\")\n", - "\n", - "if os.path.exists(QC_model_folder+\"/Quality Control/Prediction\"):\n", - " shutil.rmtree(QC_model_folder+\"/Quality Control/Prediction\")\n", - "os.makedirs(QC_model_folder+\"/Quality Control/Prediction\")\n", - "\n", - "# Here we need to make predictions\n", - "\n", - "for name in os.listdir(Source_QC_folder):\n", - " \n", - " print(\"Performing prediction on: \"+name)\n", - " image = io.imread(Source_QC_folder+\"/\"+name) \n", - "\n", - " short_name = os.path.splitext(name)\n", - " \n", - " if QC_model == \"Own_model\":\n", - " masks, flows, styles = model.eval(image, diameter=Object_diameter, flow_threshold=Flow_threshold,cellprob_threshold=Cell_probability_threshold, channels=channels)\n", - " else:\n", - " masks, flows, styles, diams = model.eval(image, diameter=Object_diameter, flow_threshold=Flow_threshold,cellprob_threshold=Cell_probability_threshold, channels=channels)\n", - " \n", - " os.chdir(QC_model_folder+\"/Quality Control/Prediction\")\n", - " imsave(str(short_name[0])+\".tif\", masks, compress=ZIP_DEFLATED) \n", - " \n", - "# Here we start testing the differences between GT and predicted masks\n", - "\n", - "with open(QC_model_folder+\"/Quality Control/Quality_Control for \"+QC_model_name+\".csv\", \"w\", newline='') as file:\n", - " writer = csv.writer(file, delimiter=\",\")\n", - " writer.writerow([\"image\",\"Prediction v. GT Intersection over Union\", \"false positive\", \"true positive\", \"false negative\", \"precision\", \"recall\", \"accuracy\", \"f1 score\", \"n_true\", \"n_pred\", \"mean_true_score\", \"mean_matched_score\", \"panoptic_quality\"]) \n", - "\n", - "# define the images\n", - "\n", - " for n in os.listdir(Source_QC_folder):\n", - " \n", - " if not os.path.isdir(os.path.join(Source_QC_folder,n)):\n", - " print('Running QC on: '+n)\n", - " test_input = io.imread(os.path.join(Source_QC_folder,n))\n", - " test_prediction = io.imread(os.path.join(QC_model_folder+\"/Quality Control/Prediction\",n))\n", - " test_ground_truth_image = io.imread(os.path.join(Target_QC_folder, n))\n", - "\n", - " # Calculate the matching (with IoU threshold `thresh`) and all metrics\n", - "\n", - " stats = matching(test_ground_truth_image, test_prediction, thresh=0.5)\n", - " \n", - "\n", - " #Convert pixel values to 0 or 255\n", - " test_prediction_0_to_255 = test_prediction\n", - " test_prediction_0_to_255[test_prediction_0_to_255>0] = 255\n", - "\n", - " #Convert pixel values to 0 or 255\n", - " test_ground_truth_0_to_255 = test_ground_truth_image\n", - " test_ground_truth_0_to_255[test_ground_truth_0_to_255>0] = 255\n", - "\n", - "\n", - " # Intersection over Union metric\n", - "\n", - " intersection = np.logical_and(test_ground_truth_0_to_255, test_prediction_0_to_255)\n", - " union = np.logical_or(test_ground_truth_0_to_255, test_prediction_0_to_255)\n", - " iou_score = np.sum(intersection) / np.sum(union)\n", - " writer.writerow([n, str(iou_score), str(stats.fp), str(stats.tp), str(stats.fn), str(stats.precision), str(stats.recall), str(stats.accuracy), str(stats.f1), str(stats.n_true), str(stats.n_pred), str(stats.mean_true_score), str(stats.mean_matched_score), str(stats.panoptic_quality)])\n", - "\n", - "from tabulate import tabulate\n", - "\n", - "df = pd.read_csv (QC_model_folder+\"/Quality Control/Quality_Control for \"+QC_model_name+\".csv\")\n", - "print(tabulate(df, headers='keys', tablefmt='psql'))\n", - "\n", - "\n", - "from astropy.visualization import simple_norm\n", - "\n", - "# ------------- For display ------------\n", - "print('--------------------------------------------------------------')\n", - "@interact\n", - "def show_QC_results(file = os.listdir(Source_QC_folder)):\n", - " \n", - "\n", - " plt.figure(figsize=(25,5))\n", - " if n_channel > 1:\n", - " source_image = io.imread(os.path.join(Source_QC_folder, file))\n", - " if n_channel == 1:\n", - " source_image = io.imread(os.path.join(Source_QC_folder, file), as_gray = True)\n", - "\n", - " target_image = io.imread(os.path.join(Target_QC_folder, file), as_gray = True)\n", - " prediction = io.imread(QC_model_folder+\"/Quality Control/Prediction/\"+file, as_gray = True)\n", - "\n", - " stats = matching(prediction, target_image, thresh=0.5)\n", - "\n", - " target_image_mask = np.empty_like(target_image)\n", - " target_image_mask[target_image > 0] = 255\n", - " target_image_mask[target_image == 0] = 0\n", - " \n", - " prediction_mask = np.empty_like(prediction)\n", - " prediction_mask[prediction > 0] = 255\n", - " prediction_mask[prediction == 0] = 0\n", - "\n", - " intersection = np.logical_and(target_image_mask, prediction_mask)\n", - " union = np.logical_or(target_image_mask, prediction_mask)\n", - " iou_score = np.sum(intersection) / np.sum(union)\n", - "\n", - " norm = simple_norm(source_image, percent = 99)\n", - "\n", - " #Input\n", - " plt.subplot(1,4,1)\n", - " plt.axis('off')\n", - " if n_channel > 1:\n", - " plt.imshow(source_image)\n", - " if n_channel == 1:\n", - " plt.imshow(source_image, aspect='equal', norm=norm, cmap='magma', interpolation='nearest')\n", - " plt.title('Input')\n", - "\n", - " #Ground-truth\n", - " plt.subplot(1,4,2)\n", - " plt.axis('off')\n", - " plt.imshow(target_image_mask, aspect='equal', cmap='Greens')\n", - " plt.title('Ground Truth')\n", - "\n", - " #Prediction\n", - " plt.subplot(1,4,3)\n", - " plt.axis('off')\n", - " plt.imshow(prediction_mask, aspect='equal', cmap='Purples')\n", - " plt.title('Prediction')\n", - "\n", - " #Overlay\n", - " plt.subplot(1,4,4)\n", - " plt.axis('off')\n", - " plt.imshow(target_image_mask, cmap='Greens')\n", - " plt.imshow(prediction_mask, alpha=0.5, cmap='Purples')\n", - " plt.title('Ground Truth and Prediction, Intersection over Union:'+str(round(iou_score,3 )));\n", - " plt.savefig(QC_model_folder+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n", - "\n", - "\n", - "#qc_pdf_export()\n", - "\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Io62PUMLagFS" - }, - "source": [ - "# **6. Using the trained model**\n", - "\n", - "---\n", - "\n", - "In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "E29LWfWpjkZU" - }, - "source": [ - "\n", - "\n", - "## **6.1 Generate prediction(s) from unseen dataset**\n", - "---\n", - "\n", - "The current trained model (from section 4.3) can now be used to process images. If an older model needs to be used, please untick the **Use_the_current_trained_model** box and enter the model's name and path to use. Predicted output images are saved in your **Prediction_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n", - "\n", - "**`Data_folder`:** This folder should contain the images that you want to predict using the network that you will train.\n", - "\n", - "**`Result_folder`:** This folder will contain the predicted output.\n", - "\n", - "**`Data_type`:** Please indicate if the images you want to predict are single images or stacks\n", - "\n", - "**`model_choice`:** Choose the model to use to make predictions. This model needs to be a Cellpose model. You can also use the pretrained models already available in cellpose: \n", - "\n", - "- The cytoplasm model in cellpose is trained on two-channel images, where the first channel is the channel to segment, and the second channel is an optional nuclear channel.\n", - "- The cytoplasm2 model is an updated cytoplasm model trained with user-submitted images.\n", - "\n", - "- The nuclear model in cellpose is trained on two-channel images, where the first channel is the channel to segment, and the second channel is always set to an array of zeros.\n", - "\n", - "**`Channel_to_segment`:** Choose the channel to segment. If using single-channel grayscale images, choose \"Grayscale\".\n", - "\n", - "**`Nuclear_channel`:** If you are using a model that segment the \"cytoplasm\", you can use a nuclear channel to aid the segmentation. \n", - "\n", - "**`Object_diameter`:** Indicate the diameter of the objects (cells or Nuclei) you want to segment (in pixel). If you input \"0\", this parameter will be estimated automatically for each of your images.\n", - "\n", - "**`Flow_threshold`:** This parameter controls the maximum allowed error of the flows for each mask. Increase this threshold if cellpose is not returning as many masks as you'd expect. Similarly, decrease this threshold if cellpose is returning too many ill-shaped masks. **Default value: 0.4**\n", - "\n", - "**`Cell_probability_threshold`:** The pixels greater than the Cell_probability_threshold are used to run dynamics and determine masks. Decrease this threshold if cellpose is not returning as many masks as you'd expect. Similarly, increase this threshold if cellpose is returning too many masks, particularly from dim areas. **Default value: 0.0**\n", - "\n", - "**IMPORTANT:** One example result will be displayed first so that you can assess the quality of the prediction and change your settings accordingly. Once the most suitable settings have been chosen, press on the yellow button \"process your images\".\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "mfgvhMk2xid9", - "cellView": "form" - }, - "source": [ - "\n", - "#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then play the cell to predict outputs from your unseen images.\n", - "\n", - "Data_folder = \"\" #@param {type:\"string\"}\n", - "Result_folder = \"\" #@param {type:\"string\"}\n", - "\n", - "#@markdown ###Are your data single images or stacks?\n", - "\n", - "Data_type = \"Single_Images\" #@param [\"Single_Images\", \"Stacks\"]\n", - "\n", - "#@markdown ###What model do you want to use?\n", - "\n", - "model_choice = \"Cytoplasm2\" #@param [\"Cytoplasm\",\"Cytoplasm2\", \"Nuclei\", \"Own_model\"]\n", - "\n", - "#@markdown ####If using your own model, please provide the path to the model (not the folder):\n", - "\n", - "Prediction_model = \"\" #@param {type:\"string\"}\n", - "\n", - "#@markdown ### What channel do you want to segment?\n", - "\n", - "Channel_to_segment= \"Grayscale\" #@param [\"Grayscale\", \"Blue\", \"Green\", \"Red\"]\n", - "\n", - "# @markdown ###If you chose the model \"cytoplasm\" indicate if you also have a nuclear channel that can be used to aid the segmentation.\n", - "\n", - "Nuclear_channel= \"None\" #@param [\"None\", \"Blue\", \"Green\", \"Red\"]\n", - "\n", - "#@markdown ### Segmentation parameters:\n", - "Object_diameter = 0#@param {type:\"number\"}\n", - "\n", - "Flow_threshold = 0.4 #@param {type:\"slider\", min:0.1, max:1.1, step:0.1}\n", - "Cell_probability_threshold=0 #@param {type:\"slider\", min:-6, max:6, step:1}\n", - "\n", - "# Find the number of channel in the input image\n", - "\n", - "random_choice = random.choice(os.listdir(Data_folder))\n", - "x = io.imread(Data_folder+\"/\"+random_choice)\n", - "n_channel = 1 if x.ndim == 2 else x.shape[-1]\n", - "\n", - "if Channel_to_segment == \"Grayscale\":\n", - " segment_channel = 0\n", - "\n", - " if Data_type == \"Single_Images\":\n", - " if not n_channel == 1:\n", - " print(bcolors.WARNING +\"!! WARNING: your image has more than one channel, choose which channel you want to use for your predictions !!\")\n", - "\n", - "if Channel_to_segment == \"Blue\":\n", - " segment_channel = 3\n", - "\n", - "if Channel_to_segment == \"Green\":\n", - " segment_channel = 2\n", - "\n", - "if Channel_to_segment == \"Red\":\n", - " segment_channel = 1\n", - "\n", - "if Nuclear_channel == \"Blue\":\n", - " nuclear_channel = 3\n", - "\n", - "if Nuclear_channel == \"Green\":\n", - " nuclear_channel = 2\n", - "\n", - "if Nuclear_channel == \"Red\":\n", - " nuclear_channel = 1\n", - "\n", - "if Nuclear_channel == \"None\":\n", - " nuclear_channel = 0\n", - "\n", - "if model_choice == \"Cytoplasm\": \n", - " channels=[segment_channel,nuclear_channel]\n", - " model = models.Cellpose(gpu=True, model_type=\"cyto\")\n", - " print(\"Cytoplasm model enabled\")\n", - "\n", - "if model_choice == \"Cytoplasm2\": \n", - " channels=[segment_channel,nuclear_channel]\n", - " model = models.Cellpose(gpu=True, model_type=\"cyto2\")\n", - " print(\"Cytoplasm model enabled\")\n", - " \n", - "if model_choice == \"Nuclei\":\n", - " channels=[segment_channel,0]\n", - " model = models.Cellpose(gpu=True, model_type=\"nuclei\")\n", - " print(\"Nuclei model enabled\")\n", - "\n", - "if model_choice == \"Own_model\":\n", - " channels=[segment_channel,nuclear_channel]\n", - " model = models.CellposeModel(gpu=True, pretrained_model=Prediction_model, torch=True, diam_mean=30.0, net_avg=True, device=None, residual_on=True, style_on=True, concatenation=False)\n", - "\n", - " print(\"Own model enabled\")\n", - "\n", - "if Object_diameter is 0:\n", - " Object_diameter = None\n", - " print(\"The cell size will be estimated automatically for each image\")\n", - "\n", - "if Data_type == \"Single_Images\" :\n", - "\n", - " print('--------------------------------------------------------------')\n", - " @interact\n", - " def preview_results(file = os.listdir(Data_folder)):\n", - " source_image = io.imread(os.path.join(Data_folder, file))\n", - " \n", - " if model_choice == \"Own_model\":\n", - " masks, flows, styles = model.eval(source_image, diameter=Object_diameter, flow_threshold=Flow_threshold,cellprob_threshold=Cell_probability_threshold, channels=channels)\n", - " else:\n", - " masks, flows, styles, diams = model.eval(source_image, diameter=Object_diameter, flow_threshold=Flow_threshold,cellprob_threshold=Cell_probability_threshold, channels=channels)\n", - " \n", - " flowi = flows[0]\n", - " fig = plt.figure(figsize=(20,10))\n", - " plot.show_segmentation(fig, source_image, masks, flowi, channels=channels)\n", - " plt.tight_layout()\n", - " plt.show()\n", - "\n", - "\n", - " def batch_process():\n", - " print(\"Your images are now beeing processed\")\n", - "\n", - " for name in os.listdir(Data_folder):\n", - " print(\"Performing prediction on: \"+name)\n", - " image = io.imread(Data_folder+\"/\"+name)\n", - " short_name = os.path.splitext(name)\n", - " \n", - " if model_choice == \"Own_model\":\n", - " masks, flows, styles = model.eval(image, diameter=Object_diameter, flow_threshold=Flow_threshold,cellprob_threshold=Cell_probability_threshold, channels=channels)\n", - " else:\n", - " masks, flows, styles, diams = model.eval(image, diameter=Object_diameter, flow_threshold=Flow_threshold,cellprob_threshold=Cell_probability_threshold, channels=channels)\n", - " \n", - " os.chdir(Result_folder)\n", - " imsave(str(short_name[0])+\"_mask.tif\", masks, compress=ZIP_DEFLATED)\n", - "\n", - " im = interact_manual(batch_process)\n", - " im.widget.children[0].description = 'Process your images'\n", - " im.widget.children[0].style.button_color = 'yellow'\n", - " display(im)\n", - "\n", - "if Data_type == \"Stacks\" :\n", - " print(\"Stacks are now beeing predicted\")\n", - " \n", - " print('--------------------------------------------------------------')\n", - " @interact\n", - " def preview_results_stacks(file = os.listdir(Data_folder)):\n", - " timelapse = imread(Data_folder+\"/\"+file)\n", - "\n", - " if model_choice == \"Own_model\":\n", - " masks, flows, styles = model.eval(timelapse[0], diameter=Object_diameter, flow_threshold=Flow_threshold,cellprob_threshold=Cell_probability_threshold, channels=channels)\n", - " else:\n", - " masks, flows, styles, diams = model.eval(timelapse[0], diameter=Object_diameter, flow_threshold=Flow_threshold,cellprob_threshold=Cell_probability_threshold, channels=channels)\n", - " \n", - " flowi = flows[0]\n", - " fig = plt.figure(figsize=(20,10))\n", - " plot.show_segmentation(fig, timelapse[0], masks, flowi, channels=channels)\n", - " plt.tight_layout()\n", - " plt.show()\n", - "\n", - " def batch_process_stack():\n", - " print(\"Your images are now beeing processed\") \n", - " for image in os.listdir(Data_folder):\n", - " print(\"Performing prediction on: \"+image)\n", - " timelapse = imread(Data_folder+\"/\"+image)\n", - " short_name = os.path.splitext(image)\n", - " n_timepoint = timelapse.shape[0]\n", - " prediction_stack = np.zeros((n_timepoint, timelapse.shape[1], timelapse.shape[2]))\n", - " \n", - " for t in range(n_timepoint):\n", - " print(\"Frame number: \"+str(t))\n", - " img_t = timelapse[t]\n", - "\n", - " if model_choice == \"Own_model\":\n", - " masks, flows, styles = model.eval(img_t, diameter=Object_diameter, flow_threshold=Flow_threshold,cellprob_threshold=Cell_probability_threshold, channels=channels)\n", - " else:\n", - " masks, flows, styles, diams = model.eval(img_t, diameter=Object_diameter, flow_threshold=Flow_threshold,cellprob_threshold=Cell_probability_threshold, channels=channels)\n", - " \n", - " \n", - " prediction_stack[t] = masks\n", - " \n", - " prediction_stack_32 = img_as_float32(prediction_stack, force_copy=False)\n", - " os.chdir(Result_folder)\n", - " imsave(str(short_name[0])+\".tif\", prediction_stack_32, compress=ZIP_DEFLATED)\n", - " \n", - " im = interact_manual(batch_process_stack)\n", - " im.widget.children[0].description = 'Process your images'\n", - " im.widget.children[0].style.button_color = 'yellow'\n", - " display(im) \n", - " \n", - "\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "hvkd66PldsXB" - }, - "source": [ - "## **6.2. Download your predictions**\n", - "---\n", - "\n", - "**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "u4pcBe8Z3T2J" - }, - "source": [ - "#**Thank you for using Cellpose 2D!**" - ] - } - ] -} \ No newline at end of file +{"nbformat":4,"nbformat_minor":0,"metadata":{"accelerator":"GPU","colab":{"name":"Cellpose_2D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1hMjEc-Ex7j-jeYGclaPw2x3OgbkeC6Bl","timestamp":1610626439596},{"file_id":"1_W4q9V1ExGFldTUBvGK91E0LG5QMc7K6","timestamp":1602523405636},{"file_id":"1t9a-44km730bI7F4I08-6Xh7wEZuL98p","timestamp":1591013189418},{"file_id":"11TigzvLl4FSSwFHUNwLzZKI2IAix4Nmu","timestamp":1586415689249},{"file_id":"1_dSnxUg_qtNWjrPc7D6RWDWlCanEL4Ve","timestamp":1585153449937},{"file_id":"1bKo8jYVZPPgXPa_-Gdu1KhDnNN4vYfLx","timestamp":1583200150464}],"collapsed_sections":[],"machine_shape":"hm"},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.4"}},"cells":[{"cell_type":"markdown","metadata":{"id":"0YVrnjRozAGg"},"source":["This is notebook is in beta, expect bugs and missing features compared to other ZeroCostDL4Mic notebooks\n","\n","- Training now uses TORCH. \n","- Currently missing features include: \n"," - The training and validation curves are not saved or visualised\n"," "]},{"cell_type":"markdown","metadata":{"id":"V9zNGvape2-I"},"source":["# **Cellpose (2D)**\n","\n","---\n","\n"]},{"cell_type":"markdown","metadata":{"id":"pwLsIXtEw3Kx"},"source":["**Cellpose 2D** is a deep-learning method that can be used to segment cell and/or nuclei from bioimages and was first published by [Stringer *et al.* in 2020, in Nature Method](https://www.nature.com/articles/s41592-020-01018-x). \n","\n"," **This particular notebook enables nuclei segmentation of 2D dataset. If you are interested in 3D dataset, you should use the StarDist or U-Net 3D notebooks instead.**\n","\n","---\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the paper:\n","\n","**Cellpose: a generalist algorithm for cellular segmentation** from Stringer *et al.*, Nature Methods, 2020. (https://www.nature.com/articles/s41592-020-01018-x)\n","\n","**The Original code** is freely available in GitHub:\n","/~https://github.com/MouseLand/cellpose\n","\n","**Please also cite this original paper when using or developing this notebook.**\n","\n","**This notebook was also inspired by the one created by @pr4deepr** which is available here:\n","https://colab.research.google.com/github/MouseLand/cellpose/blob/master/notebooks/Cellpose_2D_v0_1.ipynb\n"]},{"cell_type":"markdown","metadata":{"id":"C5oYf0Q5yXrl"},"source":["#**0. Before getting started**\n","---\n"," For Cellpose to train, **it needs to have access to a paired training dataset made of images and their corresponding masks (label images)**. Information on how to generate a training dataset is available in our Wiki page: /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model**. The quality control assessment can be done directly in this notebook.\n","\n","The data structure is important. It is necessary that all the input data are in the same folder and that all the output data is in a separate folder. The provided training dataset is already split in two folders called \"Training - Images\" (Training_source) and \"Training - Masks\" (Training_target).\n","\n","Additionally, the corresponding Training_source and Training_target files need to have **the same name**.\n","\n"," **Use 8/16 bit png or Tiff images**.\n","\n","\n","You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed. This can include Test dataset for which you have the equivalent output and can compare to what the network provides.\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Images (Training_source)\n"," - img_1.tif, img_2.tif, ...\n"," - Label images (Training_target)\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - Images\n"," - img_1.tif, img_2.tif\n"," - Masks \n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a pretrained model you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **1. Install Cellpose and dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"3u2mXn3XsWzd","cellView":"form"},"source":["Notebook_version = '1.13'\n","Network = 'Cellpose'\n","\n","from builtins import any as b_any\n","\n","def get_requirements_path():\n"," # Store requirements file in 'contents' directory\n"," current_dir = os.getcwd()\n"," dir_count = current_dir.count('/') - 1\n"," path = '../' * (dir_count) + 'requirements.txt'\n"," return path\n","\n","def filter_files(file_list, filter_list):\n"," filtered_list = []\n"," for fname in file_list:\n"," if b_any(fname.split('==')[0] in s for s in filter_list):\n"," filtered_list.append(fname)\n"," return filtered_list\n","\n","def build_requirements_file(before, after):\n"," path = get_requirements_path()\n","\n"," # Exporting requirements.txt for local run\n"," !pip freeze > $path\n","\n"," # Get minimum requirements file\n"," df = pd.read_csv(path, delimiter = \"\\n\")\n"," mod_list = [m.split('.')[0] for m in after if not m in before]\n"," req_list_temp = df.values.tolist()\n"," req_list = [x[0] for x in req_list_temp]\n","\n"," # Replace with package name and handle cases where import name is different to module name\n"," mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]\n"," mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list]\n"," filtered_list = filter_files(req_list, mod_replace_list)\n","\n"," file=open(path,'w')\n"," for item in filtered_list:\n"," file.writelines(item + '\\n')\n","\n"," file.close()\n","\n","import sys\n","before = [str(m) for m in sys.modules]\n","\n","#@markdown ##Install Cellpose and dependencies\n","\n","#Libraries contains information of certain topics. \n","#For example the tifffile library contains information on how to handle tif-files.\n","\n","#Here, we install libraries which are not already included in Colab.\n","\n","!pip install tifffile # contains tools to operate tiff-files\n","!pip install cellpose \n","#!pip install mxnet-cu101 \n","!pip install wget\n","!pip install memory_profiler\n","!pip install fpdf\n","%load_ext memory_profiler\n","\n","# ------- Variable specific to Cellpose -------\n","\n","from urllib.parse import urlparse\n","%matplotlib inline\n","from cellpose import models\n","use_GPU = models.use_gpu()\n","\n","#import mxnet as mx\n","\n","from skimage.util import img_as_ubyte\n","import cv2\n","from cellpose import plot\n","from ipywidgets import interact, interact_manual\n","from zipfile import ZIP_DEFLATED\n","\n","\n","# For sliders and dropdown menu and progress bar\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","from skimage.util import img_as_ubyte\n","from tqdm import tqdm \n","from fpdf import FPDF, HTMLMixin\n","from datetime import datetime\n","import subprocess\n","from pip._internal.operations.freeze import freeze\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print(\"Libraries installed\")\n","\n","# Check if this is the latest version of the notebook\n","All_notebook_versions = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_Notebook_versions.csv\", dtype=str)\n","print('Notebook version: '+Notebook_version)\n","Latest_Notebook_version = All_notebook_versions[All_notebook_versions[\"Notebook\"] == Network]['Version'].iloc[0]\n","print('Latest notebook version: '+Latest_Notebook_version)\n","if Notebook_version == Latest_Notebook_version:\n"," print(\"This notebook is up-to-date.\")\n","else:\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","\n","!pip freeze > requirements.txt\n","\n","#Create a pdf document with training summary\n","\n","def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n"," # save FPDF() class into a \n"," # variable pdf \n"," #from datetime import datetime\n","\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," \n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," # add another cell \n"," if trained:\n"," training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n"," pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n"," pdf.ln(1)\n","\n"," Header_2 = 'Information for your materials and methods:'\n"," pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n"," #print(all_packages)\n","\n"," #Main Packages\n"," main_packages = ''\n"," version_numbers = []\n"," for name in ['tensorflow','numpy','Keras','csbdeep']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n"," cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n"," cuda_version = cuda_version.stdout.decode('utf-8')\n"," cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n"," gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n"," gpu_name = gpu_name.stdout.decode('utf-8')\n"," gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n"," #print(cuda_version[cuda_version.find(', V')+3:-1])\n"," #print(gpu_name)\n","\n"," shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n"," dataset_size = len(os.listdir(Training_source))\n","\n"," text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', with a batch size of '+str(batch_size)+', using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n","\n"," if pretrained_model:\n"," text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', with a batch size of '+str(batch_size)+', using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was re-trained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," # print(text)\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.multi_cell(190, 5, txt = text, align='L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n"," pdf.set_font('')\n"," if augmentation:\n"," aug_text = 'The dataset was augmented by a factor of '+str(Multiply_dataset_by)\n","\n"," else:\n"," aug_text = 'No augmentation was used for training.'\n"," pdf.multi_cell(190, 5, txt=aug_text, align='L')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," if Use_Default_Advanced_Parameters:\n"," pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n"," pdf.cell(200, 5, txt='The following parameters were used for training:')\n"," pdf.ln(1)\n"," html = \"\"\" \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
number_of_epochs{0}
batch_size{1}
percentage_validation{2}
initial_learning_rate{3}
\n"," \"\"\".format(number_of_epochs,batch_size,percentage_validation,initial_learning_rate)\n"," pdf.write_html(html)\n","\n"," #pdf.multi_cell(190, 5, txt = text_2, align='L')\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," pdf.ln(1)\n"," pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(29, 5, txt= 'Training_source:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(27, 5, txt= 'Training_target:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n"," #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(22, 5, txt= 'Model Path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n"," pdf.ln(1)\n"," pdf.cell(60, 5, txt = 'Example Training pair', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread('/content/TrainingDataExample_Cellpose2D.png').shape\n"," pdf.image('/content/TrainingDataExample_Cellpose2D.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- Cellpose: Stringer, Carsen, et al. \"Cellpose: a generalist algorithm for cellular segmentation.\" Nature Methods 18, pages100-106(2021).'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," pdf.ln(3)\n"," reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(model_path+'/'+model_name+'/'+model_name+\"_training_report.pdf\")\n"," # pdf.output(Saving_path+'/train_folder/models/'+model_name+\"_training_report.pdf\")\n","\n"," \n","\n","#Make a pdf summary of the QC results\n","\n","def qc_pdf_export():\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'Cellpose 2D'\n"," #model_name = os.path.basename(full_QC_model_path)\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(2)\n"," pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'Quality Control/QC_example_data.png').shape\n"," if os.path.exists(full_QC_model_path+'Quality Control/lossCurvePlots.png'):\n"," pdf.image(full_QC_model_path+'Quality Control/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/13))\n"," else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.', align='L')\n"," pdf.ln(2)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(3)\n"," pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'Quality Control/QC_example_data.png').shape\n"," pdf.image(full_QC_model_path+'Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/10))\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n","\n"," pdf.ln(1)\n"," html = \"\"\"\n"," \n"," \n"," \"\"\"\n"," with open(full_QC_model_path+'Quality Control/Quality_Control for '+QC_model_name+'.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," mSSIM_PvsGT = header[1]\n"," mSSIM_SvsGT = header[2]\n"," NRMSE_PvsGT = header[3]\n"," NRMSE_SvsGT = header[4]\n"," PSNR_PvsGT = header[5]\n"," PSNR_SvsGT = header[6]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,mSSIM_PvsGT,mSSIM_SvsGT,NRMSE_PvsGT,NRMSE_SvsGT,PSNR_PvsGT,PSNR_SvsGT)\n"," html = html+header\n"," for row in metrics:\n"," image = row[0]\n"," mSSIM_PvsGT = row[1]\n"," mSSIM_SvsGT = row[2]\n"," NRMSE_PvsGT = row[3]\n"," NRMSE_SvsGT = row[4]\n"," PSNR_PvsGT = row[5]\n"," PSNR_SvsGT = row[6]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,str(round(float(mSSIM_PvsGT),3)),str(round(float(mSSIM_SvsGT),3)),str(round(float(NRMSE_PvsGT),3)),str(round(float(NRMSE_SvsGT),3)),str(round(float(PSNR_PvsGT),3)),str(round(float(PSNR_SvsGT),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}{3}{4}{5}{6}
{0}{1}{2}{3}{4}{5}{6}
\"\"\"\n","\n"," pdf.write_html(html)\n","\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- CARE: Weigert, Martin, et al. \"Content-aware image restoration: pushing the limits of fluorescence microscopy.\" Nature methods 15.12 (2018): 1090-1097.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n"," pdf.ln(3)\n"," reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(full_QC_model_path+'Quality Control/'+QC_model_name+'_QC_report.pdf')\n","\n","\n","\n","\n","# Build requirements file for local run\n","after = [str(m) for m in sys.modules]\n","build_requirements_file(before, after)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"b4-r1gE7Iamv"},"source":["# **2. Initialise the Colab session**\n","---"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["\n","## **2.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelerator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"BDhmUgqCStlm","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","\n","#%tensorflow_version 1.x\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-oqBTeLaImnU"},"source":["## **2.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"cellView":"form","id":"01Djr8v-5pPk"},"source":["\n","#@markdown ##Run this cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","#mounts user's Google Drive to Google Colab.\n","\n","from google.colab import drive\n","drive.mount('/content/gdrive')\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"2UfUWjI_askO"},"source":["# **3. Select your parameters and paths**\n","---"]},{"cell_type":"markdown","metadata":{"id":"ZyMxrSWvavVL"},"source":["## **3.1. Setting main training parameters**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"id":"5MlTyQVXXvDx"},"source":[" **Paths for training, predictions and results**\n","\n","\n","**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source (images of cells) and Training_target (masks) training data respecively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","\n","**Training parameters**\n","\n","**`number_of_epochs`:** Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a 100 epochs, but a full training should run for up to 500-1000 epochs. Evaluate the performance after training (see 5.). **Default value: 500**\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 8**\n","\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during the training. **Default value: 10** \n","\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"2HkNZ16BdfJv","cellView":"form"},"source":["#@markdown ###Path to training images:\n","\n","Training_source = \"\" #@param {type:\"string\"}\n","Training_target = \"\" #@param {type:\"string\"}\n","\n","#Define where the patch file will be saved\n","base = \"/content\"\n","\n","# model name and path\n","#@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","# other parameters for training.\n","#@markdown ###Training Parameters:\n","#@markdown Number of epochs:\n","number_of_epochs = 500#@param {type:\"number\"}\n","\n","Channel_to_use_for_training = \"Grayscale\" #@param [\"Grayscale\", \"Blue\", \"Green\", \"Red\"]\n","\n","# @markdown ###If you have a secondary channel that can be used for training, for instance nuclei, choose it here:\n","\n","Second_training_channel= \"None\" #@param [\"None\", \"Blue\", \"Green\", \"Red\"]\n","\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","#@markdown ###If not, please input:\n","batch_size = 8#@param {type:\"number\"}\n","initial_learning_rate = 0.0002 #@param {type:\"number\"}\n","percentage_validation = 10#@param {type:\"number\"}\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 8 \n"," initial_learning_rate = 0.0002\n"," percentage_validation = 10\n","\n","#here we check that no model with the same name already exist, if so delete\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n"," print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\"+W)\n"," \n","# Here we enable the cyto pre-trained model by default (in case the cell is not ran)\n","model_to_load = \"cyto\"\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = True\n","\n","# This will display a randomly chosen dataset input and output\n","random_choice = random.choice(os.listdir(Training_source))\n","x = io.imread(Training_source+\"/\"+random_choice)\n","norm = simple_norm(x, percent = 99)\n","y = io.imread(Training_target+\"/\"+random_choice)\n","\n","# Find the number of channel in the input image\n","\n","n_channel = 1 if x.ndim == 2 else x.shape[-1]\n","\n","\n","# Here we match the channel to number\n","\n","if Channel_to_use_for_training == \"Grayscale\":\n"," Training_channel = 0\n","\n"," if not n_channel == 1:\n"," print(bcolors.WARNING +\"!! WARNING: your image has more than one channel, choose which channel you want to use for trainning !!\")\n","\n","if Channel_to_use_for_training == \"Blue\":\n"," Training_channel = 3\n","\n","if Channel_to_use_for_training == \"Green\":\n"," Training_channel = 2\n","\n","if Channel_to_use_for_training == \"Red\":\n"," Training_channel = 1\n","\n","\n","if Second_training_channel == \"Blue\":\n"," Second_training_channel = 3\n","\n","if Second_training_channel == \"Green\":\n"," Second_training_channel = 2\n","\n","if Second_training_channel == \"Red\":\n"," Second_training_channel = 1\n","\n","if Second_training_channel == \"None\":\n"," Second_training_channel = 0\n","\n","\n","if n_channel ==1:\n","\n"," f=plt.figure(figsize=(16,8))\n"," plt.subplot(1,2,1)\n"," plt.imshow(x, norm=norm, cmap='magma', interpolation='nearest')\n"," plt.title('Training source')\n"," plt.axis('off');\n","\n"," plt.subplot(1,2,2)\n"," plt.imshow(y,cmap='nipy_spectral', interpolation='nearest')\n"," plt.title('Training target')\n"," plt.axis('off');\n","\n"," plt.savefig('/content/TrainingDataExample_Cellpose2D.png',bbox_inches='tight',pad_inches=0)\n","\n","else:\n","\n"," f=plt.figure(figsize=(20,10))\n"," plt.subplot(1,3,1)\n"," plt.imshow(x, interpolation='nearest')\n"," plt.title('Training source')\n"," plt.axis('off');\n","\n"," plt.subplot(1,3,2)\n"," plt.imshow(x[:, :, int(Training_channel-1)],cmap='magma', interpolation='nearest')\n"," plt.title('Channel used for training')\n"," plt.axis('off');\n","\n"," plt.subplot(1,3,3)\n"," plt.imshow(y,cmap='nipy_spectral', interpolation='nearest')\n"," plt.title('Training target')\n"," plt.axis('off');\n","\n"," plt.savefig('/content/TrainingDataExample_Cellpose2D.png',bbox_inches='tight',pad_inches=0)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"qEg6ar0PhuDY"},"source":["## **3.2. Data augmentation**\n","---\n",""]},{"cell_type":"markdown","metadata":{"id":"t6q9aqDUhxlw"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n","By default, a x4 data augmentation is enabled in this notebook."]},{"cell_type":"code","metadata":{"id":"SblwpgmahfBl","cellView":"form"},"source":["#Data augmentation\n","\n","Use_Data_augmentation = False #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation:\n"," print(bcolors.WARNING+\"Data augmentation enabled\")\n"," Multiply_dataset_by = 4\n","\n","\n","if not Use_Data_augmentation:\n"," print(bcolors.WARNING+\"Data augmentation disabled\") "],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"s2NC_-Tuc02W"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a Cellpose model**. \n","\n"," You can also use the pretrained models already available in Cellpose: \n","\n","- The cytoplasm model in cellpose is trained on two-channel images, where the first channel is the channel to segment, and the second channel is an optional nuclear channel.\n","\n","- The cytoplasm2 model is an updated cytoplasm model trained with user-submitted images.\n","\n","- The nuclear model in cellpose is trained on two-channel images, where the first channel is the channel to segment, and the second channel is always set to an array of zeros.\n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n"]},{"cell_type":"code","metadata":{"id":"sLdgQM6Rc7vp","cellView":"form"},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = True #@param {type:\"boolean\"}\n","\n","Pretrained_model = \"Nuclei\" #@param [\"Cytoplasm\",\"Cytoplasm2\", \"Nuclei\", \"Own_model\"]\n","\n","#@markdown ###If using your own model, please provide the path to the model (not the folder):\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","\n","if Use_pretrained_model == True :\n","\n"," if Pretrained_model == \"Own_model\": \n","\n"," model_to_load = pretrained_model_path\n","\n"," print('The model '+ str(model_to_load) + \"will be used as a starting point\")\n","\n"," if Pretrained_model == \"Cytoplasm\":\n"," model_to_load = \"cyto\"\n"," print('The model Cytoplasm will be used as a starting point')\n","\n"," if Pretrained_model == \"Cytoplasm2\":\n"," model_to_load = \"cyto2\"\n"," print('The model Cytoplasm2 (cyto2) will be used as a starting point')\n","\n"," if Pretrained_model == \"Nuclei\":\n"," model_to_load = \"nuclei\"\n"," print('The model nuclei will be used as a starting point')\n","\n","else:\n"," model_to_load = None\n"," print(bcolors.WARNING+'No pretrained network will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"qeYZ7PeValfs"},"source":["#**4. Train the network**\n","---"]},{"cell_type":"code","metadata":{"id":"tsn8WV3Wl0sG","cellView":"form"},"source":["#@markdown ##Create the model and dataset objects\n","\n","\n","# Here we check that the model destination folder is empty\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: Model folder already exists and has been removed !!\" + W)\n"," shutil.rmtree(model_path+'/'+model_name)\n","\n","os.makedirs(model_path+\"/\"+model_name)\n","\n","\n","#To use cellpose to work we need to organise the data in a way the network can understand\n","\n","# Here we count the number of files in the training target folder\n","Filelist = os.listdir(Training_target)\n","number_files = len(Filelist)\n","\n","# Here we count the number of file to use for validation\n","Image_for_validation = int((number_files)*(percentage_validation/100))\n","\n","Saving_path= \"/content/\"+model_name\n","\n","if os.path.exists(Saving_path):\n"," shutil.rmtree(Saving_path)\n","os.makedirs(Saving_path)\n","\n","train_folder = Saving_path+\"/train_folder\"\n","os.makedirs(train_folder)\n","\n","test_folder = Saving_path+\"/test_folder\"\n","os.makedirs(test_folder)\n","\n","index = 0\n","\n","print('Copying training source data...')\n","for f in tqdm(os.listdir(Training_source)): \n"," short_name = os.path.splitext(f)\n","\n"," if index < Image_for_validation:\n"," shutil.copyfile(Training_source+\"/\"+f, test_folder+\"/\"+short_name[0]+\"_img.tif\")\n"," shutil.copyfile(Training_target+\"/\"+f, test_folder+\"/\"+short_name[0]+\"_masks.tif\") \n"," else:\n"," shutil.copyfile(Training_source+\"/\"+f, train_folder+\"/\"+short_name[0]+\"_img.tif\")\n"," shutil.copyfile(Training_target+\"/\"+f, train_folder+\"/\"+short_name[0]+\"_masks.tif\") \n"," index = index +1\n"," \n","print(\"Done\")\n","\n","pdf_export(augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"jRorFe296LgI"},"source":["## **4.2. Start Training**\n","---\n","\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Training is currently done using Torch.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point.\n"]},{"cell_type":"code","metadata":{"id":"YXUnd3awi6K3","cellView":"form"},"source":["#@markdown ##Start training\n","\n","start = time.time()\n","\n","if not Use_Data_augmentation:\n"," #!python -m cellpose --train --use_gpu --mxnet --fast_mode --dir \"$train_folder\" --test_dir \"$test_folder\" --pretrained_model $model_to_load --chan $Training_channel --chan2 $Second_training_channel --n_epochs $number_of_epochs --learning_rate $initial_learning_rate --batch_size $batch_size --img_filter img --mask_filter masks\n"," !python -m cellpose --train --use_gpu --fast_mode --dir \"$train_folder\" --test_dir \"$test_folder\" --pretrained_model $model_to_load --chan $Training_channel --chan2 $Second_training_channel --n_epochs $number_of_epochs --learning_rate $initial_learning_rate --batch_size $batch_size --img_filter img --mask_filter masks\n","\n","else:\n"," #!python -m cellpose --train --use_gpu --mxnet --dir \"$train_folder\" --test_dir \"$test_folder\" --pretrained_model $model_to_load --chan $Training_channel --chan2 $Second_training_channel --n_epochs $number_of_epochs --learning_rate $initial_learning_rate --batch_size $batch_size --img_filter img --mask_filter masks\n"," !python -m cellpose --train --use_gpu --fast_mode --dir \"$train_folder\" --test_dir \"$test_folder\" --pretrained_model $model_to_load --chan $Training_channel --chan2 $Second_training_channel --n_epochs $number_of_epochs --learning_rate $initial_learning_rate --batch_size $batch_size --img_filter img --mask_filter masks\n","\n","\n","#Settings\n","# --check_mkl', action='store_true', help='check if mkl working'\n","\n","#'--mkldnn', action='store_true', help='for mxnet, force MXNET_SUBGRAPH_BACKEND = \"MKLDNN\"')\n","\n","#'--train', action='store_true', help='train network using images in dir')\n","#'--dir', required=False, help='folder containing data to run or train on')\n","# '--mxnet', action='store_true', help='use mxnet')\n","# '--img_filter', required=False, default=[], type=str, help='end string for images to run on')\n","# '--use_gpu', action='store_true', help='use gpu if mxnet with cuda installed')\n","# '--fast_mode', action='store_true', help=\"make code run faster by turning off 4 network averaging\")\n","# '--resample', action='store_true', help=\"run dynamics on full image (slower for images with large diameters)\")\n","# '--no_interp', action='store_true', help='do not interpolate when running dynamics (was default)')\n","# '--do_3D', action='store_true', help='process images as 3D stacks of images (nplanes x nchan x Ly x Lx')\n"," \n","# settings for training\n","# parser.add_argument('--train_size', action='store_true', help='train size network at end of training')\n","# parser.add_argument('--mask_filter', required=False, default='_masks', type=str, help='end string for masks to run on')\n","# parser.add_argument('--test_dir', required=False, default=[], type=str, help='folder containing test data (optional)')\n","# parser.add_argument('--learning_rate', required=False, default=0.2, type=float, help='learning rate')\n","# parser.add_argument('--n_epochs', required=False, default=500, type=int, help='number of epochs')\n","# parser.add_argument('--batch_size', required=False, default=8, type=int, help='batch size')\n","# parser.add_argument('--residual_on', required=False, default=1, type=int, help='use residual connections')\n","# parser.add_argument('--style_on', required=False, default=1, type=int, help='use style vector')\n","# parser.add_argument('--concatenation', required=False, dfault=0, type=int, help='concatenate downsampled layers with upsampled layers (off by default which means they are added)')\n","\n","\n","\n","#Here we copy the model to the result folder after training\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: Model folder already exists and has been removed !!\" + W)\n"," shutil.rmtree(model_path+'/'+model_name)\n","destination = shutil.copytree(Saving_path+\"/train_folder/models\", model_path+\"/\"+model_name)\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","print(\"Your model is also available here: \"+str(model_path+\"/\"+model_name))\n","\n","\n","pdf_export(trained=True, augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"qvbm9EJGaXr9"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"IeiU6D2jGDh4","cellView":"form"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, indicate which model you want to assess:\n","\n","QC_model = \"Own_model\" #@param [\"Cytoplasm\",\"Cytoplasm2\", \"Nuclei\", \"Own_model\"]\n","\n","#@markdown ###If using your own model, please provide the path to the model (not the folder):\n","\n","QC_model_path = \"\" #@param {type:\"string\"}\n","\n","#@markdown ###If using the \"Cytoplasm\" or \"Nuclei\" models, please indicate where you want to save the results:\n","Saving_path = \"\" #@param {type:\"string\"}\n","\n","\n","if Use_the_current_trained_model :\n","\n"," list_files = os.listdir(model_path+\"/\"+model_name)\n"," \n"," QC_model_path = model_path+\"/\"+model_name+\"/\"+list_files[0]\n"," QC_model = \"Own_model\"\n","\n"," #model = models.CellposeModel(gpu=True, pretrained_model=QC_model_path, torch=False, diam_mean=30.0, net_avg=True, device=None, residual_on=True, style_on=True, concatenation=False)\n"," model = models.CellposeModel(gpu=True, pretrained_model=QC_model_path, torch=True, diam_mean=30.0, net_avg=True, device=None, residual_on=True, style_on=True, concatenation=False)\n","\n"," QC_model_folder = os.path.dirname(QC_model_path)\n"," QC_model_name = os.path.basename(QC_model_folder)\n"," Saving_path = QC_model_folder\n","\n"," print(\"The \"+str(QC_model_name)+\" model will be evaluated\")\n","\n","if not Use_the_current_trained_model:\n","\n"," if QC_model == \"Cytoplasm\":\n"," model = models.Cellpose(gpu=True, model_type=\"cyto\")\n"," QC_model_folder = Saving_path\n"," QC_model_name = \"Cytoplasm\"\n","\n"," print('The model \"Cytoplasm\" will be evaluated')\n","\n"," if QC_model == \"Cytoplasm2\":\n"," model = models.Cellpose(gpu=True, model_type=\"cyto2\")\n"," QC_model_folder = Saving_path\n"," QC_model_name = \"Cytoplasm2\"\n","\n"," print('The model \"Cytoplasm\" will be evaluated')\n","\n"," if QC_model == \"Nuclei\":\n"," model = models.Cellpose(gpu=True, model_type=\"nuclei\")\n","\n"," QC_model_folder = Saving_path\n"," QC_model_name = \"Nuclei\"\n","\n"," print('The model \"Nuclei\" will be evaluated')\n"," \n"," if QC_model == \"Own_model\":\n","\n"," if os.path.exists(QC_model_path):\n"," model = models.CellposeModel(gpu=True, pretrained_model=QC_model_path, torch=True, diam_mean=30.0, net_avg=True, device=None, residual_on=True, style_on=True, concatenation=False)\n"," \n"," QC_model_folder = os.path.dirname(QC_model_path)\n"," Saving_path = QC_model_folder\n"," QC_model_name = os.path.basename(QC_model_folder)\n"," print(\"The \"+str(QC_model_name)+\" model will be evaluated\")\n"," \n"," else: \n"," print(bcolors.WARNING+'!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","#Here we make the folder to save the resuslts if it does not exists\n","\n","if not Saving_path == \"\":\n"," if os.path.exists(QC_model_folder) == False:\n"," os.makedirs(QC_model_folder)\n","else:\n"," print(bcolors.WARNING+'!! WARNING: Indicate where you want to save the results')\n","\n","\n","# Here we load the def that perform the QC, code taken from StarDist /~https://github.com/mpicbg-csbd/stardist/blob/master/stardist/matching.py\n","\n","import numpy as np\n","from numba import jit\n","from tqdm import tqdm\n","from scipy.optimize import linear_sum_assignment\n","from collections import namedtuple\n","\n","\n","matching_criteria = dict()\n","\n","def label_are_sequential(y):\n"," \"\"\" returns true if y has only sequential labels from 1... \"\"\"\n"," labels = np.unique(y)\n"," return (set(labels)-{0}) == set(range(1,1+labels.max()))\n","\n","\n","def is_array_of_integers(y):\n"," return isinstance(y,np.ndarray) and np.issubdtype(y.dtype, np.integer)\n","\n","\n","def _check_label_array(y, name=None, check_sequential=False):\n"," err = ValueError(\"{label} must be an array of {integers}.\".format(\n"," label = 'labels' if name is None else name,\n"," integers = ('sequential ' if check_sequential else '') + 'non-negative integers',\n"," ))\n"," is_array_of_integers(y) or print(\"An error occured\")\n"," if check_sequential:\n"," label_are_sequential(y) or print(\"An error occured\")\n"," else:\n"," y.min() >= 0 or print(\"An error occured\")\n"," return True\n","\n","\n","def label_overlap(x, y, check=True):\n"," if check:\n"," _check_label_array(x,'x',True)\n"," _check_label_array(y,'y',True)\n"," x.shape == y.shape or _raise(ValueError(\"x and y must have the same shape\"))\n"," return _label_overlap(x, y)\n","\n","@jit(nopython=True)\n","def _label_overlap(x, y):\n"," x = x.ravel()\n"," y = y.ravel()\n"," overlap = np.zeros((1+x.max(),1+y.max()), dtype=np.uint)\n"," for i in range(len(x)):\n"," overlap[x[i],y[i]] += 1\n"," return overlap\n","\n","\n","def intersection_over_union(overlap):\n"," _check_label_array(overlap,'overlap')\n"," if np.sum(overlap) == 0:\n"," return overlap\n"," n_pixels_pred = np.sum(overlap, axis=0, keepdims=True)\n"," n_pixels_true = np.sum(overlap, axis=1, keepdims=True)\n"," return overlap / (n_pixels_pred + n_pixels_true - overlap)\n","\n","matching_criteria['iou'] = intersection_over_union\n","\n","\n","def intersection_over_true(overlap):\n"," _check_label_array(overlap,'overlap')\n"," if np.sum(overlap) == 0:\n"," return overlap\n"," n_pixels_true = np.sum(overlap, axis=1, keepdims=True)\n"," return overlap / n_pixels_true\n","\n","matching_criteria['iot'] = intersection_over_true\n","\n","\n","def intersection_over_pred(overlap):\n"," _check_label_array(overlap,'overlap')\n"," if np.sum(overlap) == 0:\n"," return overlap\n"," n_pixels_pred = np.sum(overlap, axis=0, keepdims=True)\n"," return overlap / n_pixels_pred\n","\n","matching_criteria['iop'] = intersection_over_pred\n","\n","\n","def precision(tp,fp,fn):\n"," return tp/(tp+fp) if tp > 0 else 0\n","def recall(tp,fp,fn):\n"," return tp/(tp+fn) if tp > 0 else 0\n","def accuracy(tp,fp,fn):\n"," # also known as \"average precision\" (?)\n"," # -> https://www.kaggle.com/c/data-science-bowl-2018#evaluation\n"," return tp/(tp+fp+fn) if tp > 0 else 0\n","def f1(tp,fp,fn):\n"," # also known as \"dice coefficient\"\n"," return (2*tp)/(2*tp+fp+fn) if tp > 0 else 0\n","\n","\n","def _safe_divide(x,y):\n"," return x/y if y>0 else 0.0\n","\n","def matching(y_true, y_pred, thresh=0.5, criterion='iou', report_matches=False):\n"," \"\"\"Calculate detection/instance segmentation metrics between ground truth and predicted label images.\n"," Currently, the following metrics are implemented:\n"," 'fp', 'tp', 'fn', 'precision', 'recall', 'accuracy', 'f1', 'criterion', 'thresh', 'n_true', 'n_pred', 'mean_true_score', 'mean_matched_score', 'panoptic_quality'\n"," Corresponding objects of y_true and y_pred are counted as true positives (tp), false positives (fp), and false negatives (fn)\n"," whether their intersection over union (IoU) >= thresh (for criterion='iou', which can be changed)\n"," * mean_matched_score is the mean IoUs of matched true positives\n"," * mean_true_score is the mean IoUs of matched true positives but normalized by the total number of GT objects\n"," * panoptic_quality defined as in Eq. 1 of Kirillov et al. \"Panoptic Segmentation\", CVPR 2019\n"," Parameters\n"," ----------\n"," y_true: ndarray\n"," ground truth label image (integer valued)\n"," predicted label image (integer valued)\n"," thresh: float\n"," threshold for matching criterion (default 0.5)\n"," criterion: string\n"," matching criterion (default IoU)\n"," report_matches: bool\n"," if True, additionally calculate matched_pairs and matched_scores (note, that this returns even gt-pred pairs whose scores are below 'thresh')\n"," Returns\n"," -------\n"," Matching object with different metrics as attributes\n"," Examples\n"," --------\n"," >>> y_true = np.zeros((100,100), np.uint16)\n"," >>> y_true[10:20,10:20] = 1\n"," >>> y_pred = np.roll(y_true,5,axis = 0)\n"," >>> stats = matching(y_true, y_pred)\n"," >>> print(stats)\n"," Matching(criterion='iou', thresh=0.5, fp=1, tp=0, fn=1, precision=0, recall=0, accuracy=0, f1=0, n_true=1, n_pred=1, mean_true_score=0.0, mean_matched_score=0.0, panoptic_quality=0.0)\n"," \"\"\"\n"," _check_label_array(y_true,'y_true')\n"," _check_label_array(y_pred,'y_pred')\n"," y_true.shape == y_pred.shape or _raise(ValueError(\"y_true ({y_true.shape}) and y_pred ({y_pred.shape}) have different shapes\".format(y_true=y_true, y_pred=y_pred)))\n"," criterion in matching_criteria or _raise(ValueError(\"Matching criterion '%s' not supported.\" % criterion))\n"," if thresh is None: thresh = 0\n"," thresh = float(thresh) if np.isscalar(thresh) else map(float,thresh)\n","\n"," y_true, _, map_rev_true = relabel_sequential(y_true)\n"," y_pred, _, map_rev_pred = relabel_sequential(y_pred)\n","\n"," overlap = label_overlap(y_true, y_pred, check=False)\n"," scores = matching_criteria[criterion](overlap)\n"," assert 0 <= np.min(scores) <= np.max(scores) <= 1\n","\n"," # ignoring background\n"," scores = scores[1:,1:]\n"," n_true, n_pred = scores.shape\n"," n_matched = min(n_true, n_pred)\n","\n"," def _single(thr):\n"," not_trivial = n_matched > 0 and np.any(scores >= thr)\n"," if not_trivial:\n"," # compute optimal matching with scores as tie-breaker\n"," costs = -(scores >= thr).astype(float) - scores / (2*n_matched)\n"," true_ind, pred_ind = linear_sum_assignment(costs)\n"," assert n_matched == len(true_ind) == len(pred_ind)\n"," match_ok = scores[true_ind,pred_ind] >= thr\n"," tp = np.count_nonzero(match_ok)\n"," else:\n"," tp = 0\n"," fp = n_pred - tp\n"," fn = n_true - tp\n"," # assert tp+fp == n_pred\n"," # assert tp+fn == n_true\n","\n"," # the score sum over all matched objects (tp)\n"," sum_matched_score = np.sum(scores[true_ind,pred_ind][match_ok]) if not_trivial else 0.0\n","\n"," # the score average over all matched objects (tp)\n"," mean_matched_score = _safe_divide(sum_matched_score, tp)\n"," # the score average over all gt/true objects\n"," mean_true_score = _safe_divide(sum_matched_score, n_true)\n"," panoptic_quality = _safe_divide(sum_matched_score, tp+fp/2+fn/2)\n","\n"," stats_dict = dict (\n"," criterion = criterion,\n"," thresh = thr,\n"," fp = fp,\n"," tp = tp,\n"," fn = fn,\n"," precision = precision(tp,fp,fn),\n"," recall = recall(tp,fp,fn),\n"," accuracy = accuracy(tp,fp,fn),\n"," f1 = f1(tp,fp,fn),\n"," n_true = n_true,\n"," n_pred = n_pred,\n"," mean_true_score = mean_true_score,\n"," mean_matched_score = mean_matched_score,\n"," panoptic_quality = panoptic_quality,\n"," )\n"," if bool(report_matches):\n"," if not_trivial:\n"," stats_dict.update (\n"," # int() to be json serializable\n"," matched_pairs = tuple((int(map_rev_true[i]),int(map_rev_pred[j])) for i,j in zip(1+true_ind,1+pred_ind)),\n"," matched_scores = tuple(scores[true_ind,pred_ind]),\n"," matched_tps = tuple(map(int,np.flatnonzero(match_ok))),\n"," )\n"," else:\n"," stats_dict.update (\n"," matched_pairs = (),\n"," matched_scores = (),\n"," matched_tps = (),\n"," )\n"," return namedtuple('Matching',stats_dict.keys())(*stats_dict.values())\n","\n"," return _single(thresh) if np.isscalar(thresh) else tuple(map(_single,thresh))\n","\n","\n","\n","def matching_dataset(y_true, y_pred, thresh=0.5, criterion='iou', by_image=False, show_progress=True, parallel=False):\n"," \"\"\"matching metrics for list of images, see `stardist.matching.matching`\n"," \"\"\"\n"," len(y_true) == len(y_pred) or _raise(ValueError(\"y_true and y_pred must have the same length.\"))\n"," return matching_dataset_lazy (\n"," tuple(zip(y_true,y_pred)), thresh=thresh, criterion=criterion, by_image=by_image, show_progress=show_progress, parallel=parallel,\n"," )\n","\n","\n","\n","def matching_dataset_lazy(y_gen, thresh=0.5, criterion='iou', by_image=False, show_progress=True, parallel=False):\n","\n"," expected_keys = set(('fp', 'tp', 'fn', 'precision', 'recall', 'accuracy', 'f1', 'criterion', 'thresh', 'n_true', 'n_pred', 'mean_true_score', 'mean_matched_score', 'panoptic_quality'))\n","\n"," single_thresh = False\n"," if np.isscalar(thresh):\n"," single_thresh = True\n"," thresh = (thresh,)\n","\n"," tqdm_kwargs = {}\n"," tqdm_kwargs['disable'] = not bool(show_progress)\n"," if int(show_progress) > 1:\n"," tqdm_kwargs['total'] = int(show_progress)\n","\n"," # compute matching stats for every pair of label images\n"," if parallel:\n"," from concurrent.futures import ThreadPoolExecutor\n"," fn = lambda pair: matching(*pair, thresh=thresh, criterion=criterion, report_matches=False)\n"," with ThreadPoolExecutor() as pool:\n"," stats_all = tuple(pool.map(fn, tqdm(y_gen,**tqdm_kwargs)))\n"," else:\n"," stats_all = tuple (\n"," matching(y_t, y_p, thresh=thresh, criterion=criterion, report_matches=False)\n"," for y_t,y_p in tqdm(y_gen,**tqdm_kwargs)\n"," )\n","\n"," # accumulate results over all images for each threshold separately\n"," n_images, n_threshs = len(stats_all), len(thresh)\n"," accumulate = [{} for _ in range(n_threshs)]\n"," for stats in stats_all:\n"," for i,s in enumerate(stats):\n"," acc = accumulate[i]\n"," for k,v in s._asdict().items():\n"," if k == 'mean_true_score' and not bool(by_image):\n"," # convert mean_true_score to \"sum_matched_score\"\n"," acc[k] = acc.setdefault(k,0) + v * s.n_true\n"," else:\n"," try:\n"," acc[k] = acc.setdefault(k,0) + v\n"," except TypeError:\n"," pass\n","\n"," # normalize/compute 'precision', 'recall', 'accuracy', 'f1'\n"," for thr,acc in zip(thresh,accumulate):\n"," set(acc.keys()) == expected_keys or _raise(ValueError(\"unexpected keys\"))\n"," acc['criterion'] = criterion\n"," acc['thresh'] = thr\n"," acc['by_image'] = bool(by_image)\n"," if bool(by_image):\n"," for k in ('precision', 'recall', 'accuracy', 'f1', 'mean_true_score', 'mean_matched_score', 'panoptic_quality'):\n"," acc[k] /= n_images\n"," else:\n"," tp, fp, fn, n_true = acc['tp'], acc['fp'], acc['fn'], acc['n_true']\n"," sum_matched_score = acc['mean_true_score']\n","\n"," mean_matched_score = _safe_divide(sum_matched_score, tp)\n"," mean_true_score = _safe_divide(sum_matched_score, n_true)\n"," panoptic_quality = _safe_divide(sum_matched_score, tp+fp/2+fn/2)\n","\n"," acc.update(\n"," precision = precision(tp,fp,fn),\n"," recall = recall(tp,fp,fn),\n"," accuracy = accuracy(tp,fp,fn),\n"," f1 = f1(tp,fp,fn),\n"," mean_true_score = mean_true_score,\n"," mean_matched_score = mean_matched_score,\n"," panoptic_quality = panoptic_quality,\n"," )\n","\n"," accumulate = tuple(namedtuple('DatasetMatching',acc.keys())(*acc.values()) for acc in accumulate)\n"," return accumulate[0] if single_thresh else accumulate\n","\n","\n","\n","# copied from scikit-image master for now (remove when part of a release)\n","def relabel_sequential(label_field, offset=1):\n"," \"\"\"Relabel arbitrary labels to {`offset`, ... `offset` + number_of_labels}.\n"," This function also returns the forward map (mapping the original labels to\n"," the reduced labels) and the inverse map (mapping the reduced labels back\n"," to the original ones).\n"," Parameters\n"," ----------\n"," label_field : numpy array of int, arbitrary shape\n"," An array of labels, which must be non-negative integers.\n"," offset : int, optional\n"," The return labels will start at `offset`, which should be\n"," strictly positive.\n"," Returns\n"," -------\n"," relabeled : numpy array of int, same shape as `label_field`\n"," The input label field with labels mapped to\n"," {offset, ..., number_of_labels + offset - 1}.\n"," The data type will be the same as `label_field`, except when\n"," offset + number_of_labels causes overflow of the current data type.\n"," forward_map : numpy array of int, shape ``(label_field.max() + 1,)``\n"," The map from the original label space to the returned label\n"," space. Can be used to re-apply the same mapping. See examples\n"," for usage. The data type will be the same as `relabeled`.\n"," inverse_map : 1D numpy array of int, of length offset + number of labels\n"," The map from the new label space to the original space. This\n"," can be used to reconstruct the original label field from the\n"," relabeled one. The data type will be the same as `relabeled`.\n"," Notes\n"," -----\n"," The label 0 is assumed to denote the background and is never remapped.\n"," The forward map can be extremely big for some inputs, since its\n"," length is given by the maximum of the label field. However, in most\n"," situations, ``label_field.max()`` is much smaller than\n"," ``label_field.size``, and in these cases the forward map is\n"," guaranteed to be smaller than either the input or output images.\n"," Examples\n"," --------\n"," >>> from skimage.segmentation import relabel_sequential\n"," >>> label_field = np.array([1, 1, 5, 5, 8, 99, 42])\n"," >>> relab, fw, inv = relabel_sequential(label_field)\n"," >>> relab\n"," array([1, 1, 2, 2, 3, 5, 4])\n"," >>> fw\n"," array([0, 1, 0, 0, 0, 2, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n"," 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0,\n"," 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n"," 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n"," 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5])\n"," >>> inv\n"," array([ 0, 1, 5, 8, 42, 99])\n"," >>> (fw[label_field] == relab).all()\n"," True\n"," >>> (inv[relab] == label_field).all()\n"," True\n"," >>> relab, fw, inv = relabel_sequential(label_field, offset=5)\n"," >>> relab\n"," array([5, 5, 6, 6, 7, 9, 8])\n"," \"\"\"\n"," offset = int(offset)\n"," if offset <= 0:\n"," raise ValueError(\"Offset must be strictly positive.\")\n"," if np.min(label_field) < 0:\n"," raise ValueError(\"Cannot relabel array that contains negative values.\")\n"," max_label = int(label_field.max()) # Ensure max_label is an integer\n"," if not np.issubdtype(label_field.dtype, np.integer):\n"," new_type = np.min_scalar_type(max_label)\n"," label_field = label_field.astype(new_type)\n"," labels = np.unique(label_field)\n"," labels0 = labels[labels != 0]\n"," new_max_label = offset - 1 + len(labels0)\n"," new_labels0 = np.arange(offset, new_max_label + 1)\n"," output_type = label_field.dtype\n"," required_type = np.min_scalar_type(new_max_label)\n"," if np.dtype(required_type).itemsize > np.dtype(label_field.dtype).itemsize:\n"," output_type = required_type\n"," forward_map = np.zeros(max_label + 1, dtype=output_type)\n"," forward_map[labels0] = new_labels0\n"," inverse_map = np.zeros(new_max_label + 1, dtype=output_type)\n"," inverse_map[offset:] = labels0\n"," relabeled = forward_map[label_field]\n"," return relabeled, forward_map, inverse_map\n","\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Tbv6DpxZjVN3"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased.\n","\n","\n"]},{"cell_type":"code","metadata":{"cellView":"form","id":"jtSv-B0AjX8j"},"source":["#@markdown ###Not implemented yet"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"s2VXDuiOF7r4"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","This section will calculate the Intersection over Union score for all the images provided in the Source_QC_folder and Target_QC_folder ! The result for one of the image will also be displayed.\n","\n","The **Intersection over Union** (IuO) metric is a method that can be used to quantify the percent overlap between the target mask and your prediction output. **Therefore, the closer to 1, the better the performance.** This metric can be used to assess the quality of your model to accurately predict nuclei. \n","\n","Here, the IuO is both calculated over the whole image and on a per-object basis. The value displayed below is the IuO value calculated over the entire image. The IuO value calculated on a per-object basis is used to calculate the other metrics displayed.\n","\n","“n_true” refers to the number of objects present in the ground truth image. “n_pred” refers to the number of objects present in the predicted image. \n","\n","When a segmented object has an IuO value above 0.5 (compared to the corresponding ground truth), it is then considered a true positive. The number of “**true positives**” is available in the table below. The number of “false positive” is then defined as “**false positive**” = “n_pred” - “true positive”. The number of “false negative” is defined as “false negative” = “n_true” - “true positive”.\n","\n","The mean_matched_score is the mean IoUs of matched true positives. The mean_true_score is the mean IoUs of matched true positives but normalized by the total number of ground truth objects. The panoptic_quality is calculated as described by [Kirillov et al. 2019](https://arxiv.org/abs/1801.00868).\n","\n","For more information about the other metric displayed, please consult the SI of the paper describing ZeroCostDL4Mic.\n","\n"," The results can be found in the \"*Quality Control*\" folder which is located inside your \"model_folder\".\n","\n","**`model_choice`:** Choose the model to use to make predictions. This model needs to be a Cellpose model. You can also use the pretrained models already available in cellpose: \n","\n","- The cytoplasm model in cellpose is trained on two-channel images, where the first channel is the channel to segment, and the second channel is an optional nuclear channel.\n","- The cytoplasm2 model is an updated cytoplasm model trained with user-submitted images.\n","\n","- The nuclear model in cellpose is trained on two-channel images, where the first channel is the channel to segment, and the second channel is always set to an array of zeros.\n","\n","**`Channel_to_segment`:** Choose the channel to segment. If using single-channel grayscale images, choose \"Grayscale\".\n","\n","**`Nuclear_channel`:** If you are using a model that segment the \"cytoplasm\", you can use a nuclear channel to aid the segmentation. \n","\n","**`Object_diameter`:** Indicate the diameter of the objects (cells or Nuclei) you want to segment (in pixel). If you input \"0\", this parameter will be estimated automatically for each of your images.\n","\n","**`Flow_threshold`:** This parameter controls the maximum allowed error of the flows for each mask. Increase this threshold if cellpose is not returning as many masks as you'd expect. Similarly, decrease this threshold if cellpose is returning too many ill-shaped masks. **Default value: 0.4**\n","\n","**`Cell_probability_threshold`:** The pixels greater than the Cell_probability_threshold are used to run dynamics and determine masks. Decrease this threshold if cellpose is not returning as many masks as you'd expect. Similarly, increase this threshold if cellpose is returning too many masks, particularly from dim areas. **Default value: 0.0**"]},{"cell_type":"code","metadata":{"id":"BUrTuonhEH5J","cellView":"form"},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","Channel_to_segment= \"Grayscale\" #@param [\"Grayscale\", \"Blue\", \"Green\", \"Red\"]\n","\n","# @markdown ###If you chose the model \"cytoplasm\" indicate if you also have a nuclear channel that can be used to aid the segmentation.\n","\n","Nuclear_channel= \"None\" #@param [\"None\", \"Blue\", \"Green\", \"Red\"]\n","\n","#@markdown ### Segmentation parameters:\n","Object_diameter = 0#@param {type:\"number\"}\n","\n","Flow_threshold = 0.4 #@param {type:\"slider\", min:0.1, max:1.1, step:0.1}\n","Cell_probability_threshold=0 #@param {type:\"slider\", min:-6, max:6, step:1}\n","\n","if Object_diameter is 0:\n"," Object_diameter = None\n"," print(\"The cell size will be estimated automatically for each image\")\n","\n","\n","# Find the number of channel in the input image\n","\n","random_choice = random.choice(os.listdir(Source_QC_folder))\n","x = io.imread(Source_QC_folder+\"/\"+random_choice)\n","n_channel = 1 if x.ndim == 2 else x.shape[-1]\n","\n","if Channel_to_segment == \"Grayscale\":\n"," segment_channel = 0\n","\n"," if not n_channel == 1:\n"," print(bcolors.WARNING +\"!! WARNING: your image has more than one channel, choose which channel you want to use for QC !!\")\n","\n","if Channel_to_segment == \"Blue\":\n"," segment_channel = 3\n","\n","if Channel_to_segment == \"Green\":\n"," segment_channel = 2\n","\n","if Channel_to_segment == \"Red\":\n"," segment_channel = 1\n","\n","if Nuclear_channel == \"Blue\":\n"," nuclear_channel = 3\n","\n","if Nuclear_channel == \"Green\":\n"," nuclear_channel = 2\n","\n","if Nuclear_channel == \"Red\":\n"," nuclear_channel = 1\n","\n","if Nuclear_channel == \"None\":\n"," nuclear_channel = 0\n","\n","if QC_model == \"Cytoplasm\": \n"," channels=[segment_channel,nuclear_channel]\n","\n","if QC_model == \"Cytoplasm2\": \n"," channels=[segment_channel,nuclear_channel]\n"," \n","if QC_model == \"Nuclei\":\n"," channels=[segment_channel,0]\n","\n","if QC_model == \"Own_model\":\n"," channels=[segment_channel,nuclear_channel]\n","\n","#Create a quality control Folder and check if the folder already exist\n","if os.path.exists(QC_model_folder+\"/Quality Control\") == False:\n"," os.makedirs(QC_model_folder+\"/Quality Control\")\n","\n","if os.path.exists(QC_model_folder+\"/Quality Control/Prediction\"):\n"," shutil.rmtree(QC_model_folder+\"/Quality Control/Prediction\")\n","os.makedirs(QC_model_folder+\"/Quality Control/Prediction\")\n","\n","\n","# Here we need to make predictions\n","\n","for name in os.listdir(Source_QC_folder):\n"," \n"," print(\"Performing prediction on: \"+name)\n"," image = io.imread(Source_QC_folder+\"/\"+name) \n","\n"," short_name = os.path.splitext(name)\n"," \n"," if QC_model == \"Own_model\":\n"," masks, flows, styles = model.eval(image, diameter=Object_diameter, flow_threshold=Flow_threshold,cellprob_threshold=Cell_probability_threshold, channels=channels)\n"," else:\n"," masks, flows, styles, diams = model.eval(image, diameter=Object_diameter, flow_threshold=Flow_threshold,cellprob_threshold=Cell_probability_threshold, channels=channels)\n"," \n"," os.chdir(QC_model_folder+\"/Quality Control/Prediction\")\n"," imsave(str(short_name[0])+\".tif\", masks, compress=ZIP_DEFLATED) \n"," \n","# Here we start testing the differences between GT and predicted masks\n","\n","with open(QC_model_folder+\"/Quality Control/Quality_Control for \"+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file, delimiter=\",\")\n"," writer.writerow([\"image\",\"Prediction v. GT Intersection over Union\", \"false positive\", \"true positive\", \"false negative\", \"precision\", \"recall\", \"accuracy\", \"f1 score\", \"n_true\", \"n_pred\", \"mean_true_score\", \"mean_matched_score\", \"panoptic_quality\"]) \n","\n","# define the images\n","\n"," for n in os.listdir(Source_QC_folder):\n"," \n"," if not os.path.isdir(os.path.join(Source_QC_folder,n)):\n"," print('Running QC on: '+n)\n"," test_input = io.imread(os.path.join(Source_QC_folder,n))\n"," test_prediction = io.imread(os.path.join(QC_model_folder+\"/Quality Control/Prediction\",n))\n"," test_ground_truth_image = io.imread(os.path.join(Target_QC_folder, n))\n","\n"," # Calculate the matching (with IoU threshold `thresh`) and all metrics\n","\n"," stats = matching(test_ground_truth_image, test_prediction, thresh=0.5)\n"," \n","\n"," #Convert pixel values to 0 or 255\n"," test_prediction_0_to_255 = test_prediction\n"," test_prediction_0_to_255[test_prediction_0_to_255>0] = 255\n","\n"," #Convert pixel values to 0 or 255\n"," test_ground_truth_0_to_255 = test_ground_truth_image\n"," test_ground_truth_0_to_255[test_ground_truth_0_to_255>0] = 255\n","\n","\n"," # Intersection over Union metric\n","\n"," intersection = np.logical_and(test_ground_truth_0_to_255, test_prediction_0_to_255)\n"," union = np.logical_or(test_ground_truth_0_to_255, test_prediction_0_to_255)\n"," iou_score = np.sum(intersection) / np.sum(union)\n"," writer.writerow([n, str(iou_score), str(stats.fp), str(stats.tp), str(stats.fn), str(stats.precision), str(stats.recall), str(stats.accuracy), str(stats.f1), str(stats.n_true), str(stats.n_pred), str(stats.mean_true_score), str(stats.mean_matched_score), str(stats.panoptic_quality)])\n","\n","from tabulate import tabulate\n","\n","df = pd.read_csv (QC_model_folder+\"/Quality Control/Quality_Control for \"+QC_model_name+\".csv\")\n","print(tabulate(df, headers='keys', tablefmt='psql'))\n","\n","\n","from astropy.visualization import simple_norm\n","\n","# ------------- For display ------------\n","print('--------------------------------------------------------------')\n","@interact\n","def show_QC_results(file = os.listdir(Source_QC_folder)):\n"," \n","\n"," plt.figure(figsize=(25,5))\n"," if n_channel > 1:\n"," source_image = io.imread(os.path.join(Source_QC_folder, file))\n"," if n_channel == 1:\n"," source_image = io.imread(os.path.join(Source_QC_folder, file), as_gray = True)\n","\n"," target_image = io.imread(os.path.join(Target_QC_folder, file), as_gray = True)\n"," prediction = io.imread(QC_model_folder+\"/Quality Control/Prediction/\"+file, as_gray = True)\n","\n"," stats = matching(prediction, target_image, thresh=0.5)\n","\n"," target_image_mask = np.empty_like(target_image)\n"," target_image_mask[target_image > 0] = 255\n"," target_image_mask[target_image == 0] = 0\n"," \n"," prediction_mask = np.empty_like(prediction)\n"," prediction_mask[prediction > 0] = 255\n"," prediction_mask[prediction == 0] = 0\n","\n"," intersection = np.logical_and(target_image_mask, prediction_mask)\n"," union = np.logical_or(target_image_mask, prediction_mask)\n"," iou_score = np.sum(intersection) / np.sum(union)\n","\n"," norm = simple_norm(source_image, percent = 99)\n","\n"," #Input\n"," plt.subplot(1,4,1)\n"," plt.axis('off')\n"," if n_channel > 1:\n"," plt.imshow(source_image)\n"," if n_channel == 1:\n"," plt.imshow(source_image, aspect='equal', norm=norm, cmap='magma', interpolation='nearest')\n"," plt.title('Input')\n","\n"," #Ground-truth\n"," plt.subplot(1,4,2)\n"," plt.axis('off')\n"," plt.imshow(target_image_mask, aspect='equal', cmap='Greens')\n"," plt.title('Ground Truth')\n","\n"," #Prediction\n"," plt.subplot(1,4,3)\n"," plt.axis('off')\n"," plt.imshow(prediction_mask, aspect='equal', cmap='Purples')\n"," plt.title('Prediction')\n","\n"," #Overlay\n"," plt.subplot(1,4,4)\n"," plt.axis('off')\n"," plt.imshow(target_image_mask, cmap='Greens')\n"," plt.imshow(prediction_mask, alpha=0.5, cmap='Purples')\n"," plt.title('Ground Truth and Prediction, Intersection over Union:'+str(round(iou_score,3 )));\n"," plt.savefig(QC_model_folder+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","\n","\n","full_QC_model_path = QC_model_folder+'/'\n","qc_pdf_export()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Io62PUMLagFS"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"E29LWfWpjkZU"},"source":["\n","\n","## **6.1 Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.3) can now be used to process images. If an older model needs to be used, please untick the **Use_the_current_trained_model** box and enter the model's name and path to use. Predicted output images are saved in your **Prediction_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contain the images that you want to predict using the network that you will train.\n","\n","**`Result_folder`:** This folder will contain the predicted output.\n","\n","**`Data_type`:** Please indicate if the images you want to predict are single images or stacks\n","\n","**`model_choice`:** Choose the model to use to make predictions. This model needs to be a Cellpose model. You can also use the pretrained models already available in cellpose: \n","\n","- The cytoplasm model in cellpose is trained on two-channel images, where the first channel is the channel to segment, and the second channel is an optional nuclear channel.\n","- The cytoplasm2 model is an updated cytoplasm model trained with user-submitted images.\n","\n","- The nuclear model in cellpose is trained on two-channel images, where the first channel is the channel to segment, and the second channel is always set to an array of zeros.\n","\n","**`Channel_to_segment`:** Choose the channel to segment. If using single-channel grayscale images, choose \"Grayscale\".\n","\n","**`Nuclear_channel`:** If you are using a model that segment the \"cytoplasm\", you can use a nuclear channel to aid the segmentation. \n","\n","**`Object_diameter`:** Indicate the diameter of the objects (cells or Nuclei) you want to segment (in pixel). If you input \"0\", this parameter will be estimated automatically for each of your images.\n","\n","**`Flow_threshold`:** This parameter controls the maximum allowed error of the flows for each mask. Increase this threshold if cellpose is not returning as many masks as you'd expect. Similarly, decrease this threshold if cellpose is returning too many ill-shaped masks. **Default value: 0.4**\n","\n","**`Cell_probability_threshold`:** The pixels greater than the Cell_probability_threshold are used to run dynamics and determine masks. Decrease this threshold if cellpose is not returning as many masks as you'd expect. Similarly, increase this threshold if cellpose is returning too many masks, particularly from dim areas. **Default value: 0.0**\n","\n","**IMPORTANT:** One example result will be displayed first so that you can assess the quality of the prediction and change your settings accordingly. Once the most suitable settings have been chosen, press on the yellow button \"process your images\".\n","\n","\n","\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"mfgvhMk2xid9","cellView":"form"},"source":["\n","#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then play the cell to predict outputs from your unseen images.\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n","#@markdown ###Are your data single images or stacks?\n","\n","Data_type = \"Single_Images\" #@param [\"Single_Images\", \"Stacks\"]\n","\n","#@markdown ###What model do you want to use?\n","\n","model_choice = \"Cytoplasm2\" #@param [\"Cytoplasm\",\"Cytoplasm2\", \"Nuclei\", \"Own_model\"]\n","\n","#@markdown ####If using your own model, please provide the path to the model (not the folder):\n","\n","Prediction_model = \"\" #@param {type:\"string\"}\n","\n","#@markdown ### What channel do you want to segment?\n","\n","Channel_to_segment= \"Grayscale\" #@param [\"Grayscale\", \"Blue\", \"Green\", \"Red\"]\n","\n","# @markdown ###If you chose the model \"cytoplasm\" indicate if you also have a nuclear channel that can be used to aid the segmentation.\n","\n","Nuclear_channel= \"None\" #@param [\"None\", \"Blue\", \"Green\", \"Red\"]\n","\n","#@markdown ### Segmentation parameters:\n","Object_diameter = 0#@param {type:\"number\"}\n","\n","Flow_threshold = 0.4 #@param {type:\"slider\", min:0.1, max:1.1, step:0.1}\n","Cell_probability_threshold=0 #@param {type:\"slider\", min:-6, max:6, step:1}\n","\n","# Find the number of channel in the input image\n","\n","random_choice = random.choice(os.listdir(Data_folder))\n","x = io.imread(Data_folder+\"/\"+random_choice)\n","n_channel = 1 if x.ndim == 2 else x.shape[-1]\n","\n","if Channel_to_segment == \"Grayscale\":\n"," segment_channel = 0\n","\n"," if Data_type == \"Single_Images\":\n"," if not n_channel == 1:\n"," print(bcolors.WARNING +\"!! WARNING: your image has more than one channel, choose which channel you want to use for your predictions !!\")\n","\n","if Channel_to_segment == \"Blue\":\n"," segment_channel = 3\n","\n","if Channel_to_segment == \"Green\":\n"," segment_channel = 2\n","\n","if Channel_to_segment == \"Red\":\n"," segment_channel = 1\n","\n","if Nuclear_channel == \"Blue\":\n"," nuclear_channel = 3\n","\n","if Nuclear_channel == \"Green\":\n"," nuclear_channel = 2\n","\n","if Nuclear_channel == \"Red\":\n"," nuclear_channel = 1\n","\n","if Nuclear_channel == \"None\":\n"," nuclear_channel = 0\n","\n","if model_choice == \"Cytoplasm\": \n"," channels=[segment_channel,nuclear_channel]\n"," model = models.Cellpose(gpu=True, model_type=\"cyto\")\n"," print(\"Cytoplasm model enabled\")\n","\n","if model_choice == \"Cytoplasm2\": \n"," channels=[segment_channel,nuclear_channel]\n"," model = models.Cellpose(gpu=True, model_type=\"cyto2\")\n"," print(\"Cytoplasm model enabled\")\n"," \n","if model_choice == \"Nuclei\":\n"," channels=[segment_channel,0]\n"," model = models.Cellpose(gpu=True, model_type=\"nuclei\")\n"," print(\"Nuclei model enabled\")\n","\n","if model_choice == \"Own_model\":\n"," channels=[segment_channel,nuclear_channel]\n"," model = models.CellposeModel(gpu=True, pretrained_model=Prediction_model, torch=True, diam_mean=30.0, net_avg=True, device=None, residual_on=True, style_on=True, concatenation=False)\n","\n"," print(\"Own model enabled\")\n","\n","if Object_diameter is 0:\n"," Object_diameter = None\n"," print(\"The cell size will be estimated automatically for each image\")\n","\n","if Data_type == \"Single_Images\" :\n","\n"," print('--------------------------------------------------------------')\n"," @interact\n"," def preview_results(file = os.listdir(Data_folder)):\n"," source_image = io.imread(os.path.join(Data_folder, file))\n"," \n"," if model_choice == \"Own_model\":\n"," masks, flows, styles = model.eval(source_image, diameter=Object_diameter, flow_threshold=Flow_threshold,cellprob_threshold=Cell_probability_threshold, channels=channels)\n"," else:\n"," masks, flows, styles, diams = model.eval(source_image, diameter=Object_diameter, flow_threshold=Flow_threshold,cellprob_threshold=Cell_probability_threshold, channels=channels)\n"," \n"," flowi = flows[0]\n"," fig = plt.figure(figsize=(20,10))\n"," plot.show_segmentation(fig, source_image, masks, flowi, channels=channels)\n"," plt.tight_layout()\n"," plt.show()\n","\n","\n"," def batch_process():\n"," print(\"Your images are now beeing processed\")\n","\n"," for name in os.listdir(Data_folder):\n"," print(\"Performing prediction on: \"+name)\n"," image = io.imread(Data_folder+\"/\"+name)\n"," short_name = os.path.splitext(name)\n"," \n"," if model_choice == \"Own_model\":\n"," masks, flows, styles = model.eval(image, diameter=Object_diameter, flow_threshold=Flow_threshold,cellprob_threshold=Cell_probability_threshold, channels=channels)\n"," else:\n"," masks, flows, styles, diams = model.eval(image, diameter=Object_diameter, flow_threshold=Flow_threshold,cellprob_threshold=Cell_probability_threshold, channels=channels)\n"," \n"," os.chdir(Result_folder)\n"," imsave(str(short_name[0])+\"_mask.tif\", masks, compress=ZIP_DEFLATED)\n","\n"," im = interact_manual(batch_process)\n"," im.widget.children[0].description = 'Process your images'\n"," im.widget.children[0].style.button_color = 'yellow'\n"," display(im)\n","\n","if Data_type == \"Stacks\" :\n"," print(\"Stacks are now beeing predicted\")\n"," \n"," print('--------------------------------------------------------------')\n"," @interact\n"," def preview_results_stacks(file = os.listdir(Data_folder)):\n"," timelapse = imread(Data_folder+\"/\"+file)\n","\n"," if model_choice == \"Own_model\":\n"," masks, flows, styles = model.eval(timelapse[0], diameter=Object_diameter, flow_threshold=Flow_threshold,cellprob_threshold=Cell_probability_threshold, channels=channels)\n"," else:\n"," masks, flows, styles, diams = model.eval(timelapse[0], diameter=Object_diameter, flow_threshold=Flow_threshold,cellprob_threshold=Cell_probability_threshold, channels=channels)\n"," \n"," flowi = flows[0]\n"," fig = plt.figure(figsize=(20,10))\n"," plot.show_segmentation(fig, timelapse[0], masks, flowi, channels=channels)\n"," plt.tight_layout()\n"," plt.show()\n","\n"," def batch_process_stack():\n"," print(\"Your images are now beeing processed\") \n"," for image in os.listdir(Data_folder):\n"," print(\"Performing prediction on: \"+image)\n"," timelapse = imread(Data_folder+\"/\"+image)\n"," short_name = os.path.splitext(image)\n"," n_timepoint = timelapse.shape[0]\n"," prediction_stack = np.zeros((n_timepoint, timelapse.shape[1], timelapse.shape[2]))\n"," \n"," for t in range(n_timepoint):\n"," print(\"Frame number: \"+str(t))\n"," img_t = timelapse[t]\n","\n"," if model_choice == \"Own_model\":\n"," masks, flows, styles = model.eval(img_t, diameter=Object_diameter, flow_threshold=Flow_threshold,cellprob_threshold=Cell_probability_threshold, channels=channels)\n"," else:\n"," masks, flows, styles, diams = model.eval(img_t, diameter=Object_diameter, flow_threshold=Flow_threshold,cellprob_threshold=Cell_probability_threshold, channels=channels)\n"," \n"," \n"," prediction_stack[t] = masks\n"," \n"," prediction_stack_32 = img_as_float32(prediction_stack, force_copy=False)\n"," os.chdir(Result_folder)\n"," imsave(str(short_name[0])+\".tif\", prediction_stack_32, compress=ZIP_DEFLATED)\n"," \n"," im = interact_manual(batch_process_stack)\n"," im.widget.children[0].description = 'Process your images'\n"," im.widget.children[0].style.button_color = 'yellow'\n"," display(im) \n"," \n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB"},"source":["## **6.2. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"dHUKCoSZ7dzV"},"source":["# **7. Version log**\n","---\n","**v1.13**: \n","\n","\n","* Training now uses TORCH. \n","* This version also now includes built-in version check and the version log that you're reading now."]},{"cell_type":"markdown","metadata":{"id":"u4pcBe8Z3T2J"},"source":["#**Thank you for using Cellpose 2D!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/Beta notebooks/DRMIME_2D_ZeroCostDL4Mic.ipynb b/Colab_notebooks/Beta notebooks/DRMIME_2D_ZeroCostDL4Mic.ipynb index 3ed99d39..b6f09ee2 100644 --- a/Colab_notebooks/Beta notebooks/DRMIME_2D_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/Beta notebooks/DRMIME_2D_ZeroCostDL4Mic.ipynb @@ -1 +1 @@ -{"nbformat":4,"nbformat_minor":0,"metadata":{"accelerator":"GPU","colab":{"name":"DRMIME_2D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1hzAI0joLETcG5sI2Qvo8AKDr0TWRKySJ","timestamp":1587653755731},{"file_id":"1QFcz4NnQv4rMwDNl7AzHajN-Ola9sUFW","timestamp":1586411847878},{"file_id":"12UDRQ7abcnXcf5FctR9IUStgCpBiQWn7","timestamp":1584466922281},{"file_id":"1zXCn3A39GI1MCnXK_g_Z-AWh9vkB0YhU","timestamp":1583244415636}],"collapsed_sections":[],"toc_visible":true},"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.9"}},"cells":[{"cell_type":"markdown","metadata":{"id":"IkSguVy8Xv83"},"source":["# **DRMIME (2D)**\n","\n","---\n","\n"," DRMIME is a self-supervised deep-learning method that can be used to register 2D images.\n","\n"," **This particular notebook enables self-supervised registration of 2D dataset.**\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories. \n","\n","\n","While this notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (ZeroCostDL4Mic), this notebook structure substantially deviates from other ZeroCostDL4Mic notebooks and our template. This is because the deep learning method employed here is used to improve the image registration process. No Deep Learning models are actually saved, only the registered images. \n","\n","\n","This notebook is largely based on the following paper:\n","\n","DRMIME: Differentiable Mutual Information and Matrix Exponential for Multi-Resolution Image Registration by Abhishek Nan\n"," *et al.* published on arXiv in 2020 (https://arxiv.org/abs/2001.09865)\n","\n","And source code found in: /~https://github.com/abnan/DRMIME\n","\n","**Please also cite this original paper when using or developing this notebook.**\n"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"gKDLkLWUd-YX"},"source":["# **0. Before getting started**\n","---\n","\n","Before you run the notebook, please ensure that you are logged into your Google account and have the training and/or data to process in your Google Drive.\n","\n","For DRMIME to train, it requires at least two images. One **`\"Fixed image\"`** (template for the registration) and one **`Moving Image`** (image to be registered). Multiple **`Moving Images`** can also be provided if you want to register them to the same **`\"Fixed image\"`**. If you provide several **`Moving Images`**, multiple DRMIME instances will run one after another. \n","\n","The registration can also be applied to other channels. If you wish to apply the registration to other channels, please provide the images in another folder and carefully check your file names. Additional channels need to have the same name as the registered images and a prefix indicating the channel number starting at \"C1_\". See the example below. \n","\n","Here is a common data structure that can work:\n","\n","* Data\n"," \n"," - **Fixed_image_folder**\n"," - img_1.tif (image used as template for the registration)\n"," - **Moving_image_folder**\n"," - img_3.tif, img_4.tif, ... (images to be registered) \n"," - **Folder_containing_additional_channels** (optional, if you want to apply the registration to other channel(s))\n"," - C1_img_3.tif, C1_img_4.tif, ...\n"," - C2_img_3.tif, C2_img_4.tif, ...\n"," - C3_img_3.tif, C3_img_4.tif, ...\n"," - **Results**\n","\n","The **Results** folder will contain the processed images and PDF reports. Your original images remain unmodified.\n","\n","---\n","\n"]},{"cell_type":"markdown","metadata":{"id":"cbTknRcviyT7"},"source":["# **1. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"cellView":"form","id":"h5i5CS2bSmZr"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","#%tensorflow_version 1.x\n","\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n3B3meGTbYVi"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"cellView":"form","id":"01Djr8v-5pPk"},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **2. Install DRMIME and dependencies**\n","---"]},{"cell_type":"code","metadata":{"id":"fq21zJVFNASx","cellView":"form"},"source":["Notebook_version = ['1.12']\n","\n","\n","\n","#@markdown ##Install DRMIME and dependencies\n","\n","\n","# Here we install DRMIME and other required packages\n","\n","!pip install wget\n","\n","from skimage import io\n","import numpy as np\n","import math\n","import matplotlib.pyplot as plt\n","import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","from torch.autograd import Variable\n","import torch.optim as optim\n","from skimage.transform import pyramid_gaussian\n","from skimage.filters import gaussian\n","from skimage.filters import threshold_otsu\n","from skimage.filters import sobel\n","from skimage.color import rgb2gray\n","from skimage import feature\n","from torch.autograd import Function\n","import cv2\n","from IPython.display import clear_output\n","import pandas as pd\n","from skimage.io import imsave\n","\n","\n","\n","device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n","\n","\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","\n","# Check if this is the latest version of the notebook\n","Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n","\n","if Notebook_version == list(Latest_notebook_version.columns):\n"," print(\"This notebook is up-to-date.\")\n","\n","if not Notebook_version == list(Latest_notebook_version.columns):\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","!pip freeze > requirements.txt\n","\n","#Create a pdf document with training summary, not yet implemented\n","\n","def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n"," # save FPDF() class into a \n"," # variable pdf \n"," #from datetime import datetime\n","\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'CARE 2D'\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," # add another cell \n"," if trained:\n"," training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n"," pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n"," pdf.ln(1)\n","\n"," Header_2 = 'Information for your materials and methods:'\n"," pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n"," #print(all_packages)\n","\n"," #Main Packages\n"," main_packages = ''\n"," version_numbers = []\n"," for name in ['tensorflow','numpy','Keras','csbdeep']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n"," cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n"," cuda_version = cuda_version.stdout.decode('utf-8')\n"," cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n"," gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n"," gpu_name = gpu_name.stdout.decode('utf-8')\n"," gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n"," #print(cuda_version[cuda_version.find(', V')+3:-1])\n"," #print(gpu_name)\n","\n"," shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n"," dataset_size = len(os.listdir(Training_source))\n","\n"," text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(dataset_size*number_of_patches)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+config.train_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," if pretrained_model:\n"," text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(dataset_size*number_of_patches)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+config.train_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was re-trained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.multi_cell(190, 5, txt = text, align='L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n"," pdf.set_font('')\n"," if augmentation:\n"," aug_text = 'The dataset was augmented by a factor of '+str(Multiply_dataset_by)+' by'\n"," if rotate_270_degrees != 0 or rotate_90_degrees != 0:\n"," aug_text = aug_text+'\\n- rotation'\n"," if flip_left_right != 0 or flip_top_bottom != 0:\n"," aug_text = aug_text+'\\n- flipping'\n"," if random_zoom_magnification != 0:\n"," aug_text = aug_text+'\\n- random zoom magnification'\n"," if random_distortion != 0:\n"," aug_text = aug_text+'\\n- random distortion'\n"," if image_shear != 0:\n"," aug_text = aug_text+'\\n- image shearing'\n"," if skew_image != 0:\n"," aug_text = aug_text+'\\n- image skewing'\n"," else:\n"," aug_text = 'No augmentation was used for training.'\n"," pdf.multi_cell(190, 5, txt=aug_text, align='L')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," if Use_Default_Advanced_Parameters:\n"," pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n"," pdf.cell(200, 5, txt='The following parameters were used for training:')\n"," pdf.ln(1)\n"," html = \"\"\" \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
number_of_epochs{0}
patch_size{1}
number_of_patches{2}
batch_size{3}
number_of_steps{4}
percentage_validation{5}
initial_learning_rate{6}
\n"," \"\"\".format(number_of_epochs,str(patch_size)+'x'+str(patch_size),number_of_patches,batch_size,number_of_steps,percentage_validation,initial_learning_rate)\n"," pdf.write_html(html)\n","\n"," #pdf.multi_cell(190, 5, txt = text_2, align='L')\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," pdf.ln(1)\n"," pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(29, 5, txt= 'Training_source:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(27, 5, txt= 'Training_target:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n"," #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(22, 5, txt= 'Model Path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n"," pdf.ln(1)\n"," pdf.cell(60, 5, txt = 'Example Training pair', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread('/content/TrainingDataExample_CARE2D.png').shape\n"," pdf.image('/content/TrainingDataExample_CARE2D.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- CARE: Weigert, Martin, et al. \"Content-aware image restoration: pushing the limits of fluorescence microscopy.\" Nature methods 15.12 (2018): 1090-1097.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," if augmentation:\n"," ref_3 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\n"," pdf.multi_cell(190, 5, txt = ref_3, align='L')\n"," pdf.ln(3)\n"," reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(model_path+'/'+model_name+'/'+model_name+\"_training_report.pdf\")\n","\n","\n","\n","\n","print(\"Libraries installed\")\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"HLYcZR9gMv42"},"source":["# **3. Select your parameters and paths**\n","---"]},{"cell_type":"markdown","metadata":{"id":"Kbn9_JdqnNnK"},"source":["## **3.1. Setting main training parameters**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"id":"CB6acvUFtWqd"},"source":[" **Paths for training, predictions and results**\n","These is the path to your folders containing the image you want to register. To find the path of the folder containing your datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`Fixed_image_folder`:** This is the folder containing your \"Fixed image\".\n","\n","**`Moving_image_folder`:** This is the folder containing your \"Moving Image(s)\".\n","\n","**`Result_folder`:** This is the folder where your results will be saved.\n","\n","\n","**Training Parameters**\n","\n","**`model_name`:** Choose a name for your model.\n","\n","**`number_of_iteration`:** Input how many iteration (rounds) the network will be trained. Preliminary results can already be observed after a 200 iterations, but a full training should run for 500-1000 iterations. **Default value: 500**\n","\n","**`Registration_mode`:** Choose which registration method you would like to use.\n","\n","**Additional channels**\n","\n"," This option enable you to apply the registration to other images (for instance other channels). Place these images in the **`Additional_channels_folder`**. Additional channels need to have the same name as the images you want to register (found in **`Moving_image_folder`**) and a prefix indicating the channel number starting at \"C1_\".\n","\n"," \n","**Advanced Parameters - experienced users only**\n","\n","**`n_neurons`:** Number of neurons (elementary constituents) that will assemble your model. **Default value: 100**.\n","\n","**`mine_initial_learning_rate`:** Input the initial value to be used as learning rate for MINE. **Default value: 0.001**\n","**`homography_net_vL_initial_learning_rate`:** Input the initial value to be used as learning rate for homography_net_vL. **Default value: 0.001**\n","\n","**`homography_net_v1_initial_learning_rate`:** Input the initial value to be used as learning rate for homography_net_v1. **Default value: 0.0001**\n"]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","cellView":"form"},"source":["\n","#@markdown ###Path to the Fixed and Moving image folders: \n","Fixed_image_folder = \"\" #@param {type:\"string\"}\n","\n","\n","import os.path\n","from os import path\n","\n","if path.isfile(Fixed_image_folder):\n"," I = imread(Fixed_image_folder).astype(np.float32) # fixed image\n","\n","if path.isdir(Fixed_image_folder):\n"," Fixed_image = os.listdir(Fixed_image_folder)\n"," I = imread(Fixed_image_folder+\"/\"+Fixed_image[0]).astype(np.float32) # fixed image\n","\n","\n","Moving_image_folder = \"\" #@param {type:\"string\"}\n","\n","#@markdown ### Provide the path to the folder where the predictions are to be saved\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n","\n","#@markdown ###Training Parameters\n","model_name = \"\" #@param {type:\"string\"}\n","\n","number_of_iteration = 500#@param {type:\"number\"}\n","\n","Registration_mode = \"Affine\" #@param [\"Affine\", \"Perspective\"]\n","\n","\n","#@markdown ###Do you want to apply the registration to other channel(s)?\n","Apply_registration_to_other_channels = False#@param {type:\"boolean\"}\n","\n","Additional_channels_folder = \"\" #@param {type:\"string\"}\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True#@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please input:\n","\n","n_neurons = 100 #@param {type:\"number\"}\n","mine_initial_learning_rate = 0.001 #@param {type:\"number\"}\n","homography_net_vL_initial_learning_rate = 0.001 #@param {type:\"number\"}\n","homography_net_v1_initial_learning_rate = 0.0001 #@param {type:\"number\"}\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\") \n"," n_neurons = 100\n"," mine_initial_learning_rate = 0.001\n"," homography_net_vL_initial_learning_rate = 0.001\n"," homography_net_v1_initial_learning_rate = 0.0001\n","\n","\n","#failsafe for downscale could be useful \n","#to be added\n","\n","\n","#Load a random moving image to visualise and test the settings\n","random_choice = random.choice(os.listdir(Moving_image_folder))\n","J = imread(Moving_image_folder+\"/\"+random_choice).astype(np.float32)\n","\n","# Check if additional channel(s) need to be registered and if so how many\n","\n","print(str(len(os.listdir(Moving_image_folder)))+\" image(s) will be registered.\")\n","\n","if Apply_registration_to_other_channels:\n","\n"," other_channel_images = os.listdir(Additional_channels_folder)\n"," Number_of_other_channels = len(other_channel_images)/len(os.listdir(Moving_image_folder))\n","\n"," if Number_of_other_channels.is_integer():\n"," print(\"The registration(s) will be propagated to \"+str(Number_of_other_channels)+\" other channel(s)\")\n"," else:\n"," print(bcolors.WARNING +\"!! WARNING: Incorrect number of images in Folder_containing_additional_channels\"+W)\n","\n","#here we check that no model with the same name already exist, if so print a warning\n","if os.path.exists(Result_folder+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n"," print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\"+W)\n"," \n","\n","print(\"Example of two images to be registered\")\n","\n","#Here we display one image\n","f=plt.figure(figsize=(10,10))\n","plt.subplot(1,2,1)\n","plt.imshow(I, norm=simple_norm(I, percent = 99), interpolation='nearest')\n","\n","\n","plt.title('Fixed image')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(J, norm=simple_norm(J, percent = 99), interpolation='nearest')\n","plt.title('Moving image')\n","plt.axis('off');\n","plt.savefig('/content/TrainingDataExample_DRMIME2D.png',bbox_inches='tight',pad_inches=0)\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"QpKgUER3y9tn"},"source":["## **3.2. Choose and test the image pre-processing settings**\n","---\n"," DRMIME makes use of multi-resolution image pyramids to perform registration. Unlike a conventional method where computation starts at the highest level of the image pyramid and gradually proceeds to the lower levels, DRMIME simultaneously use all the levels in gradient descent-based optimization using automatic differentiation. Here, you can choose the parameters that define the multi-resolution image pyramids that will be used.\n","\n","**`nb_images_pyramid`:** Choose the number of images to use to assemble the pyramid. **Default value: 10**.\n","\n","**`Level_downscaling`:** Choose the level of downscaling that will be used to create the images of the pyramid **Default value: 1.8**.\n","\n","**`sampling`:** amount of sampling used for the perspective registration. **Default value: 0.1**.\n","\n"]},{"cell_type":"code","metadata":{"cellView":"form","id":"MoNXLwG6yd76"},"source":["\n","#@markdown ##Image pre-processing settings\n","\n","nb_images_pyramid = 10#@param {type:\"number\"} # where registration starts (at the coarsest resolution)\n","\n","L = nb_images_pyramid\n","\n","Level_downscaling = 1.8#@param {type:\"number\"}\n","\n","downscale = Level_downscaling\n","\n","sampling = 0.1#@param {type:\"number\"} # 10% sampling used only for perspective registration\n","\n","\n","ifplot=True\n","if np.ndim(I) == 3:\n"," nChannel=I.shape[2]\n"," pyramid_I = tuple(pyramid_gaussian(gaussian(I, sigma=1, multichannel=True), downscale=downscale, multichannel=True))\n"," pyramid_J = tuple(pyramid_gaussian(gaussian(J, sigma=1, multichannel=True), downscale=downscale, multichannel=True))\n","elif np.ndim(I) == 2:\n"," nChannel=1\n"," pyramid_I = tuple(pyramid_gaussian(gaussian(I, sigma=1, multichannel=False), downscale=downscale, multichannel=False))\n"," pyramid_J = tuple(pyramid_gaussian(gaussian(J, sigma=1, multichannel=False), downscale=downscale, multichannel=False))\n","else:\n"," print(\"Unknown rank for an image\")\n","\n","\n","# Control the display\n","width=5\n","height=5\n","rows = int(L/5)+1\n","cols = 5\n","axes=[]\n","fig=plt.figure(figsize=(16,16))\n","\n","if Registration_mode == \"Affine\":\n","\n"," print(\"Affine registration selected\")\n","\n","# create a list of necessary objects you will need and commit to GPU\n"," I_lst,J_lst,h_lst,w_lst,xy_lst,ind_lst=[],[],[],[],[],[]\n"," for s in range(L):\n"," I_ = torch.tensor(cv2.normalize(pyramid_I[s].astype(np.float32), None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)).to(device)\n"," J_ = torch.tensor(cv2.normalize(pyramid_J[s].astype(np.float32), None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)).to(device)\n","\n"," if nChannel>1:\n"," I_lst.append(I_.permute(2,0,1))\n"," J_lst.append(J_.permute(2,0,1))\n"," h_, w_ = I_lst[s].shape[1], I_lst[s].shape[2]\n","\n"," edges_grayscale = cv2.dilate(cv2.Canny(cv2.GaussianBlur(rgb2gray(pyramid_I[s]),(21,21),0).astype(np.uint8), 0, 30),\n"," np.ones((5,5),np.uint8),\n"," iterations = 1)\n"," ind_ = torch.nonzero(torch.tensor(edges_grayscale).view(h_*w_)).squeeze().to(device)[:1000000]\n"," ind_lst.append(ind_)\n"," else:\n"," I_lst.append(I_)\n"," J_lst.append(J_)\n"," h_, w_ = I_lst[s].shape[0], I_lst[s].shape[1]\n","\n"," edges_grayscale = cv2.dilate(cv2.Canny(cv2.GaussianBlur(rgb2gray(pyramid_I[s]),(21,21),0).astype(np.uint8), 0, 30),\n"," np.ones((5,5),np.uint8),\n"," iterations = 1)\n"," ind_ = torch.nonzero(torch.tensor(edges_grayscale).view(h_*w_)).squeeze().to(device)[:1000000]\n"," ind_lst.append(ind_) \n"," \n"," axes.append( fig.add_subplot(rows, cols, s+1) )\n"," subplot_title=(str(s))\n"," axes[-1].set_title(subplot_title) \n"," plt.imshow(edges_grayscale)\n"," plt.axis('off');\n","\n"," h_lst.append(h_)\n"," w_lst.append(w_)\n","\n"," y_, x_ = torch.meshgrid([torch.arange(0,h_).float().to(device), torch.arange(0,w_).float().to(device)])\n"," y_, x_ = 2.0*y_/(h_-1) - 1.0, 2.0*x_/(w_-1) - 1.0\n"," xy_ = torch.stack([x_,y_],2)\n"," xy_lst.append(xy_)\n","\n"," fig.tight_layout()\n","\n"," plt.show()\n","\n","\n","if Registration_mode == \"Perspective\":\n","\n"," print(\"Perspective registration selected\")\n","\n","# create a list of necessary objects you will need and commit to GPU\n"," I_lst,J_lst,h_lst,w_lst,xy_lst,ind_lst=[],[],[],[],[],[]\n"," for s in range(L):\n"," I_ = torch.tensor(cv2.normalize(pyramid_I[s].astype(np.float32), None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)).to(device)\n"," J_ = torch.tensor(cv2.normalize(pyramid_J[s].astype(np.float32), None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)).to(device)\n"," \n"," if nChannel>1:\n"," I_lst.append(I_.permute(2,0,1))\n"," J_lst.append(J_.permute(2,0,1))\n"," h_, w_ = I_lst[s].shape[1], I_lst[s].shape[2]\n","\n"," ind_ = torch.randperm(int(h_*w_*sampling))\n"," ind_lst.append(ind_)\n"," else:\n"," I_lst.append(I_)\n"," J_lst.append(J_)\n"," h_, w_ = I_lst[s].shape[0], I_lst[s].shape[1]\n","\n"," edges_grayscale = cv2.dilate(cv2.Canny(cv2.GaussianBlur(rgb2gray(pyramid_I[s]),(21,21),0).astype(np.uint8), 0, 10),\n"," np.ones((5,5),np.uint8),\n"," iterations = 1)\n"," ind_ = torch.randperm(int(h_*w_*sampling))\n"," ind_lst.append(ind_) \n"," \n"," axes.append( fig.add_subplot(rows, cols, s+1) )\n"," subplot_title=(str(s))\n"," axes[-1].set_title(subplot_title) \n"," plt.imshow(edges_grayscale)\n"," plt.axis('off');\n","\n"," h_lst.append(h_)\n"," w_lst.append(w_)\n","\n"," y_, x_ = torch.meshgrid([torch.arange(0,h_).float().to(device), torch.arange(0,w_).float().to(device)])\n"," y_, x_ = 2.0*y_/(h_-1) - 1.0, 2.0*x_/(w_-1) - 1.0\n"," xy_ = torch.stack([x_,y_],2)\n"," xy_lst.append(xy_)\n","\n"," fig.tight_layout()\n","\n"," plt.show()\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"keIQhCmOMv5S"},"source":["# **4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"Ovu0ESxivcxx"},"source":["## **4.1. Prepare for training**\n","---\n","Here, we use the information from 3. to load the correct dependencies."]},{"cell_type":"code","metadata":{"id":"t4QTv4vQvbnS","cellView":"form"},"source":["#@markdown ##Load the dependencies required for training\n","\n","print(\"--------------------------------------------------\")\n","\n","# Remove the model name folder if exists\n","\n","if os.path.exists(Result_folder+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: Model folder already exists and has been removed !!\"+W)\n"," shutil.rmtree(Result_folder+'/'+model_name)\n","os.makedirs(Result_folder+'/'+model_name)\n","\n","\n","\n","if Registration_mode == \"Affine\":\n","\n"," class HomographyNet(nn.Module):\n"," def __init__(self):\n"," super(HomographyNet, self).__init__()\n"," # affine transform basis matrices\n","\n"," self.B = torch.zeros(6,3,3).to(device)\n"," self.B[0,0,2] = 1.0\n"," self.B[1,1,2] = 1.0\n"," self.B[2,0,1] = 1.0\n"," self.B[3,1,0] = 1.0\n"," self.B[4,0,0], self.B[4,1,1] = 1.0, -1.0\n"," self.B[5,1,1], self.B[5,2,2] = -1.0, 1.0\n","\n"," self.v1 = torch.nn.Parameter(torch.zeros(6,1,1).to(device), requires_grad=True)\n"," self.vL = torch.nn.Parameter(torch.zeros(6,1,1).to(device), requires_grad=True)\n","\n"," def forward(self, s):\n"," C = torch.sum(self.B*self.vL,0)\n"," if s==0:\n"," C += torch.sum(self.B*self.v1,0)\n"," A = torch.eye(3).to(device)\n"," H = A\n"," for i in torch.arange(1,10):\n"," A = torch.mm(A/i,C)\n"," H = H + A\n"," return H\n","\n"," class MINE(nn.Module): #https://arxiv.org/abs/1801.04062\n"," def __init__(self):\n"," super(MINE, self).__init__()\n"," self.fc1 = nn.Linear(2*nChannel, n_neurons)\n"," self.fc2 = nn.Linear(n_neurons, n_neurons)\n"," self.fc3 = nn.Linear(n_neurons, 1)\n"," self.bsize = 1 # 1 may be sufficient\n","\n"," def forward(self, x, ind):\n"," x = x.view(x.size()[0]*x.size()[1],x.size()[2])\n"," MI_lb=0.0\n"," for i in range(self.bsize):\n"," ind_perm = ind[torch.randperm(len(ind))]\n"," z1 = self.fc3(F.relu(self.fc2(F.relu(self.fc1(x[ind,:])))))\n"," z2 = self.fc3(F.relu(self.fc2(F.relu(self.fc1(torch.cat((x[ind,0:nChannel],x[ind_perm,nChannel:2*nChannel]),1))))))\n"," MI_lb += torch.mean(z1) - torch.log(torch.mean(torch.exp(z2)))\n","\n"," return MI_lb/self.bsize\n","\n"," def AffineTransform(I, H, xv, yv):\n"," # apply affine transform\n"," xvt = (xv*H[0,0]+yv*H[0,1]+H[0,2])/(xv*H[2,0]+yv*H[2,1]+H[2,2])\n"," yvt = (xv*H[1,0]+yv*H[1,1]+H[1,2])/(xv*H[2,0]+yv*H[2,1]+H[2,2])\n"," J = F.grid_sample(I,torch.stack([xvt,yvt],2).unsqueeze(0)).squeeze()\n"," return J\n","\n","\n"," def multi_resolution_loss():\n"," loss=0.0\n"," for s in np.arange(L-1,-1,-1):\n"," if nChannel>1:\n"," Jw_ = AffineTransform(J_lst[s].unsqueeze(0), homography_net(s), xy_lst[s][:,:,0], xy_lst[s][:,:,1]).squeeze()\n"," mi = mine_net(torch.cat([I_lst[s],Jw_],0).permute(1,2,0),ind_lst[s])\n"," loss = loss - (1./L)*mi\n"," else:\n"," Jw_ = AffineTransform(J_lst[s].unsqueeze(0).unsqueeze(0), homography_net(s), xy_lst[s][:,:,0], xy_lst[s][:,:,1]).squeeze()\n"," mi = mine_net(torch.stack([I_lst[s],Jw_],2),ind_lst[s])\n"," loss = loss - (1./L)*mi\n","\n"," return loss\n","\n","\n","\n","if Registration_mode == \"Perspective\":\n","\n"," class HomographyNet(nn.Module):\n"," def __init__(self):\n"," super(HomographyNet, self).__init__()\n"," # affine transform basis matrices\n","\n"," self.B = torch.zeros(8,3,3).to(device)\n"," self.B[0,0,2] = 1.0\n"," self.B[1,1,2] = 1.0\n"," self.B[2,0,1] = 1.0\n"," self.B[3,1,0] = 1.0\n"," self.B[4,0,0], self.B[4,1,1] = 1.0, -1.0\n"," self.B[5,1,1], self.B[5,2,2] = -1.0, 1.0\n"," self.B[6,2,0] = 1.0\n"," self.B[7,2,1] = 1.0\n","\n"," self.v1 = torch.nn.Parameter(torch.zeros(8,1,1).to(device), requires_grad=True)\n"," self.vL = torch.nn.Parameter(torch.zeros(8,1,1).to(device), requires_grad=True)\n","\n"," def forward(self, s):\n"," C = torch.sum(self.B*self.vL,0)\n"," if s==0:\n"," C += torch.sum(self.B*self.v1,0)\n"," A = torch.eye(3).to(device)\n"," H = A\n"," for i in torch.arange(1,10):\n"," A = torch.mm(A/i,C)\n"," H = H + A\n"," return H\n","\n","\n"," class MINE(nn.Module): #https://arxiv.org/abs/1801.04062\n"," def __init__(self):\n"," super(MINE, self).__init__()\n"," self.fc1 = nn.Linear(2*nChannel, n_neurons)\n"," self.fc2 = nn.Linear(n_neurons, n_neurons)\n"," self.fc3 = nn.Linear(n_neurons, 1)\n"," self.bsize = 1 # 1 may be sufficient\n","\n"," def forward(self, x, ind):\n"," x = x.view(x.size()[0]*x.size()[1],x.size()[2])\n"," MI_lb=0.0\n"," for i in range(self.bsize):\n"," ind_perm = ind[torch.randperm(len(ind))]\n"," z1 = self.fc3(F.relu(self.fc2(F.relu(self.fc1(x[ind,:])))))\n"," z2 = self.fc3(F.relu(self.fc2(F.relu(self.fc1(torch.cat((x[ind,0:nChannel],x[ind_perm,nChannel:2*nChannel]),1))))))\n"," MI_lb += torch.mean(z1) - torch.log(torch.mean(torch.exp(z2)))\n","\n"," return MI_lb/self.bsize\n","\n","\n"," def PerspectiveTransform(I, H, xv, yv):\n"," # apply homography\n"," xvt = (xv*H[0,0]+yv*H[0,1]+H[0,2])/(xv*H[2,0]+yv*H[2,1]+H[2,2])\n"," yvt = (xv*H[1,0]+yv*H[1,1]+H[1,2])/(xv*H[2,0]+yv*H[2,1]+H[2,2])\n"," J = F.grid_sample(I,torch.stack([xvt,yvt],2).unsqueeze(0)).squeeze()\n"," return J\n","\n","\n"," def multi_resolution_loss():\n"," loss=0.0\n"," for s in np.arange(L-1,-1,-1):\n"," if nChannel>1:\n"," Jw_ = PerspectiveTransform(J_lst[s].unsqueeze(0), homography_net(s), xy_lst[s][:,:,0], xy_lst[s][:,:,1]).squeeze()\n"," mi = mine_net(torch.cat([I_lst[s],Jw_],0).permute(1,2,0),ind_lst[s])\n"," loss = loss - (1./L)*mi\n"," else:\n"," Jw_ = PerspectiveTransform(J_lst[s].unsqueeze(0).unsqueeze(0), homography_net(s), xy_lst[s][:,:,0], xy_lst[s][:,:,1]).squeeze()\n"," mi = mine_net(torch.stack([I_lst[s],Jw_],2),ind_lst[s])\n"," loss = loss - (1./L)*mi\n","\n"," return loss\n","\n"," def histogram_mutual_information(image1, image2):\n"," hgram, x_edges, y_edges = np.histogram2d(image1.ravel(), image2.ravel(), bins=100)\n"," pxy = hgram / float(np.sum(hgram))\n"," px = np.sum(pxy, axis=1)\n"," py = np.sum(pxy, axis=0)\n"," px_py = px[:, None] * py[None, :]\n"," nzs = pxy > 0\n"," return np.sum(pxy[nzs] * np.log(pxy[nzs] / px_py[nzs]))\n","\n","\n","print(\"Done\")\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"0Dfn8ZsEMv5d"},"source":["## **4.2. Start Trainning**\n","---\n","When playing the cell below you should see updates after each iterations (round). A new network will be trained for each image that need to be registered.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point.\n","\n"]},{"cell_type":"code","metadata":{"id":"fisJmA13Mv5e","scrolled":true,"cellView":"form"},"source":["#@markdown ##Start training and the registration process\n","\n","start = time.time()\n","\n","loop_number = 1\n","\n","\n","\n","if Registration_mode == \"Affine\":\n","\n"," print(\"Affine registration.....\")\n","\n"," for image in os.listdir(Moving_image_folder):\n","\n"," if path.isfile(Fixed_image_folder):\n"," I = imread(Fixed_image_folder).astype(np.float32) # fixed image\n","\n"," if path.isdir(Fixed_image_folder):\n"," Fixed_image = os.listdir(Fixed_image_folder)\n"," I = imread(Fixed_image_folder+\"/\"+Fixed_image[0]).astype(np.float32) # fixed image\n","\n"," J = imread(Moving_image_folder+\"/\"+image).astype(np.float32)\n","\n"," # Here we generate the pyramidal images\n"," ifplot=True\n"," if np.ndim(I) == 3:\n"," nChannel=I.shape[2]\n"," pyramid_I = tuple(pyramid_gaussian(gaussian(I, sigma=1, multichannel=True), downscale=downscale, multichannel=True))\n"," pyramid_J = tuple(pyramid_gaussian(gaussian(J, sigma=1, multichannel=True), downscale=downscale, multichannel=True))\n"," elif np.ndim(I) == 2:\n"," nChannel=1\n"," pyramid_I = tuple(pyramid_gaussian(gaussian(I, sigma=1, multichannel=False), downscale=downscale, multichannel=False))\n"," pyramid_J = tuple(pyramid_gaussian(gaussian(J, sigma=1, multichannel=False), downscale=downscale, multichannel=False))\n"," else:\n"," print(\"Unknown rank for an image\")\n","\n","\n"," # create a list of necessary objects you will need and commit to GPU\n"," I_lst,J_lst,h_lst,w_lst,xy_lst,ind_lst=[],[],[],[],[],[]\n","\n","\n"," for s in range(L):\n"," I_ = torch.tensor(cv2.normalize(pyramid_I[s].astype(np.float32), None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)).to(device)\n"," J_ = torch.tensor(cv2.normalize(pyramid_J[s].astype(np.float32), None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)).to(device)\n","\n"," if nChannel>1:\n"," I_lst.append(I_.permute(2,0,1))\n"," J_lst.append(J_.permute(2,0,1))\n"," h_, w_ = I_lst[s].shape[1], I_lst[s].shape[2]\n","\n"," edges_grayscale = cv2.dilate(cv2.Canny(cv2.GaussianBlur(rgb2gray(pyramid_I[s]),(21,21),0).astype(np.uint8), 0, 30),\n"," np.ones((5,5),np.uint8),\n"," iterations = 1)\n"," ind_ = torch.nonzero(torch.tensor(edges_grayscale).view(h_*w_)).squeeze().to(device)[:1000000]\n"," ind_lst.append(ind_)\n"," else:\n"," I_lst.append(I_)\n"," J_lst.append(J_)\n"," h_, w_ = I_lst[s].shape[0], I_lst[s].shape[1]\n","\n"," edges_grayscale = cv2.dilate(cv2.Canny(cv2.GaussianBlur(rgb2gray(pyramid_I[s]),(21,21),0).astype(np.uint8), 0, 30),\n"," np.ones((5,5),np.uint8),\n"," iterations = 1)\n"," ind_ = torch.nonzero(torch.tensor(edges_grayscale).view(h_*w_)).squeeze().to(device)[:1000000]\n"," ind_lst.append(ind_)\n","\n"," h_lst.append(h_)\n"," w_lst.append(w_)\n","\n"," y_, x_ = torch.meshgrid([torch.arange(0,h_).float().to(device), torch.arange(0,w_).float().to(device)])\n"," y_, x_ = 2.0*y_/(h_-1) - 1.0, 2.0*x_/(w_-1) - 1.0\n"," xy_ = torch.stack([x_,y_],2)\n"," xy_lst.append(xy_)\n","\n"," homography_net = HomographyNet().to(device)\n"," mine_net = MINE().to(device)\n","\n"," optimizer = optim.Adam([{'params': mine_net.parameters(), 'lr': 1e-3},\n"," {'params': homography_net.vL, 'lr': 5e-3},\n"," {'params': homography_net.v1, 'lr': 1e-4}], amsgrad=True)\n"," mi_list = []\n"," for itr in range(number_of_iteration):\n"," optimizer.zero_grad()\n"," loss = multi_resolution_loss()\n"," mi_list.append(-loss.item())\n"," loss.backward()\n"," optimizer.step()\n"," clear_output(wait=True)\n"," plt.plot(mi_list)\n"," plt.xlabel('Iteration number')\n"," plt.ylabel('MI')\n"," plt.title(image+\". Image registration \"+str(loop_number)+\" out of \"+str(len(os.listdir(Moving_image_folder)))+\".\")\n"," plt.show()\n","\n"," I_t = torch.tensor(I).to(device) # without Gaussian\n"," J_t = torch.tensor(J).to(device) # without Gaussian\n"," H = homography_net(0)\n"," if nChannel>1:\n"," J_w = AffineTransform(J_t.permute(2,0,1).unsqueeze(0), H, xy_lst[0][:,:,0], xy_lst[0][:,:,1]).squeeze().permute(1,2,0)\n"," else:\n"," J_w = AffineTransform(J_t.unsqueeze(0).unsqueeze(0), H , xy_lst[0][:,:,0], xy_lst[0][:,:,1]).squeeze()\n","\n"," #Apply registration to other channels\n","\n"," if Apply_registration_to_other_channels:\n","\n"," for n_channel in range(1, int(Number_of_other_channels)+1):\n","\n"," channel = imread(Additional_channels_folder+\"/C\"+str(n_channel)+\"_\"+image).astype(np.float32)\n"," channel_t = torch.tensor(channel).to(device)\n"," channel_w = AffineTransform(channel_t.unsqueeze(0).unsqueeze(0), H , xy_lst[0][:,:,0], xy_lst[0][:,:,1]).squeeze()\n"," channel_registered = channel_w.cpu().data.numpy()\n"," io.imsave(Result_folder+'/'+model_name+\"/\"+\"C\"+str(n_channel)+\"_\"+image+\"_\"+Registration_mode+\"_registered.tif\", channel_registered)\n"," \n","# Export results to numpy array\n"," registered = J_w.cpu().data.numpy()\n","# Save results\n"," io.imsave(Result_folder+'/'+model_name+\"/\"+image+\"_\"+Registration_mode+\"_registered.tif\", registered)\n","\n"," loop_number = loop_number + 1\n","\n"," print(\"Your images have been registered and saved in your result_folder\")\n","\n","\n","#Perspective registration\n","\n","if Registration_mode == \"Perspective\":\n","\n"," print(\"Perspective registration.....\")\n","\n"," for image in os.listdir(Moving_image_folder):\n","\n"," if path.isfile(Fixed_image_folder):\n"," I = imread(Fixed_image_folder).astype(np.float32) # fixed image\n","\n"," if path.isdir(Fixed_image_folder):\n"," Fixed_image = os.listdir(Fixed_image_folder)\n"," I = imread(Fixed_image).astype(np.float32) # fixed image\n","\n"," J = imread(Moving_image_folder+\"/\"+image).astype(np.float32)\n","\n"," # Here we generate the pyramidal images\n"," ifplot=True\n"," if np.ndim(I) == 3:\n"," nChannel=I.shape[2]\n"," pyramid_I = tuple(pyramid_gaussian(gaussian(I, sigma=1, multichannel=True), downscale=downscale, multichannel=True))\n"," pyramid_J = tuple(pyramid_gaussian(gaussian(J, sigma=1, multichannel=True), downscale=downscale, multichannel=True))\n"," elif np.ndim(I) == 2:\n"," nChannel=1\n"," pyramid_I = tuple(pyramid_gaussian(gaussian(I, sigma=1, multichannel=False), downscale=downscale, multichannel=False))\n"," pyramid_J = tuple(pyramid_gaussian(gaussian(J, sigma=1, multichannel=False), downscale=downscale, multichannel=False))\n"," else:\n"," print(\"Unknown rank for an image\")\n","\n","\n"," # create a list of necessary objects you will need and commit to GPU\n"," I_lst,J_lst,h_lst,w_lst,xy_lst,ind_lst=[],[],[],[],[],[]\n"," for s in range(L):\n"," I_ = torch.tensor(cv2.normalize(pyramid_I[s].astype(np.float32), None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)).to(device)\n"," J_ = torch.tensor(cv2.normalize(pyramid_J[s].astype(np.float32), None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)).to(device)\n","\n"," if nChannel>1:\n"," I_lst.append(I_.permute(2,0,1))\n"," J_lst.append(J_.permute(2,0,1))\n"," h_, w_ = I_lst[s].shape[1], I_lst[s].shape[2]\n","\n"," ind_ = torch.randperm(int(h_*w_*sampling))\n"," ind_lst.append(ind_)\n"," else:\n"," I_lst.append(I_)\n"," J_lst.append(J_)\n"," h_, w_ = I_lst[s].shape[0], I_lst[s].shape[1]\n","\n"," edges_grayscale = cv2.dilate(cv2.Canny(cv2.GaussianBlur(rgb2gray(pyramid_I[s]),(21,21),0).astype(np.uint8), 0, 10),\n"," np.ones((5,5),np.uint8),\n"," iterations = 1)\n"," ind_ = torch.randperm(int(h_*w_*sampling))\n"," ind_lst.append(ind_)\n"," h_lst.append(h_)\n"," w_lst.append(w_)\n","\n"," y_, x_ = torch.meshgrid([torch.arange(0,h_).float().to(device), torch.arange(0,w_).float().to(device)])\n"," y_, x_ = 2.0*y_/(h_-1) - 1.0, 2.0*x_/(w_-1) - 1.0\n"," xy_ = torch.stack([x_,y_],2)\n"," xy_lst.append(xy_)\n","\n"," homography_net = HomographyNet().to(device)\n"," mine_net = MINE().to(device)\n","\n"," optimizer = optim.Adam([{'params': mine_net.parameters(), 'lr': 1e-3},\n"," {'params': homography_net.vL, 'lr': 1e-3},\n"," {'params': homography_net.v1, 'lr': 1e-4}], amsgrad=True)\n"," mi_list = []\n"," for itr in range(number_of_iteration):\n"," optimizer.zero_grad()\n"," loss = multi_resolution_loss()\n"," mi_list.append(-loss.item())\n"," loss.backward()\n"," optimizer.step()\n"," clear_output(wait=True)\n"," plt.plot(mi_list)\n"," plt.xlabel('Iteration number')\n"," plt.ylabel('MI')\n"," plt.title(image+\". Image registration \"+str(loop_number)+\" out of \"+str(len(os.listdir(Moving_image_folder)))+\".\")\n"," plt.show()\n","\n"," I_t = torch.tensor(I).to(device) # without Gaussian\n"," J_t = torch.tensor(J).to(device) # without Gaussian\n"," H = homography_net(0)\n"," if nChannel>1:\n"," J_w = PerspectiveTransform(J_t.permute(2,0,1).unsqueeze(0), H, xy_lst[0][:,:,0], xy_lst[0][:,:,1]).squeeze().permute(1,2,0)\n"," else:\n"," J_w = PerspectiveTransform(J_t.unsqueeze(0).unsqueeze(0), H , xy_lst[0][:,:,0], xy_lst[0][:,:,1]).squeeze()\n","\n"," #Apply registration to other channels\n","\n"," if Apply_registration_to_other_channels:\n","\n"," for n_channel in range(1, int(Number_of_other_channels)+1):\n","\n"," channel = imread(Additional_channels_folder+\"/C\"+str(n_channel)+\"_\"+image).astype(np.float32)\n"," channel_t = torch.tensor(channel).to(device)\n"," channel_w = PerspectiveTransform(channel_t.unsqueeze(0).unsqueeze(0), H , xy_lst[0][:,:,0], xy_lst[0][:,:,1]).squeeze()\n"," channel_registered = channel_w.cpu().data.numpy()\n"," io.imsave(Result_folder+'/'+model_name+\"/\"+\"C\"+str(n_channel)+\"_\"+image+\"_Perspective_registered.tif\", channel_registered) \n","\n","\n","# Export results to numpy array\n"," registered = J_w.cpu().data.numpy()\n","# Save results\n"," io.imsave(Result_folder+'/'+model_name+\"/\"+image+\"_Perspective_registered.tif\", registered)\n","\n"," loop_number = loop_number + 1\n","\n"," print(\"Your images have been registered and saved in your result_folder\")\n","\n","\n","# PDF export missing \n","\n","#pdf_export(trained = True, augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"PfTw_pQUUAqB"},"source":["## **4.3. Assess the registration**\n","---\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"SrArBvqwYvc9","cellView":"form"},"source":["# @markdown ##Run this cell to display a randomly chosen input and its corresponding predicted output.\n","\n","# For sliders and dropdown menu and progress bar\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","\n","print('--------------------------------------------------------------')\n","@interact\n","def show_QC_results(file = os.listdir(Moving_image_folder)):\n","\n"," moving_image = imread(Moving_image_folder+\"/\"+file).astype(np.float32)\n"," \n"," registered_image = imread(Result_folder+\"/\"+model_name+\"/\"+file+\"_\"+Registration_mode+\"_registered.tif\").astype(np.float32)\n","\n","#Here we display one image\n","\n"," f=plt.figure(figsize=(20,20))\n"," plt.subplot(1,5,1)\n"," plt.imshow(I, norm=simple_norm(I, percent = 99), interpolation='nearest')\n"," plt.title('Fixed image')\n"," plt.axis('off');\n","\n"," plt.subplot(1,5,2)\n"," plt.imshow(moving_image, norm=simple_norm(moving_image, percent = 99), interpolation='nearest')\n"," plt.title('Moving image')\n"," plt.axis('off');\n","\n"," plt.subplot(1,5,3)\n"," plt.imshow(registered_image, norm=simple_norm(registered_image, percent = 99), interpolation='nearest')\n"," plt.title(\"Registered image\")\n"," plt.axis('off');\n","\n"," plt.subplot(1,5,4)\n"," plt.imshow(I, norm=simple_norm(I, percent = 99), interpolation='nearest', cmap=\"Greens\")\n"," plt.imshow(moving_image, norm=simple_norm(moving_image, percent = 99), interpolation='nearest', cmap=\"Oranges\", alpha=0.5)\n"," plt.title(\"Fixed and moving images\")\n"," plt.axis('off');\n","\n"," plt.subplot(1,5,5)\n"," plt.imshow(I, norm=simple_norm(I, percent = 99), interpolation='nearest', cmap=\"Greens\")\n"," plt.imshow(registered_image, norm=simple_norm(registered_image, percent = 99), interpolation='nearest', cmap=\"Oranges\", alpha=0.5)\n"," plt.title(\"Fixed and Registered images\")\n"," plt.axis('off');\n","\n"," plt.show()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"wgO7Ok1PBFQj"},"source":["## **4.4. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"nlyPYwZu4VVS"},"source":["#**Thank you for using DRMIME 2D!**"]}]} \ No newline at end of file +{"nbformat":4,"nbformat_minor":0,"metadata":{"accelerator":"GPU","colab":{"name":"DRMIME_2D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1hzAI0joLETcG5sI2Qvo8AKDr0TWRKySJ","timestamp":1587653755731},{"file_id":"1QFcz4NnQv4rMwDNl7AzHajN-Ola9sUFW","timestamp":1586411847878},{"file_id":"12UDRQ7abcnXcf5FctR9IUStgCpBiQWn7","timestamp":1584466922281},{"file_id":"1zXCn3A39GI1MCnXK_g_Z-AWh9vkB0YhU","timestamp":1583244415636}],"collapsed_sections":[],"toc_visible":true},"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.9"}},"cells":[{"cell_type":"markdown","metadata":{"id":"IkSguVy8Xv83"},"source":["# **DRMIME (2D)**\n","\n","---\n","\n"," DRMIME is a self-supervised deep-learning method that can be used to register 2D images.\n","\n"," **This particular notebook enables self-supervised registration of 2D dataset.**\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories. \n","\n","\n","While this notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (ZeroCostDL4Mic), this notebook structure substantially deviates from other ZeroCostDL4Mic notebooks and our template. This is because the deep learning method employed here is used to improve the image registration process. No Deep Learning models are actually saved, only the registered images. \n","\n","\n","This notebook is largely based on the following paper:\n","\n","DRMIME: Differentiable Mutual Information and Matrix Exponential for Multi-Resolution Image Registration by Abhishek Nan\n"," *et al.* published on arXiv in 2020 (https://arxiv.org/abs/2001.09865)\n","\n","And source code found in: /~https://github.com/abnan/DRMIME\n","\n","**Please also cite this original paper when using or developing this notebook.**\n"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"gKDLkLWUd-YX"},"source":["# **0. Before getting started**\n","---\n","\n","Before you run the notebook, please ensure that you are logged into your Google account and have the training and/or data to process in your Google Drive.\n","\n","For DRMIME to train, it requires at least two images. One **`\"Fixed image\"`** (template for the registration) and one **`Moving Image`** (image to be registered). Multiple **`Moving Images`** can also be provided if you want to register them to the same **`\"Fixed image\"`**. If you provide several **`Moving Images`**, multiple DRMIME instances will run one after another. \n","\n","The registration can also be applied to other channels. If you wish to apply the registration to other channels, please provide the images in another folder and carefully check your file names. Additional channels need to have the same name as the registered images and a prefix indicating the channel number starting at \"C1_\". See the example below. \n","\n","Here is a common data structure that can work:\n","\n","* Data\n"," \n"," - **Fixed_image_folder**\n"," - img_1.tif (image used as template for the registration)\n"," - **Moving_image_folder**\n"," - img_3.tif, img_4.tif, ... (images to be registered) \n"," - **Folder_containing_additional_channels** (optional, if you want to apply the registration to other channel(s))\n"," - C1_img_3.tif, C1_img_4.tif, ...\n"," - C2_img_3.tif, C2_img_4.tif, ...\n"," - C3_img_3.tif, C3_img_4.tif, ...\n"," - **Results**\n","\n","The **Results** folder will contain the processed images and PDF reports. Your original images remain unmodified.\n","\n","---\n","\n"]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **1. Install DRMIME and dependencies**\n","---"]},{"cell_type":"code","metadata":{"id":"fq21zJVFNASx","cellView":"form"},"source":["Notebook_version = '1.13'\n","Network = 'DRMIME'\n","\n","\n","from builtins import any as b_any\n","\n","def get_requirements_path():\n"," # Store requirements file in 'contents' directory\n"," current_dir = os.getcwd()\n"," dir_count = current_dir.count('/') - 1\n"," path = '../' * (dir_count) + 'requirements.txt'\n"," return path\n","\n","def filter_files(file_list, filter_list):\n"," filtered_list = []\n"," for fname in file_list:\n"," if b_any(fname.split('==')[0] in s for s in filter_list):\n"," filtered_list.append(fname)\n"," return filtered_list\n","\n","def build_requirements_file(before, after):\n"," path = get_requirements_path()\n","\n"," # Exporting requirements.txt for local run\n"," !pip freeze > $path\n","\n"," # Get minimum requirements file\n"," df = pd.read_csv(path, delimiter = \"\\n\")\n"," mod_list = [m.split('.')[0] for m in after if not m in before]\n"," req_list_temp = df.values.tolist()\n"," req_list = [x[0] for x in req_list_temp]\n","\n"," # Replace with package name and handle cases where import name is different to module name\n"," mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]\n"," mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list]\n"," filtered_list = filter_files(req_list, mod_replace_list)\n","\n"," file=open(path,'w')\n"," for item in filtered_list:\n"," file.writelines(item + '\\n')\n","\n"," file.close()\n","\n","import sys\n","before = [str(m) for m in sys.modules]\n","\n","#@markdown ##Install DRMIME and dependencies\n","\n","# Here we install DRMIME and other required packages\n","\n","!pip install wget\n","\n","from skimage import io\n","import numpy as np\n","import math\n","import matplotlib.pyplot as plt\n","import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","from torch.autograd import Variable\n","import torch.optim as optim\n","from skimage.transform import pyramid_gaussian\n","from skimage.filters import gaussian\n","from skimage.filters import threshold_otsu\n","from skimage.filters import sobel\n","from skimage.color import rgb2gray\n","from skimage import feature\n","from torch.autograd import Function\n","import cv2\n","from IPython.display import clear_output\n","import pandas as pd\n","from skimage.io import imsave\n","\n","device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n","\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","\n","# Check if this is the latest version of the notebook\n","All_notebook_versions = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_Notebook_versions.csv\", dtype=str)\n","print('Notebook version: '+Notebook_version)\n","Latest_Notebook_version = All_notebook_versions[All_notebook_versions[\"Notebook\"] == Network]['Version'].iloc[0]\n","print('Latest notebook version: '+Latest_Notebook_version)\n","if Notebook_version == Latest_Notebook_version:\n"," print(\"This notebook is up-to-date.\")\n","else:\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","\n","\n","!pip freeze > requirements.txt\n","\n","\n","print(\"Libraries installed\")\n","\n","# Build requirements file for local run\n","after = [str(m) for m in sys.modules]\n","build_requirements_file(before, after)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"cbTknRcviyT7"},"source":["# **2. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["## **2.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelerator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"cellView":"form","id":"h5i5CS2bSmZr"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","#%tensorflow_version 1.x\n","\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n3B3meGTbYVi"},"source":["## **2.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"cellView":"form","id":"01Djr8v-5pPk"},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"HLYcZR9gMv42"},"source":["# **3. Select your parameters and paths**\n","---"]},{"cell_type":"markdown","metadata":{"id":"Kbn9_JdqnNnK"},"source":["## **3.1. Setting main training parameters**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"id":"CB6acvUFtWqd"},"source":[" **Paths for training, predictions and results**\n","These is the path to your folders containing the image you want to register. To find the path of the folder containing your datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`Fixed_image_folder`:** This is the folder containing your \"Fixed image\".\n","\n","**`Moving_image_folder`:** This is the folder containing your \"Moving Image(s)\".\n","\n","**`Result_folder`:** This is the folder where your results will be saved.\n","\n","\n","**Training Parameters**\n","\n","**`model_name`:** Choose a name for your model.\n","\n","**`number_of_iteration`:** Input how many iteration (rounds) the network will be trained. Preliminary results can already be observed after a 200 iterations, but a full training should run for 500-1000 iterations. **Default value: 500**\n","\n","**`Registration_mode`:** Choose which registration method you would like to use.\n","\n","**Additional channels**\n","\n"," This option enable you to apply the registration to other images (for instance other channels). Place these images in the **`Additional_channels_folder`**. Additional channels need to have the same name as the images you want to register (found in **`Moving_image_folder`**) and a prefix indicating the channel number starting at \"C1_\".\n","\n"," \n","**Advanced Parameters - experienced users only**\n","\n","**`n_neurons`:** Number of neurons (elementary constituents) that will assemble your model. **Default value: 100**.\n","\n","**`mine_initial_learning_rate`:** Input the initial value to be used as learning rate for MINE. **Default value: 0.001**\n","**`homography_net_vL_initial_learning_rate`:** Input the initial value to be used as learning rate for homography_net_vL. **Default value: 0.001**\n","\n","**`homography_net_v1_initial_learning_rate`:** Input the initial value to be used as learning rate for homography_net_v1. **Default value: 0.0001**\n"]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","cellView":"form"},"source":["\n","#@markdown ###Path to the Fixed and Moving image folders: \n","Fixed_image_folder = \"\" #@param {type:\"string\"}\n","\n","\n","import os.path\n","from os import path\n","\n","if path.isfile(Fixed_image_folder):\n"," I = imread(Fixed_image_folder).astype(np.float32) # fixed image\n","\n","if path.isdir(Fixed_image_folder):\n"," Fixed_image = os.listdir(Fixed_image_folder)\n"," I = imread(Fixed_image_folder+\"/\"+Fixed_image[0]).astype(np.float32) # fixed image\n","\n","\n","Moving_image_folder = \"\" #@param {type:\"string\"}\n","\n","#@markdown ### Provide the path to the folder where the predictions are to be saved\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n","\n","#@markdown ###Training Parameters\n","model_name = \"\" #@param {type:\"string\"}\n","\n","number_of_iteration = 500#@param {type:\"number\"}\n","\n","Registration_mode = \"Affine\" #@param [\"Affine\", \"Perspective\"]\n","\n","\n","#@markdown ###Do you want to apply the registration to other channel(s)?\n","Apply_registration_to_other_channels = False#@param {type:\"boolean\"}\n","\n","Additional_channels_folder = \"\" #@param {type:\"string\"}\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True#@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please input:\n","\n","n_neurons = 100 #@param {type:\"number\"}\n","mine_initial_learning_rate = 0.001 #@param {type:\"number\"}\n","homography_net_vL_initial_learning_rate = 0.001 #@param {type:\"number\"}\n","homography_net_v1_initial_learning_rate = 0.0001 #@param {type:\"number\"}\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\") \n"," n_neurons = 100\n"," mine_initial_learning_rate = 0.001\n"," homography_net_vL_initial_learning_rate = 0.001\n"," homography_net_v1_initial_learning_rate = 0.0001\n","\n","\n","#failsafe for downscale could be useful \n","#to be added\n","\n","\n","#Load a random moving image to visualise and test the settings\n","random_choice = random.choice(os.listdir(Moving_image_folder))\n","J = imread(Moving_image_folder+\"/\"+random_choice).astype(np.float32)\n","\n","# Check if additional channel(s) need to be registered and if so how many\n","\n","print(str(len(os.listdir(Moving_image_folder)))+\" image(s) will be registered.\")\n","\n","if Apply_registration_to_other_channels:\n","\n"," other_channel_images = os.listdir(Additional_channels_folder)\n"," Number_of_other_channels = len(other_channel_images)/len(os.listdir(Moving_image_folder))\n","\n"," if Number_of_other_channels.is_integer():\n"," print(\"The registration(s) will be propagated to \"+str(Number_of_other_channels)+\" other channel(s)\")\n"," else:\n"," print(bcolors.WARNING +\"!! WARNING: Incorrect number of images in Folder_containing_additional_channels\"+W)\n","\n","#here we check that no model with the same name already exist, if so print a warning\n","if os.path.exists(Result_folder+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n"," print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\"+W)\n"," \n","\n","print(\"Example of two images to be registered\")\n","\n","#Here we display one image\n","f=plt.figure(figsize=(10,10))\n","plt.subplot(1,2,1)\n","plt.imshow(I, norm=simple_norm(I, percent = 99), interpolation='nearest')\n","\n","\n","plt.title('Fixed image')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(J, norm=simple_norm(J, percent = 99), interpolation='nearest')\n","plt.title('Moving image')\n","plt.axis('off');\n","plt.savefig('/content/TrainingDataExample_DRMIME2D.png',bbox_inches='tight',pad_inches=0)\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"QpKgUER3y9tn"},"source":["## **3.2. Choose and test the image pre-processing settings**\n","---\n"," DRMIME makes use of multi-resolution image pyramids to perform registration. Unlike a conventional method where computation starts at the highest level of the image pyramid and gradually proceeds to the lower levels, DRMIME simultaneously use all the levels in gradient descent-based optimization using automatic differentiation. Here, you can choose the parameters that define the multi-resolution image pyramids that will be used.\n","\n","**`nb_images_pyramid`:** Choose the number of images to use to assemble the pyramid. **Default value: 10**.\n","\n","**`Level_downscaling`:** Choose the level of downscaling that will be used to create the images of the pyramid **Default value: 1.8**.\n","\n","**`sampling`:** amount of sampling used for the perspective registration. **Default value: 0.1**.\n","\n"]},{"cell_type":"code","metadata":{"cellView":"form","id":"MoNXLwG6yd76"},"source":["\n","#@markdown ##Image pre-processing settings\n","\n","nb_images_pyramid = 10#@param {type:\"number\"} # where registration starts (at the coarsest resolution)\n","\n","L = nb_images_pyramid\n","\n","Level_downscaling = 1.8#@param {type:\"number\"}\n","\n","downscale = Level_downscaling\n","\n","sampling = 0.1#@param {type:\"number\"} # 10% sampling used only for perspective registration\n","\n","\n","ifplot=True\n","if np.ndim(I) == 3:\n"," nChannel=I.shape[2]\n"," pyramid_I = tuple(pyramid_gaussian(gaussian(I, sigma=1, multichannel=True), downscale=downscale, multichannel=True))\n"," pyramid_J = tuple(pyramid_gaussian(gaussian(J, sigma=1, multichannel=True), downscale=downscale, multichannel=True))\n","elif np.ndim(I) == 2:\n"," nChannel=1\n"," pyramid_I = tuple(pyramid_gaussian(gaussian(I, sigma=1, multichannel=False), downscale=downscale, multichannel=False))\n"," pyramid_J = tuple(pyramid_gaussian(gaussian(J, sigma=1, multichannel=False), downscale=downscale, multichannel=False))\n","else:\n"," print(\"Unknown rank for an image\")\n","\n","\n","# Control the display\n","width=5\n","height=5\n","rows = int(L/5)+1\n","cols = 5\n","axes=[]\n","fig=plt.figure(figsize=(16,16))\n","\n","if Registration_mode == \"Affine\":\n","\n"," print(\"Affine registration selected\")\n","\n","# create a list of necessary objects you will need and commit to GPU\n"," I_lst,J_lst,h_lst,w_lst,xy_lst,ind_lst=[],[],[],[],[],[]\n"," for s in range(L):\n"," I_ = torch.tensor(cv2.normalize(pyramid_I[s].astype(np.float32), None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)).to(device)\n"," J_ = torch.tensor(cv2.normalize(pyramid_J[s].astype(np.float32), None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)).to(device)\n","\n"," if nChannel>1:\n"," I_lst.append(I_.permute(2,0,1))\n"," J_lst.append(J_.permute(2,0,1))\n"," h_, w_ = I_lst[s].shape[1], I_lst[s].shape[2]\n","\n"," edges_grayscale = cv2.dilate(cv2.Canny(cv2.GaussianBlur(rgb2gray(pyramid_I[s]),(21,21),0).astype(np.uint8), 0, 30),\n"," np.ones((5,5),np.uint8),\n"," iterations = 1)\n"," ind_ = torch.nonzero(torch.tensor(edges_grayscale).view(h_*w_)).squeeze().to(device)[:1000000]\n"," ind_lst.append(ind_)\n"," else:\n"," I_lst.append(I_)\n"," J_lst.append(J_)\n"," h_, w_ = I_lst[s].shape[0], I_lst[s].shape[1]\n","\n"," edges_grayscale = cv2.dilate(cv2.Canny(cv2.GaussianBlur(rgb2gray(pyramid_I[s]),(21,21),0).astype(np.uint8), 0, 30),\n"," np.ones((5,5),np.uint8),\n"," iterations = 1)\n"," ind_ = torch.nonzero(torch.tensor(edges_grayscale).view(h_*w_)).squeeze().to(device)[:1000000]\n"," ind_lst.append(ind_) \n"," \n"," axes.append( fig.add_subplot(rows, cols, s+1) )\n"," subplot_title=(str(s))\n"," axes[-1].set_title(subplot_title) \n"," plt.imshow(edges_grayscale)\n"," plt.axis('off');\n","\n"," h_lst.append(h_)\n"," w_lst.append(w_)\n","\n"," y_, x_ = torch.meshgrid([torch.arange(0,h_).float().to(device), torch.arange(0,w_).float().to(device)])\n"," y_, x_ = 2.0*y_/(h_-1) - 1.0, 2.0*x_/(w_-1) - 1.0\n"," xy_ = torch.stack([x_,y_],2)\n"," xy_lst.append(xy_)\n","\n"," fig.tight_layout()\n","\n"," plt.show()\n","\n","\n","if Registration_mode == \"Perspective\":\n","\n"," print(\"Perspective registration selected\")\n","\n","# create a list of necessary objects you will need and commit to GPU\n"," I_lst,J_lst,h_lst,w_lst,xy_lst,ind_lst=[],[],[],[],[],[]\n"," for s in range(L):\n"," I_ = torch.tensor(cv2.normalize(pyramid_I[s].astype(np.float32), None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)).to(device)\n"," J_ = torch.tensor(cv2.normalize(pyramid_J[s].astype(np.float32), None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)).to(device)\n"," \n"," if nChannel>1:\n"," I_lst.append(I_.permute(2,0,1))\n"," J_lst.append(J_.permute(2,0,1))\n"," h_, w_ = I_lst[s].shape[1], I_lst[s].shape[2]\n","\n"," ind_ = torch.randperm(int(h_*w_*sampling))\n"," ind_lst.append(ind_)\n"," else:\n"," I_lst.append(I_)\n"," J_lst.append(J_)\n"," h_, w_ = I_lst[s].shape[0], I_lst[s].shape[1]\n","\n"," edges_grayscale = cv2.dilate(cv2.Canny(cv2.GaussianBlur(rgb2gray(pyramid_I[s]),(21,21),0).astype(np.uint8), 0, 10),\n"," np.ones((5,5),np.uint8),\n"," iterations = 1)\n"," ind_ = torch.randperm(int(h_*w_*sampling))\n"," ind_lst.append(ind_) \n"," \n"," axes.append( fig.add_subplot(rows, cols, s+1) )\n"," subplot_title=(str(s))\n"," axes[-1].set_title(subplot_title) \n"," plt.imshow(edges_grayscale)\n"," plt.axis('off');\n","\n"," h_lst.append(h_)\n"," w_lst.append(w_)\n","\n"," y_, x_ = torch.meshgrid([torch.arange(0,h_).float().to(device), torch.arange(0,w_).float().to(device)])\n"," y_, x_ = 2.0*y_/(h_-1) - 1.0, 2.0*x_/(w_-1) - 1.0\n"," xy_ = torch.stack([x_,y_],2)\n"," xy_lst.append(xy_)\n","\n"," fig.tight_layout()\n","\n"," plt.show()\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"keIQhCmOMv5S"},"source":["# **4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"Ovu0ESxivcxx"},"source":["## **4.1. Prepare for training**\n","---\n","Here, we use the information from 3. to load the correct dependencies."]},{"cell_type":"code","metadata":{"id":"t4QTv4vQvbnS","cellView":"form"},"source":["#@markdown ##Load the dependencies required for training\n","\n","print(\"--------------------------------------------------\")\n","\n","# Remove the model name folder if exists\n","\n","if os.path.exists(Result_folder+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: Model folder already exists and has been removed !!\"+W)\n"," shutil.rmtree(Result_folder+'/'+model_name)\n","os.makedirs(Result_folder+'/'+model_name)\n","\n","\n","\n","if Registration_mode == \"Affine\":\n","\n"," class HomographyNet(nn.Module):\n"," def __init__(self):\n"," super(HomographyNet, self).__init__()\n"," # affine transform basis matrices\n","\n"," self.B = torch.zeros(6,3,3).to(device)\n"," self.B[0,0,2] = 1.0\n"," self.B[1,1,2] = 1.0\n"," self.B[2,0,1] = 1.0\n"," self.B[3,1,0] = 1.0\n"," self.B[4,0,0], self.B[4,1,1] = 1.0, -1.0\n"," self.B[5,1,1], self.B[5,2,2] = -1.0, 1.0\n","\n"," self.v1 = torch.nn.Parameter(torch.zeros(6,1,1).to(device), requires_grad=True)\n"," self.vL = torch.nn.Parameter(torch.zeros(6,1,1).to(device), requires_grad=True)\n","\n"," def forward(self, s):\n"," C = torch.sum(self.B*self.vL,0)\n"," if s==0:\n"," C += torch.sum(self.B*self.v1,0)\n"," A = torch.eye(3).to(device)\n"," H = A\n"," for i in torch.arange(1,10):\n"," A = torch.mm(A/i,C)\n"," H = H + A\n"," return H\n","\n"," class MINE(nn.Module): #https://arxiv.org/abs/1801.04062\n"," def __init__(self):\n"," super(MINE, self).__init__()\n"," self.fc1 = nn.Linear(2*nChannel, n_neurons)\n"," self.fc2 = nn.Linear(n_neurons, n_neurons)\n"," self.fc3 = nn.Linear(n_neurons, 1)\n"," self.bsize = 1 # 1 may be sufficient\n","\n"," def forward(self, x, ind):\n"," x = x.view(x.size()[0]*x.size()[1],x.size()[2])\n"," MI_lb=0.0\n"," for i in range(self.bsize):\n"," ind_perm = ind[torch.randperm(len(ind))]\n"," z1 = self.fc3(F.relu(self.fc2(F.relu(self.fc1(x[ind,:])))))\n"," z2 = self.fc3(F.relu(self.fc2(F.relu(self.fc1(torch.cat((x[ind,0:nChannel],x[ind_perm,nChannel:2*nChannel]),1))))))\n"," MI_lb += torch.mean(z1) - torch.log(torch.mean(torch.exp(z2)))\n","\n"," return MI_lb/self.bsize\n","\n"," def AffineTransform(I, H, xv, yv):\n"," # apply affine transform\n"," xvt = (xv*H[0,0]+yv*H[0,1]+H[0,2])/(xv*H[2,0]+yv*H[2,1]+H[2,2])\n"," yvt = (xv*H[1,0]+yv*H[1,1]+H[1,2])/(xv*H[2,0]+yv*H[2,1]+H[2,2])\n"," J = F.grid_sample(I,torch.stack([xvt,yvt],2).unsqueeze(0)).squeeze()\n"," return J\n","\n","\n"," def multi_resolution_loss():\n"," loss=0.0\n"," for s in np.arange(L-1,-1,-1):\n"," if nChannel>1:\n"," Jw_ = AffineTransform(J_lst[s].unsqueeze(0), homography_net(s), xy_lst[s][:,:,0], xy_lst[s][:,:,1]).squeeze()\n"," mi = mine_net(torch.cat([I_lst[s],Jw_],0).permute(1,2,0),ind_lst[s])\n"," loss = loss - (1./L)*mi\n"," else:\n"," Jw_ = AffineTransform(J_lst[s].unsqueeze(0).unsqueeze(0), homography_net(s), xy_lst[s][:,:,0], xy_lst[s][:,:,1]).squeeze()\n"," mi = mine_net(torch.stack([I_lst[s],Jw_],2),ind_lst[s])\n"," loss = loss - (1./L)*mi\n","\n"," return loss\n","\n","\n","\n","if Registration_mode == \"Perspective\":\n","\n"," class HomographyNet(nn.Module):\n"," def __init__(self):\n"," super(HomographyNet, self).__init__()\n"," # affine transform basis matrices\n","\n"," self.B = torch.zeros(8,3,3).to(device)\n"," self.B[0,0,2] = 1.0\n"," self.B[1,1,2] = 1.0\n"," self.B[2,0,1] = 1.0\n"," self.B[3,1,0] = 1.0\n"," self.B[4,0,0], self.B[4,1,1] = 1.0, -1.0\n"," self.B[5,1,1], self.B[5,2,2] = -1.0, 1.0\n"," self.B[6,2,0] = 1.0\n"," self.B[7,2,1] = 1.0\n","\n"," self.v1 = torch.nn.Parameter(torch.zeros(8,1,1).to(device), requires_grad=True)\n"," self.vL = torch.nn.Parameter(torch.zeros(8,1,1).to(device), requires_grad=True)\n","\n"," def forward(self, s):\n"," C = torch.sum(self.B*self.vL,0)\n"," if s==0:\n"," C += torch.sum(self.B*self.v1,0)\n"," A = torch.eye(3).to(device)\n"," H = A\n"," for i in torch.arange(1,10):\n"," A = torch.mm(A/i,C)\n"," H = H + A\n"," return H\n","\n","\n"," class MINE(nn.Module): #https://arxiv.org/abs/1801.04062\n"," def __init__(self):\n"," super(MINE, self).__init__()\n"," self.fc1 = nn.Linear(2*nChannel, n_neurons)\n"," self.fc2 = nn.Linear(n_neurons, n_neurons)\n"," self.fc3 = nn.Linear(n_neurons, 1)\n"," self.bsize = 1 # 1 may be sufficient\n","\n"," def forward(self, x, ind):\n"," x = x.view(x.size()[0]*x.size()[1],x.size()[2])\n"," MI_lb=0.0\n"," for i in range(self.bsize):\n"," ind_perm = ind[torch.randperm(len(ind))]\n"," z1 = self.fc3(F.relu(self.fc2(F.relu(self.fc1(x[ind,:])))))\n"," z2 = self.fc3(F.relu(self.fc2(F.relu(self.fc1(torch.cat((x[ind,0:nChannel],x[ind_perm,nChannel:2*nChannel]),1))))))\n"," MI_lb += torch.mean(z1) - torch.log(torch.mean(torch.exp(z2)))\n","\n"," return MI_lb/self.bsize\n","\n","\n"," def PerspectiveTransform(I, H, xv, yv):\n"," # apply homography\n"," xvt = (xv*H[0,0]+yv*H[0,1]+H[0,2])/(xv*H[2,0]+yv*H[2,1]+H[2,2])\n"," yvt = (xv*H[1,0]+yv*H[1,1]+H[1,2])/(xv*H[2,0]+yv*H[2,1]+H[2,2])\n"," J = F.grid_sample(I,torch.stack([xvt,yvt],2).unsqueeze(0)).squeeze()\n"," return J\n","\n","\n"," def multi_resolution_loss():\n"," loss=0.0\n"," for s in np.arange(L-1,-1,-1):\n"," if nChannel>1:\n"," Jw_ = PerspectiveTransform(J_lst[s].unsqueeze(0), homography_net(s), xy_lst[s][:,:,0], xy_lst[s][:,:,1]).squeeze()\n"," mi = mine_net(torch.cat([I_lst[s],Jw_],0).permute(1,2,0),ind_lst[s])\n"," loss = loss - (1./L)*mi\n"," else:\n"," Jw_ = PerspectiveTransform(J_lst[s].unsqueeze(0).unsqueeze(0), homography_net(s), xy_lst[s][:,:,0], xy_lst[s][:,:,1]).squeeze()\n"," mi = mine_net(torch.stack([I_lst[s],Jw_],2),ind_lst[s])\n"," loss = loss - (1./L)*mi\n","\n"," return loss\n","\n"," def histogram_mutual_information(image1, image2):\n"," hgram, x_edges, y_edges = np.histogram2d(image1.ravel(), image2.ravel(), bins=100)\n"," pxy = hgram / float(np.sum(hgram))\n"," px = np.sum(pxy, axis=1)\n"," py = np.sum(pxy, axis=0)\n"," px_py = px[:, None] * py[None, :]\n"," nzs = pxy > 0\n"," return np.sum(pxy[nzs] * np.log(pxy[nzs] / px_py[nzs]))\n","\n","\n","print(\"Done\")\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"0Dfn8ZsEMv5d"},"source":["## **4.2. Start Trainning**\n","---\n","When playing the cell below you should see updates after each iterations (round). A new network will be trained for each image that need to be registered.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point.\n","\n"]},{"cell_type":"code","metadata":{"id":"fisJmA13Mv5e","scrolled":true,"cellView":"form"},"source":["#@markdown ##Start training and the registration process\n","\n","start = time.time()\n","\n","loop_number = 1\n","\n","\n","\n","if Registration_mode == \"Affine\":\n","\n"," print(\"Affine registration.....\")\n","\n"," for image in os.listdir(Moving_image_folder):\n","\n"," if path.isfile(Fixed_image_folder):\n"," I = imread(Fixed_image_folder).astype(np.float32) # fixed image\n","\n"," if path.isdir(Fixed_image_folder):\n"," Fixed_image = os.listdir(Fixed_image_folder)\n"," I = imread(Fixed_image_folder+\"/\"+Fixed_image[0]).astype(np.float32) # fixed image\n","\n"," J = imread(Moving_image_folder+\"/\"+image).astype(np.float32)\n","\n"," # Here we generate the pyramidal images\n"," ifplot=True\n"," if np.ndim(I) == 3:\n"," nChannel=I.shape[2]\n"," pyramid_I = tuple(pyramid_gaussian(gaussian(I, sigma=1, multichannel=True), downscale=downscale, multichannel=True))\n"," pyramid_J = tuple(pyramid_gaussian(gaussian(J, sigma=1, multichannel=True), downscale=downscale, multichannel=True))\n"," elif np.ndim(I) == 2:\n"," nChannel=1\n"," pyramid_I = tuple(pyramid_gaussian(gaussian(I, sigma=1, multichannel=False), downscale=downscale, multichannel=False))\n"," pyramid_J = tuple(pyramid_gaussian(gaussian(J, sigma=1, multichannel=False), downscale=downscale, multichannel=False))\n"," else:\n"," print(\"Unknown rank for an image\")\n","\n","\n"," # create a list of necessary objects you will need and commit to GPU\n"," I_lst,J_lst,h_lst,w_lst,xy_lst,ind_lst=[],[],[],[],[],[]\n","\n","\n"," for s in range(L):\n"," I_ = torch.tensor(cv2.normalize(pyramid_I[s].astype(np.float32), None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)).to(device)\n"," J_ = torch.tensor(cv2.normalize(pyramid_J[s].astype(np.float32), None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)).to(device)\n","\n"," if nChannel>1:\n"," I_lst.append(I_.permute(2,0,1))\n"," J_lst.append(J_.permute(2,0,1))\n"," h_, w_ = I_lst[s].shape[1], I_lst[s].shape[2]\n","\n"," edges_grayscale = cv2.dilate(cv2.Canny(cv2.GaussianBlur(rgb2gray(pyramid_I[s]),(21,21),0).astype(np.uint8), 0, 30),\n"," np.ones((5,5),np.uint8),\n"," iterations = 1)\n"," ind_ = torch.nonzero(torch.tensor(edges_grayscale).view(h_*w_)).squeeze().to(device)[:1000000]\n"," ind_lst.append(ind_)\n"," else:\n"," I_lst.append(I_)\n"," J_lst.append(J_)\n"," h_, w_ = I_lst[s].shape[0], I_lst[s].shape[1]\n","\n"," edges_grayscale = cv2.dilate(cv2.Canny(cv2.GaussianBlur(rgb2gray(pyramid_I[s]),(21,21),0).astype(np.uint8), 0, 30),\n"," np.ones((5,5),np.uint8),\n"," iterations = 1)\n"," ind_ = torch.nonzero(torch.tensor(edges_grayscale).view(h_*w_)).squeeze().to(device)[:1000000]\n"," ind_lst.append(ind_)\n","\n"," h_lst.append(h_)\n"," w_lst.append(w_)\n","\n"," y_, x_ = torch.meshgrid([torch.arange(0,h_).float().to(device), torch.arange(0,w_).float().to(device)])\n"," y_, x_ = 2.0*y_/(h_-1) - 1.0, 2.0*x_/(w_-1) - 1.0\n"," xy_ = torch.stack([x_,y_],2)\n"," xy_lst.append(xy_)\n","\n"," homography_net = HomographyNet().to(device)\n"," mine_net = MINE().to(device)\n","\n"," optimizer = optim.Adam([{'params': mine_net.parameters(), 'lr': 1e-3},\n"," {'params': homography_net.vL, 'lr': 5e-3},\n"," {'params': homography_net.v1, 'lr': 1e-4}], amsgrad=True)\n"," mi_list = []\n"," for itr in range(number_of_iteration):\n"," optimizer.zero_grad()\n"," loss = multi_resolution_loss()\n"," mi_list.append(-loss.item())\n"," loss.backward()\n"," optimizer.step()\n"," clear_output(wait=True)\n"," plt.plot(mi_list)\n"," plt.xlabel('Iteration number')\n"," plt.ylabel('MI')\n"," plt.title(image+\". Image registration \"+str(loop_number)+\" out of \"+str(len(os.listdir(Moving_image_folder)))+\".\")\n"," plt.show()\n","\n"," I_t = torch.tensor(I).to(device) # without Gaussian\n"," J_t = torch.tensor(J).to(device) # without Gaussian\n"," H = homography_net(0)\n"," if nChannel>1:\n"," J_w = AffineTransform(J_t.permute(2,0,1).unsqueeze(0), H, xy_lst[0][:,:,0], xy_lst[0][:,:,1]).squeeze().permute(1,2,0)\n"," else:\n"," J_w = AffineTransform(J_t.unsqueeze(0).unsqueeze(0), H , xy_lst[0][:,:,0], xy_lst[0][:,:,1]).squeeze()\n","\n"," #Apply registration to other channels\n","\n"," if Apply_registration_to_other_channels:\n","\n"," for n_channel in range(1, int(Number_of_other_channels)+1):\n","\n"," channel = imread(Additional_channels_folder+\"/C\"+str(n_channel)+\"_\"+image).astype(np.float32)\n"," channel_t = torch.tensor(channel).to(device)\n"," channel_w = AffineTransform(channel_t.unsqueeze(0).unsqueeze(0), H , xy_lst[0][:,:,0], xy_lst[0][:,:,1]).squeeze()\n"," channel_registered = channel_w.cpu().data.numpy()\n"," io.imsave(Result_folder+'/'+model_name+\"/\"+\"C\"+str(n_channel)+\"_\"+image+\"_\"+Registration_mode+\"_registered.tif\", channel_registered)\n"," \n","# Export results to numpy array\n"," registered = J_w.cpu().data.numpy()\n","# Save results\n"," io.imsave(Result_folder+'/'+model_name+\"/\"+image+\"_\"+Registration_mode+\"_registered.tif\", registered)\n","\n"," loop_number = loop_number + 1\n","\n"," print(\"Your images have been registered and saved in your result_folder\")\n","\n","\n","#Perspective registration\n","\n","if Registration_mode == \"Perspective\":\n","\n"," print(\"Perspective registration.....\")\n","\n"," for image in os.listdir(Moving_image_folder):\n","\n"," if path.isfile(Fixed_image_folder):\n"," I = imread(Fixed_image_folder).astype(np.float32) # fixed image\n","\n"," if path.isdir(Fixed_image_folder):\n"," Fixed_image = os.listdir(Fixed_image_folder)\n"," I = imread(Fixed_image).astype(np.float32) # fixed image\n","\n"," J = imread(Moving_image_folder+\"/\"+image).astype(np.float32)\n","\n"," # Here we generate the pyramidal images\n"," ifplot=True\n"," if np.ndim(I) == 3:\n"," nChannel=I.shape[2]\n"," pyramid_I = tuple(pyramid_gaussian(gaussian(I, sigma=1, multichannel=True), downscale=downscale, multichannel=True))\n"," pyramid_J = tuple(pyramid_gaussian(gaussian(J, sigma=1, multichannel=True), downscale=downscale, multichannel=True))\n"," elif np.ndim(I) == 2:\n"," nChannel=1\n"," pyramid_I = tuple(pyramid_gaussian(gaussian(I, sigma=1, multichannel=False), downscale=downscale, multichannel=False))\n"," pyramid_J = tuple(pyramid_gaussian(gaussian(J, sigma=1, multichannel=False), downscale=downscale, multichannel=False))\n"," else:\n"," print(\"Unknown rank for an image\")\n","\n","\n"," # create a list of necessary objects you will need and commit to GPU\n"," I_lst,J_lst,h_lst,w_lst,xy_lst,ind_lst=[],[],[],[],[],[]\n"," for s in range(L):\n"," I_ = torch.tensor(cv2.normalize(pyramid_I[s].astype(np.float32), None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)).to(device)\n"," J_ = torch.tensor(cv2.normalize(pyramid_J[s].astype(np.float32), None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)).to(device)\n","\n"," if nChannel>1:\n"," I_lst.append(I_.permute(2,0,1))\n"," J_lst.append(J_.permute(2,0,1))\n"," h_, w_ = I_lst[s].shape[1], I_lst[s].shape[2]\n","\n"," ind_ = torch.randperm(int(h_*w_*sampling))\n"," ind_lst.append(ind_)\n"," else:\n"," I_lst.append(I_)\n"," J_lst.append(J_)\n"," h_, w_ = I_lst[s].shape[0], I_lst[s].shape[1]\n","\n"," edges_grayscale = cv2.dilate(cv2.Canny(cv2.GaussianBlur(rgb2gray(pyramid_I[s]),(21,21),0).astype(np.uint8), 0, 10),\n"," np.ones((5,5),np.uint8),\n"," iterations = 1)\n"," ind_ = torch.randperm(int(h_*w_*sampling))\n"," ind_lst.append(ind_)\n"," h_lst.append(h_)\n"," w_lst.append(w_)\n","\n"," y_, x_ = torch.meshgrid([torch.arange(0,h_).float().to(device), torch.arange(0,w_).float().to(device)])\n"," y_, x_ = 2.0*y_/(h_-1) - 1.0, 2.0*x_/(w_-1) - 1.0\n"," xy_ = torch.stack([x_,y_],2)\n"," xy_lst.append(xy_)\n","\n"," homography_net = HomographyNet().to(device)\n"," mine_net = MINE().to(device)\n","\n"," optimizer = optim.Adam([{'params': mine_net.parameters(), 'lr': 1e-3},\n"," {'params': homography_net.vL, 'lr': 1e-3},\n"," {'params': homography_net.v1, 'lr': 1e-4}], amsgrad=True)\n"," mi_list = []\n"," for itr in range(number_of_iteration):\n"," optimizer.zero_grad()\n"," loss = multi_resolution_loss()\n"," mi_list.append(-loss.item())\n"," loss.backward()\n"," optimizer.step()\n"," clear_output(wait=True)\n"," plt.plot(mi_list)\n"," plt.xlabel('Iteration number')\n"," plt.ylabel('MI')\n"," plt.title(image+\". Image registration \"+str(loop_number)+\" out of \"+str(len(os.listdir(Moving_image_folder)))+\".\")\n"," plt.show()\n","\n"," I_t = torch.tensor(I).to(device) # without Gaussian\n"," J_t = torch.tensor(J).to(device) # without Gaussian\n"," H = homography_net(0)\n"," if nChannel>1:\n"," J_w = PerspectiveTransform(J_t.permute(2,0,1).unsqueeze(0), H, xy_lst[0][:,:,0], xy_lst[0][:,:,1]).squeeze().permute(1,2,0)\n"," else:\n"," J_w = PerspectiveTransform(J_t.unsqueeze(0).unsqueeze(0), H , xy_lst[0][:,:,0], xy_lst[0][:,:,1]).squeeze()\n","\n"," #Apply registration to other channels\n","\n"," if Apply_registration_to_other_channels:\n","\n"," for n_channel in range(1, int(Number_of_other_channels)+1):\n","\n"," channel = imread(Additional_channels_folder+\"/C\"+str(n_channel)+\"_\"+image).astype(np.float32)\n"," channel_t = torch.tensor(channel).to(device)\n"," channel_w = PerspectiveTransform(channel_t.unsqueeze(0).unsqueeze(0), H , xy_lst[0][:,:,0], xy_lst[0][:,:,1]).squeeze()\n"," channel_registered = channel_w.cpu().data.numpy()\n"," io.imsave(Result_folder+'/'+model_name+\"/\"+\"C\"+str(n_channel)+\"_\"+image+\"_Perspective_registered.tif\", channel_registered) \n","\n","\n","# Export results to numpy array\n"," registered = J_w.cpu().data.numpy()\n","# Save results\n"," io.imsave(Result_folder+'/'+model_name+\"/\"+image+\"_Perspective_registered.tif\", registered)\n","\n"," loop_number = loop_number + 1\n","\n"," print(\"Your images have been registered and saved in your result_folder\")\n","\n","\n","# PDF export missing \n","\n","#pdf_export(trained = True, augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"PfTw_pQUUAqB"},"source":["## **4.3. Assess the registration**\n","---\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"SrArBvqwYvc9","cellView":"form"},"source":["# @markdown ##Run this cell to display a randomly chosen input and its corresponding predicted output.\n","\n","# For sliders and dropdown menu and progress bar\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","\n","print('--------------------------------------------------------------')\n","@interact\n","def show_QC_results(file = os.listdir(Moving_image_folder)):\n","\n"," moving_image = imread(Moving_image_folder+\"/\"+file).astype(np.float32)\n"," \n"," registered_image = imread(Result_folder+\"/\"+model_name+\"/\"+file+\"_\"+Registration_mode+\"_registered.tif\").astype(np.float32)\n","\n","#Here we display one image\n","\n"," f=plt.figure(figsize=(20,20))\n"," plt.subplot(1,5,1)\n"," plt.imshow(I, norm=simple_norm(I, percent = 99), interpolation='nearest')\n"," plt.title('Fixed image')\n"," plt.axis('off');\n","\n"," plt.subplot(1,5,2)\n"," plt.imshow(moving_image, norm=simple_norm(moving_image, percent = 99), interpolation='nearest')\n"," plt.title('Moving image')\n"," plt.axis('off');\n","\n"," plt.subplot(1,5,3)\n"," plt.imshow(registered_image, norm=simple_norm(registered_image, percent = 99), interpolation='nearest')\n"," plt.title(\"Registered image\")\n"," plt.axis('off');\n","\n"," plt.subplot(1,5,4)\n"," plt.imshow(I, norm=simple_norm(I, percent = 99), interpolation='nearest', cmap=\"Greens\")\n"," plt.imshow(moving_image, norm=simple_norm(moving_image, percent = 99), interpolation='nearest', cmap=\"Oranges\", alpha=0.5)\n"," plt.title(\"Fixed and moving images\")\n"," plt.axis('off');\n","\n"," plt.subplot(1,5,5)\n"," plt.imshow(I, norm=simple_norm(I, percent = 99), interpolation='nearest', cmap=\"Greens\")\n"," plt.imshow(registered_image, norm=simple_norm(registered_image, percent = 99), interpolation='nearest', cmap=\"Oranges\", alpha=0.5)\n"," plt.title(\"Fixed and Registered images\")\n"," plt.axis('off');\n","\n"," plt.show()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"wgO7Ok1PBFQj"},"source":["## **4.4. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"XXsUh88HqYay"},"source":["# **5. Version log**\n","---\n","**v1.13**: \n","\n","* This version now includes built-in version check and the version log that you're reading now."]},{"cell_type":"markdown","metadata":{"id":"nlyPYwZu4VVS"},"source":["#**Thank you for using DRMIME 2D!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/Beta notebooks/DecoNoising_2D_ZeroCostDL4Mic.ipynb b/Colab_notebooks/Beta notebooks/DecoNoising_2D_ZeroCostDL4Mic.ipynb index 1b996ddc..47759cef 100644 --- a/Colab_notebooks/Beta notebooks/DecoNoising_2D_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/Beta notebooks/DecoNoising_2D_ZeroCostDL4Mic.ipynb @@ -1,2061 +1 @@ -{ - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "accelerator": "GPU", - "colab": { - "name": "DecoNoising_2D_ZeroCostDL4Mic.ipynb", - "provenance": [], - "collapsed_sections": [], - "toc_visible": true, - "machine_shape": "hm" - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.4" - } - }, - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "5oqGsnIIhcWe" - }, - "source": [ - "#**Currently missing features**\n", - "\n", - "- Learning rate is not currently saved \n", - "- Transfer learning is included but probably does not work well (last learning rate is not loaded)\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "V9zNGvape2-I" - }, - "source": [ - "# **DecoNoising (2D)**\n", - "\n", - "---\n", - "\n", - " DecoNoising is an unsupervised denoising method that takes in consideration the expected point spread function of the images to denoise to improve the predictions. DecoNoising was originally published by [Goncharova *et al.* on arXiv](https://arxiv.org/abs/2008.08414). \n", - "\n", - " **This particular notebook enables self-supervised denoising of 2D dataset. If you are interested in 3D dataset, you should use the Noise2Void 3D notebook instead.**\n", - "\n", - "---\n", - "\n", - "*Disclaimer*:\n", - "\n", - "This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n", - "\n", - "This notebook is largely based on the following paper:\n", - "\n", - "**Improving Blind Spot Denoising for Microscopy**\n", - "from Goncharova *et al.* published on arXiv in 2020 (https://arxiv.org/abs/2008.08414)\n", - "\n", - "And source code found in: /~https://github.com/juglab/DecoNoising\n", - "\n", - "**Please also cite this original paper when using or developing this notebook.**\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jWAz2i7RdxUV" - }, - "source": [ - "# **How to use this notebook?**\n", - "\n", - "---\n", - "\n", - "Video describing how to use our notebooks are available on youtube:\n", - " - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n", - " - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n", - "\n", - "\n", - "---\n", - "###**Structure of a notebook**\n", - "\n", - "The notebook contains two types of cell: \n", - "\n", - "**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n", - "\n", - "**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n", - "\n", - "---\n", - "###**Table of contents, Code snippets** and **Files**\n", - "\n", - "On the top left side of the notebook you find three tabs which contain from top to bottom:\n", - "\n", - "*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n", - "\n", - "*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n", - "\n", - "*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n", - "\n", - "**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n", - "\n", - "**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n", - "\n", - "---\n", - "###**Making changes to the notebook**\n", - "\n", - "**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n", - "\n", - "To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n", - "You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "vNMDQHm0Ah-Z" - }, - "source": [ - "# **0. Before getting started**\n", - "---\n", - "\n", - "Before you run the notebook, please ensure that you are logged into your Google account and have the training and/or data to process in your Google Drive.\n", - "\n", - "For Noise2Void to train, it only requires a single noisy image but multiple images can be used. Information on how to generate a training dataset is available in our Wiki page: /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n", - "\n", - "Please note that you currently can **only use .tif files!**\n", - "\n", - "**We strongly recommend that you generate high signal to noise ration version of your noisy images (Quality control dataset). These images can be used to assess the quality of your trained model**. The quality control assessment can be done directly in this notebook.\n", - "\n", - " You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed.\n", - "\n", - "Here is a common data structure that can work:\n", - "\n", - "* Data\n", - " - **Training dataset**\n", - " - **Quality control dataset** (Optional but recomended)\n", - " - Low SNR images\n", - " - img_1.tif, img_2.tif\n", - " - High SNR images\n", - " - img_1.tif, img_2.tif \n", - " - **Data to be predicted** \n", - " - Results\n", - "\n", - "\n", - "The **Results** folder will contain the processed images, trained model and network parameters as csv file. Your original images remain unmodified.\n", - "\n", - "---\n", - "**Important note**\n", - "\n", - "- If you wish to **train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n", - "\n", - "- If you wish to **evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n", - "\n", - "- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n", - "---\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "b4-r1gE7Iamv" - }, - "source": [ - "# **1. Initialise the Colab session**\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "DMNHVZfHmbKb" - }, - "source": [ - "\n", - "## **1.1. Check for GPU access**\n", - "---\n", - "\n", - "By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n", - "\n", - "Go to **Runtime -> Change the Runtime type**\n", - "\n", - "**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n", - "\n", - "**Accelerator: GPU** *(Graphics processing unit)*\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "BDhmUgqCStlm", - "cellView": "form" - }, - "source": [ - "#@markdown ##Run this cell to check if you have GPU access\n", - "\n", - "\n", - "import tensorflow as tf\n", - "if tf.test.gpu_device_name()=='':\n", - " print('You do not have GPU access.') \n", - " print('Did you change your runtime ?') \n", - " print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n", - " print('Expect slow performance. To access GPU try reconnecting later')\n", - "\n", - "else:\n", - " print('You have GPU access')\n", - " !nvidia-smi" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-oqBTeLaImnU" - }, - "source": [ - "## **1.2. Mount your Google Drive**\n", - "---\n", - " To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n", - "\n", - " Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n", - "\n", - " Once this is done, your data are available in the **Files** tab on the top left of notebook." - ] - }, - { - "cell_type": "code", - "metadata": { - "cellView": "form", - "id": "01Djr8v-5pPk" - }, - "source": [ - "#@markdown ##Play the cell to connect your Google Drive to Colab\n", - "\n", - "#@markdown * Click on the URL. \n", - "\n", - "#@markdown * Sign in your Google Account. \n", - "\n", - "#@markdown * Copy the authorization code. \n", - "\n", - "#@markdown * Enter the authorization code. \n", - "\n", - "#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n", - "\n", - "# mount user's Google Drive to Google Colab.\n", - "from google.colab import drive\n", - "drive.mount('/content/gdrive')" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "n4yWFoJNnoin" - }, - "source": [ - "# **2. Install DecoNoising and dependencies**\n", - "---" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "GT3rF4uA8gL6", - "cellView": "form" - }, - "source": [ - "Notebook_version = ['1.12']\n", - "\n", - "#@markdown ##Install DecoNoising and dependencies\n", - "\n", - "from builtins import any as b_any\n", - "\n", - "def get_requirements_path():\n", - " # Store requirements file in 'contents' directory\n", - " current_dir = os.getcwd()\n", - " dir_count = current_dir.count('/') - 1\n", - " path = '../' * (dir_count) + 'requirements.txt'\n", - " return path\n", - "\n", - "def filter_files(file_list, filter_list):\n", - " filtered_list = []\n", - " for fname in file_list:\n", - " if b_any(fname.split('==')[0] in s for s in filter_list):\n", - " filtered_list.append(fname)\n", - " return filtered_list\n", - "\n", - "def build_requirements_file(before, after):\n", - " path = get_requirements_path()\n", - "\n", - " # Exporting requirements.txt for local run\n", - " !pip freeze > $path\n", - "\n", - " # Get minimum requirements file\n", - " df = pd.read_csv(path, delimiter = \"\\n\")\n", - " mod_list = [m.split('.')[0] for m in after if not m in before]\n", - " req_list_temp = df.values.tolist()\n", - " req_list = [x[0] for x in req_list_temp]\n", - "\n", - " # Replace with package name and handle cases where import name is different to module name\n", - " mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]\n", - " mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list]\n", - " filtered_list = filter_files(req_list, mod_replace_list)\n", - "\n", - " file=open(path,'w')\n", - " for item in filtered_list:\n", - " file.writelines(item + '\\n')\n", - "\n", - " file.close()\n", - "\n", - "import sys\n", - "before = [str(m) for m in sys.modules]\n", - "\n", - "# ------- Variable specific to DecoNoising -------\n", - "\n", - "!git clone /~https://github.com/juglab/DecoNoising\n", - "\n", - "!git clone /~https://github.com/juglab/PN2V\n", - "\n", - "sys.path.append('/content/PN2V')\n", - "\n", - "from pn2v import training\n", - "\n", - "sys.path.append('/content/DecoNoising')\n", - "\n", - "import matplotlib.pyplot as plt\n", - "from unet.model import UNet\n", - "from deconoising import utils\n", - "from deconoising.utils import PSNR\n", - "from deconoising import training\n", - "from deconoising import prediction\n", - "\n", - "import torch\n", - "\n", - "from scipy.ndimage import gaussian_filter\n", - "\n", - "device=utils.getDevice()\n", - "\n", - "!pip install wget\n", - "!pip install fpdf\n", - "\n", - "\n", - "# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n", - "import numpy as np\n", - "from matplotlib import pyplot as plt\n", - "import urllib\n", - "import os, random\n", - "import shutil \n", - "import zipfile\n", - "from tifffile import imread, imsave\n", - "import time\n", - "import wget\n", - "from pathlib import Path\n", - "import pandas as pd\n", - "import csv\n", - "from glob import glob\n", - "from scipy import signal\n", - "from scipy import ndimage\n", - "from skimage import io\n", - "from sklearn.linear_model import LinearRegression\n", - "from skimage.util import img_as_uint\n", - "import matplotlib as mpl\n", - "from skimage.metrics import structural_similarity\n", - "from skimage.metrics import peak_signal_noise_ratio as psnr\n", - "from astropy.visualization import simple_norm\n", - "from skimage import img_as_float32\n", - "from fpdf import FPDF, HTMLMixin\n", - "from datetime import datetime\n", - "from pip._internal.operations.freeze import freeze\n", - "import subprocess\n", - "from datetime import datetime\n", - "from PIL import Image \n", - "from PIL.TiffTags import TAGS\n", - "import tensorflow as tf\n", - "\n", - "# Colors for the warning messages\n", - "class bcolors:\n", - " WARNING = '\\033[31m'\n", - "W = '\\033[0m' # white (normal)\n", - "R = '\\033[31m' # red\n", - "\n", - "#Disable some of the tensorflow warnings\n", - "import warnings\n", - "warnings.filterwarnings(\"ignore\")\n", - "\n", - "print(\"Libraries installed\")\n", - "\n", - "\n", - "# Check if this is the latest version of the notebook\n", - "Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n", - "print('Notebook version: '+Notebook_version[0])\n", - "strlist = Notebook_version[0].split('.')\n", - "Notebook_version_main = strlist[0]+'.'+strlist[1]\n", - "if Notebook_version_main == Latest_notebook_version.columns:\n", - " print(\"This notebook is up-to-date.\")\n", - "else:\n", - " print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n", - "\n", - "\n", - "def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n", - " class MyFPDF(FPDF, HTMLMixin):\n", - " pass\n", - "\n", - " pdf = MyFPDF()\n", - " pdf.add_page()\n", - " pdf.set_right_margin(-1)\n", - " pdf.set_font(\"Arial\", size = 11, style='B') \n", - "\n", - " Network = 'DecoNoising 2D'\n", - " day = datetime.now()\n", - " datetime_str = str(day)[0:10]\n", - "\n", - " Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n", - " pdf.multi_cell(180, 5, txt = Header, align = 'L') \n", - "\n", - " # add another cell \n", - " if trained:\n", - " training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n", - " pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n", - " pdf.ln(1)\n", - "\n", - " Header_2 = 'Information for your materials and method:'\n", - " pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n", - "\n", - " all_packages = ''\n", - " for requirement in freeze(local_only=True):\n", - " all_packages = all_packages+requirement+', '\n", - " #print(all_packages)\n", - "\n", - " #Main Packages\n", - " main_packages = ''\n", - " version_numbers = []\n", - " for name in ['torch','numpy']:\n", - " find_name=all_packages.find(name)\n", - " main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n", - " #Version numbers only here:\n", - " version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n", - "\n", - " cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n", - " cuda_version = cuda_version.stdout.decode('utf-8')\n", - " cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n", - " gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n", - " gpu_name = gpu_name.stdout.decode('utf-8')\n", - " gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n", - " #print(cuda_version[cuda_version.find(', V')+3:-1])\n", - " #print(gpu_name)\n", - "\n", - " shape = io.imread(Training_source+'/'+os.listdir(Training_source)[0]).shape\n", - " dataset_size = len(os.listdir(Training_source))\n", - "\n", - " text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(len(Training_Filelist))+' images (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a virtual batch size of '+str(virtual_batch_size)+', using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include torch (v '+version_numbers[0]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n", - "\n", - " if pretrained_model:\n", - " text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(len(Training_Filelist))+' images (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a virtual batch size of '+str(virtual_batch_size)+', using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was re-trained from a pretrained model. Key python packages used include torch (v '+version_numbers[0]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n", - "\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - " pdf.multi_cell(190, 5, txt = text, align='L')\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.ln(1)\n", - " pdf.cell(26, 5, txt='Augmentation: ', ln=0)\n", - " pdf.set_font('')\n", - " if Use_Data_augmentation:\n", - " aug_text = 'The dataset was augmented by default.'\n", - " else:\n", - " aug_text = 'No augmentation was used for training.'\n", - " pdf.multi_cell(190, 5, txt=aug_text, align='L')\n", - " pdf.set_font('Arial', size = 11, style = 'B')\n", - " pdf.ln(1)\n", - " pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - " if Use_Default_Advanced_Parameters:\n", - " pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n", - " pdf.cell(200, 5, txt='The following parameters were used for training:')\n", - " pdf.ln(1)\n", - " html = \"\"\" \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
ParameterValue
number_of_epochs{0}
patch_size{1}
batch_size{2}
virtual_batch_size{3}
number_of_steps{4}
percentage_validation{5}
initial_learning_rate{6}
std_gaussian_for_PSF{7}
pixel_size{8}
\n", - " \"\"\".format(number_of_epochs,str(patch_size)+'x'+str(patch_size),batch_size,virtual_batch_size,number_of_steps,percentage_validation,initial_learning_rate,std_gaussian_for_PSF,pixel_size)\n", - " pdf.write_html(html)\n", - "\n", - " #pdf.multi_cell(190, 5, txt = text_2, align='L')\n", - " pdf.set_font(\"Arial\", size = 11, style='B')\n", - " pdf.ln(1)\n", - " pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.cell(28, 5, txt= 'Training_source:', align = 'L', ln=0)\n", - " pdf.set_font('')\n", - " pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n", - " # pdf.set_font('')\n", - " # pdf.set_font('Arial', size = 10, style = 'B')\n", - " # pdf.cell(28, 5, txt= 'Training_target:', align = 'L', ln=0)\n", - " # pdf.set_font('')\n", - " # pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n", - " #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n", - " pdf.ln(1)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.cell(21, 5, txt= 'Model Path:', align = 'L', ln=0)\n", - " pdf.set_font('')\n", - " pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n", - " pdf.ln(1)\n", - " pdf.cell(60, 5, txt = 'Example Training Image', ln=1)\n", - " pdf.ln(1)\n", - " exp_size = io.imread('/content/TrainingDataExample.png').shape\n", - " pdf.image('/content/TrainingDataExample.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n", - " pdf.ln(1)\n", - " ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n", - " pdf.multi_cell(190, 5, txt = ref_1, align='L')\n", - " ref_2 = '- DecoNoising: Goncharova Anna et al. \"Improving Blind Spot Denoising for Microscopy\" https://arxiv.org/abs/2008.08414. 2020.'\n", - " pdf.multi_cell(190, 5, txt = ref_2, align='L')\n", - " pdf.ln(3)\n", - " reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n", - " pdf.set_font('Arial', size = 11, style='B')\n", - " pdf.multi_cell(190, 5, txt=reminder, align='C')\n", - "\n", - " pdf.output(model_path+'/'+model_name+'/'+model_name+\"_training_report.pdf\")\n", - "\n", - "\n", - "\n", - " #Make a pdf summary of the QC results\n", - "\n", - "def qc_pdf_export():\n", - " class MyFPDF(FPDF, HTMLMixin):\n", - " pass\n", - "\n", - " pdf = MyFPDF()\n", - " pdf.add_page()\n", - " pdf.set_right_margin(-1)\n", - " pdf.set_font(\"Arial\", size = 11, style='B') \n", - "\n", - " Network = 'DecoNoising 2D'\n", - "\n", - " day = datetime.now()\n", - " datetime_str = str(day)[0:10]\n", - "\n", - " Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n", - " pdf.multi_cell(180, 5, txt = Header, align = 'L') \n", - "\n", - " all_packages = ''\n", - " for requirement in freeze(local_only=True):\n", - " all_packages = all_packages+requirement+', '\n", - "\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 11, style = 'B')\n", - " pdf.ln(2)\n", - " pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n", - " pdf.ln(1)\n", - " exp_size = io.imread(full_QC_model_path+'/Quality Control/lossCurvePlots.png').shape\n", - " if os.path.exists(full_QC_model_path+'/Quality Control/lossCurvePlots.png'):\n", - " pdf.image(full_QC_model_path+'/Quality Control/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n", - " else:\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size=10)\n", - " pdf.cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.')\n", - " pdf.ln(2)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.ln(3)\n", - " pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n", - " pdf.ln(1)\n", - " exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n", - " pdf.image(full_QC_model_path+'/Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/10))\n", - " pdf.ln(1)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 11, style = 'B')\n", - " pdf.ln(1)\n", - " pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - "\n", - " pdf.ln(1)\n", - " html = \"\"\"\n", - " \n", - " \n", - " \"\"\"\n", - " with open(full_QC_model_path+'/Quality Control/QC_metrics_'+QC_model_name+'.csv', 'r') as csvfile:\n", - " metrics = csv.reader(csvfile)\n", - " header = next(metrics)\n", - " image = header[0]\n", - " mSSIM_PvsGT = header[1]\n", - " mSSIM_SvsGT = header[2]\n", - " NRMSE_PvsGT = header[3]\n", - " NRMSE_SvsGT = header[4]\n", - " PSNR_PvsGT = header[5]\n", - " PSNR_SvsGT = header[6]\n", - " header = \"\"\"\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \"\"\".format(image,mSSIM_PvsGT,mSSIM_SvsGT,NRMSE_PvsGT,NRMSE_SvsGT,PSNR_PvsGT,PSNR_SvsGT)\n", - " html = html+header\n", - " for row in metrics:\n", - " image = row[0]\n", - " mSSIM_PvsGT = row[1]\n", - " mSSIM_SvsGT = row[2]\n", - " NRMSE_PvsGT = row[3]\n", - " NRMSE_SvsGT = row[4]\n", - " PSNR_PvsGT = row[5]\n", - " PSNR_SvsGT = row[6]\n", - " cells = \"\"\"\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \"\"\".format(image,str(round(float(mSSIM_PvsGT),3)),str(round(float(mSSIM_SvsGT),3)),str(round(float(NRMSE_PvsGT),3)),str(round(float(NRMSE_SvsGT),3)),str(round(float(PSNR_PvsGT),3)),str(round(float(PSNR_SvsGT),3)))\n", - " html = html+cells\n", - " html = html+\"\"\"
{0}{1}{2}{3}{4}{5}{6}
{0}{1}{2}{3}{4}{5}{6}
\"\"\"\n", - " \n", - " pdf.write_html(html)\n", - "\n", - " pdf.ln(1)\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - " ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n", - " pdf.multi_cell(190, 5, txt = ref_1, align='L')\n", - " ref_2 = '- DecoNoising: Goncharova Anna et al. \"Improving Blind Spot Denoising for Microscopy\" https://arxiv.org/abs/2008.08414. 2020.'\n", - " pdf.multi_cell(190, 5, txt = ref_2, align='L')\n", - " pdf.ln(3)\n", - " reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n", - "\n", - " pdf.set_font('Arial', size = 11, style='B')\n", - " pdf.multi_cell(190, 5, txt=reminder, align='C')\n", - "\n", - " pdf.output(full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n", - "\n", - "\n", - "def getPixelSizeTIFFmetadata(TIFFpath, display=False):\n", - " with Image.open(TIFFpath) as img:\n", - " meta_dict = {TAGS[key] : img.tag[key] for key in img.tag.keys()}\n", - "\n", - "\n", - " # TIFF tags\n", - " # https://www.loc.gov/preservation/digital/formats/content/tiff_tags.shtml\n", - " # https://www.awaresystems.be/imaging/tiff/tifftags/resolutionunit.html\n", - " ResolutionUnit = meta_dict['ResolutionUnit'][0] # unit of resolution\n", - " width = meta_dict['ImageWidth'][0]\n", - " height = meta_dict['ImageLength'][0]\n", - "\n", - " xResolution = meta_dict['XResolution'][0] # number of pixels / ResolutionUnit\n", - "\n", - " if len(xResolution) == 1:\n", - " xResolution = xResolution[0]\n", - " elif len(xResolution) == 2:\n", - " xResolution = xResolution[0]/xResolution[1]\n", - " else:\n", - " print('Image resolution not defined.')\n", - " xResolution = 1\n", - "\n", - " if ResolutionUnit == 2:\n", - " # Units given are in inches\n", - " pixel_size = 0.025*1e9/xResolution\n", - " elif ResolutionUnit == 3:\n", - " # Units given are in cm\n", - " pixel_size = 0.01*1e9/xResolution\n", - " else: \n", - " # ResolutionUnit is therefore 1\n", - " print('Resolution unit not defined. Assuming: um')\n", - " pixel_size = 1e3/xResolution\n", - "\n", - " if display:\n", - " print('Pixel size obtained from metadata: '+str(pixel_size)+' nm')\n", - " print('Image size: '+str(width)+'x'+str(height))\n", - " \n", - " return (pixel_size, width, height)\n", - "\n", - "# Build requirements file for local run\n", - "after = [str(m) for m in sys.modules]\n", - "build_requirements_file(before, after)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Fw0kkTU6CsU4" - }, - "source": [ - "# **3. Select your parameters and paths**\n", - "\n", - "---\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "WzYAA-MuaYrT" - }, - "source": [ - "## **3.1. Setting main training parameters**\n", - "---\n", - "\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "CB6acvUFtWqd" - }, - "source": [ - " **Paths for training, predictions and results**\n", - "\n", - "**`Training_source:`:** These is the path to your folders containing the Training_source (noisy images). To find the path of the folder containing your datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n", - "\n", - "**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Do not re-use the name of an existing model (saved in the same folder), otherwise it will be overwritten.\n", - "\n", - "**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n", - "\n", - "\n", - "**Training Parameters**\n", - "\n", - "\n", - "**`number_of_epochs`:** Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for 100-200 epochs. Evaluate the performance after training (see 4.3.). **Default value: 100**\n", - "\n", - "**`patch_size`:** Deconoising divides the image into patches for training. Input the size of the patches (length of a side). The value should be between 64 and the dimensions of the image and divisible by 8. **Default value: 80**\n", - " \n", - "**Advanced Parameters - experienced users only**\n", - "\n", - "**`batch_size:`** This parameter defines the number of patches seen in each training step. **Default value: 4**\n", - "\n", - "**`virtual_batch_size:`** The number of batches that are processed before a gradient step is performed. **Default value: 20**\n", - "\n", - "**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each image / patch is seen at least once per epoch. **Default value: Number of images / batch_size**\n", - "\n", - "**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during the training. **Default value: 20**\n", - "\n", - "**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.001**\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "ewpNJ_I0Mv47", - "cellView": "form" - }, - "source": [ - "\n", - "#@markdown ###Path to training image(s): \n", - "Training_source = \"\" #@param {type:\"string\"}\n", - "\n", - "#compatibility to easily change the name of the parameters\n", - "training_images = Training_source \n", - "\n", - "#@markdown ### Model name and path:\n", - "model_name = \"\" #@param {type:\"string\"}\n", - "model_path = \"\" #@param {type:\"string\"}\n", - "\n", - "full_model_path = model_path+'/'+model_name+'/'\n", - "\n", - "\n", - "#@markdown Number of epochs:\n", - "number_of_epochs = 200#@param {type:\"number\"}\n", - "\n", - "#@markdown Patch size (pixels)\n", - "patch_size = 80#@param {type:\"number\"}\n", - "\n", - "#@markdown ###Advanced Parameters\n", - "\n", - "Use_Default_Advanced_Parameters = True#@param {type:\"boolean\"}\n", - "\n", - "#@markdown ###If not, please input:\n", - "batch_size = 4#@param {type:\"number\"}\n", - "virtual_batch_size = 20 #@param {type:\"number\"}\n", - "number_of_steps = 12#@param {type:\"number\"}\n", - "percentage_validation = 20#@param {type:\"number\"}\n", - "initial_learning_rate = 0.001 #@param {type:\"number\"}\n", - "\n", - "\n", - "if (Use_Default_Advanced_Parameters): \n", - " print(\"Default advanced parameters enabled\")\n", - " # number_of_steps is defined in the following cell in this case\n", - " batch_size = 4\n", - " percentage_validation = 20\n", - " initial_learning_rate = 0.001\n", - " virtual_batch_size = 20\n", - " \n", - "#here we check that no model with the same name already exist, if so print a warning\n", - "\n", - "if os.path.exists(model_path+'/'+model_name):\n", - " print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n", - " print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\"+W)\n", - "\n", - "\n", - "# This will open a randomly chosen dataset input image\n", - "random_choice = random.choice(os.listdir(Training_source))\n", - "x = imread(Training_source+\"/\"+random_choice)\n", - "\n", - "# Here we check the image dimensions\n", - "\n", - "Image_Y = x.shape[0]\n", - "Image_X = x.shape[1]\n", - "\n", - "\n", - "#Hyperparameters failsafes\n", - "\n", - "# Here we check that patch_size is smaller than the smallest xy dimension of the image \n", - "if patch_size > min(Image_Y, Image_X):\n", - " patch_size = min(Image_Y, Image_X)\n", - " print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n", - "\n", - "# Here we check that patch_size is divisible by 8\n", - "if not patch_size % 8 == 0:\n", - " patch_size = ((int(patch_size / 8)-1) * 8)\n", - " print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is now:\",patch_size)\n", - "\n", - "\n", - "# Here we check that the input images contains the expected dimensions\n", - "if len(x.shape) == 2:\n", - " print('Loaded images (width, length) =', x.shape)\n", - "\n", - "\n", - "if not len(x.shape) == 2:\n", - " print(bcolors.WARNING +\"Your images appear to have the wrong dimensions. Image dimension\",x.shape)\n", - "\n", - "#Here we split the images between training and validation\n", - "\n", - "# Here we count the number of files in the training target folder\n", - "Noisy_Filelist = os.listdir(Training_source)\n", - "Noisy_number_files = len(Noisy_Filelist)\n", - "\n", - "# Here we count the number of file to use for validation\n", - "Noisy_for_validation = int((Noisy_number_files)/percentage_validation)\n", - "\n", - "if Noisy_for_validation == 0:\n", - " Noisy_for_validation = 1\n", - "\n", - "#Here we split the training dataset between training and validation\n", - "# Everything is copied in the /Content Folder\n", - "Training_source_temp = \"/content/training_source\"\n", - "\n", - "if os.path.exists(Training_source_temp):\n", - " shutil.rmtree(Training_source_temp)\n", - "os.makedirs(Training_source_temp)\n", - "\n", - "\n", - "Validation_source_temp = \"/content/validation_source\"\n", - "\n", - "if os.path.exists(Validation_source_temp):\n", - " shutil.rmtree(Validation_source_temp)\n", - "os.makedirs(Validation_source_temp)\n", - "\n", - "list_source = os.listdir(os.path.join(Training_source))\n", - "\n", - "#Move files into the temporary source and target directories:\n", - "\n", - "for f in os.listdir(os.path.join(Training_source)):\n", - " shutil.copy(Training_source+\"/\"+f, Training_source_temp+\"/\"+f)\n", - "\n", - "#Here we move images to be used for validation\n", - "for i in range(Noisy_for_validation): \n", - " shutil.move(Training_source_temp+\"/\"+list_source[i], Validation_source_temp+\"/\"+list_source[i])\n", - "\n", - "# Here we disable pre-trained model by default (in case the next cell is not run)\n", - "Use_pretrained_model = False\n", - "\n", - "# Here we enable data augmentation by default (in case the cell is not ran)\n", - "Use_Data_augmentation = True\n", - "\n", - "print(\"Parameters initiated.\")\n", - "\n", - "#Here we display one image\n", - "norm = simple_norm(x, percent = 99)\n", - "\n", - "f=plt.figure(figsize=(16,8))\n", - "plt.subplot(1,2,1)\n", - "plt.imshow(x, interpolation='nearest', norm=norm, cmap='magma')\n", - "plt.title('Training source')\n", - "plt.axis('off');\n", - "plt.savefig('/content/TrainingDataExample.png',bbox_inches='tight',pad_inches=0)\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "j5aThu1DaB04" - }, - "source": [ - "## **3.2. PSF simulator**\n", - "---\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ibpaGXPBQaNP" - }, - "source": [ - " DecoNoising uses the expected point spread function of your images to denoise and improve the predictions. Here we simulate a point spread function using a Gaussian function. \n", - "\n", - "**`Lambda`:** This corresponds to the wavelength **in nanoometers** of the fluorescence emission in the experiment.\n", - "\n", - "**`NA`:** This corresponds to the numerical aperture of the microscope objective used in the experiment.\n", - "\n", - "**`pixel_size:`:** Indicate the pixel size of your images (**in nanometers**). This information can be found, for instance, by opening your images in Fiji. It can also sometimes be directly extracted from the metadata.\n", - "\n", - "The simulator uses the Gaussian approximation of the point spread function (PSF). The nominal value of the standard deviation of the corresponding 2D Gaussian function here is evaluated using `0.21 x Lambda / NA`, described by [Zhang *et al.*, Applied Optics 2007](https://doi.org/10.1364/AO.46.001819)). This equation however likely represents an optimistic estimation of the size of the PSF and therefore should be taken as an under-estimation of the real experimental PSF. Users may consider increasing Lambda or decreasing NA slightly to take into account potential deviation from this ideal case.\n", - "\n", - "\n", - "\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "u2ag5fe_Ine2", - "cellView": "form" - }, - "source": [ - "\n", - "#@markdown ###Run this cell to simulate your PSF: \n", - "\n", - "Lambda = 500 #@param {type:\"number\"}\n", - "NA = 1.2 #@param {type:\"number\"}\n", - "\n", - "Load_pixel_size_from_metadata = False#@param {type:\"boolean\"}\n", - "#@markdown ###If not, please input:\n", - "\n", - "pixel_size = 250#@param {type:\"number\"}\n", - "\n", - "sigma_nm = 0.21*Lambda/NA\n", - "print(\"Gaussian PSF sigma: \"+str(sigma_nm)+\" um\")\n", - "std_gaussian_for_PSF = sigma_nm/pixel_size\n", - "print(\"Gaussian PSF sigma: \"+str(std_gaussian_for_PSF)+\" pixels\")\n", - "\n", - "\n", - "if Load_pixel_size_from_metadata:\n", - " pixel_size,_,_ = getPixelSizeTIFFmetadata(Training_source+\"/\"+random_choice, True)\n", - "\n", - "#size_of_psf = 24*int(std_gaussian_for_PSF)+1 # 12 STD wide is good enough\n", - "\n", - "size_of_psf = patch_size+1\n", - "\n", - "if size_of_psf > patch_size:\n", - " size_of_psf = (patch_size+1)\n", - "\n", - "if size_of_psf > patch_size:\n", - " size_of_psf = (patch_size+1)\n", - "\n", - "\n", - "def artificial_psf(size_of_psf = size_of_psf, std_gauss = std_gaussian_for_PSF): \n", - " filt = np.zeros((size_of_psf, size_of_psf))\n", - " p = (size_of_psf - 1)//2 # integer division\n", - " filt[p,p] = 1\n", - " filt = torch.tensor(gaussian_filter(filt,std_gaussian_for_PSF).reshape(1,1,size_of_psf,size_of_psf).astype(np.float32))\n", - " filt = filt/torch.sum(filt)\n", - " return filt\n", - "\n", - "psf_tensor = artificial_psf(std_gauss = std_gaussian_for_PSF)\n", - "# plt.imshow(psf_tensor[0, 0, 20:60, 20:60], cmap = 'gray', interpolation= 'none')\n", - "plt.imshow(psf_tensor[0, 0, :, :], cmap = 'gray', interpolation= 'none')\n", - "# plt.axis('off')\n", - "\n", - "\n", - "\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "xGcl7WGP4WHt" - }, - "source": [ - "## **3.3. Data augmentation**\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "5Lio8hpZ4PJ1" - }, - "source": [ - "Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n", - "\n", - "Data augmentation is performed here by rotating the patches in XY-Plane and flip them along X-Axis. This only works if the patches are square in XY.\n", - "\n", - " **By default data augmentation is enabled. Disable this option is you run out of RAM during the training**.\n", - " " - ] - }, - { - "cell_type": "code", - "metadata": { - "cellView": "form", - "id": "htqjkJWt5J_8" - }, - "source": [ - "#Data augmentation\n", - "\n", - "#@markdown ##Play this cell to enable or disable data augmentation: \n", - "\n", - "Use_Data_augmentation = True #@param {type:\"boolean\"}\n", - "\n", - "if Use_Data_augmentation:\n", - " print(\"Data augmentation enabled\")\n", - "\n", - "if not Use_Data_augmentation:\n", - " print(\"Data augmentation disabled\")" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "bQDuybvyadKU" - }, - "source": [ - "\n", - "## **3.4. Using weights from a pre-trained model as initial weights** NOT implemented \n", - "---\n", - " Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a Deconoising 2D model**. \n", - "\n", - " This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n", - "\n", - " In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. " - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "8vPkzEBNamE4", - "cellView": "form" - }, - "source": [ - "# @markdown ##Loading weights from a pre-trained network\n", - "\n", - "Use_pretrained_model = False #@param {type:\"boolean\"}\n", - "\n", - "pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n", - "\n", - "Weights_choice = \"best\" #@param [\"last\", \"best\"]\n", - "\n", - "\n", - "#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n", - "pretrained_model_path = \"\" #@param {type:\"string\"}\n", - "\n", - "#Here we define the loaded model name and path\n", - "pretrained_model_name = os.path.basename(pretrained_model_path)\n", - "\n", - "# --------------------- Check if we load a previously trained model ------------------------\n", - "if Use_pretrained_model:\n", - "\n", - "# --------------------- Load the model from the choosen path ------------------------\n", - " if pretrained_model_choice == \"Model_from_file\":\n", - " h5_file_path = os.path.join(pretrained_model_path, Weights_choice+\"_\"+pretrained_model_name+\".net\")\n", - "\n", - "\n", - "# --------------------- Download the a model provided in the XXX ------------------------\n", - "\n", - " if pretrained_model_choice == \"Model_name\":\n", - " pretrained_model_name = \"Model_name\"\n", - " pretrained_model_path = \"/content/\"+pretrained_model_name\n", - " print(\"Downloading the 2D_Demo_Model_from_Stardist_2D_paper\")\n", - " if os.path.exists(pretrained_model_path):\n", - " shutil.rmtree(pretrained_model_path)\n", - " os.makedirs(pretrained_model_path)\n", - " wget.download(\"\", pretrained_model_path)\n", - " wget.download(\"\", pretrained_model_path)\n", - " wget.download(\"\", pretrained_model_path) \n", - " wget.download(\"\", pretrained_model_path)\n", - " h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n", - "\n", - "# --------------------- Add additional pre-trained models here ------------------------\n", - "\n", - "\n", - "\n", - "# --------------------- Check the model exist ------------------------\n", - "# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n", - " if not os.path.exists(h5_file_path):\n", - " print(bcolors.WARNING+'WARNING: weights_last.h5 pretrained model does not exist')\n", - " Use_pretrained_model = False\n", - "\n", - " \n", - "# If the model path contains a pretrain model, we load the training rate, \n", - " if os.path.exists(h5_file_path):\n", - "#Here we check if the learning rate can be loaded from the quality control folder\n", - " if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n", - "\n", - " with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n", - " csvRead = pd.read_csv(csvfile, sep=',')\n", - " #print(csvRead)\n", - " \n", - " if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n", - " print(\"pretrained network learning rate found\")\n", - " #find the last learning rate\n", - " lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n", - " #Find the learning rate corresponding to the lowest validation loss\n", - " min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n", - " #print(min_val_loss)\n", - " bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n", - "\n", - " if Weights_choice == \"last\":\n", - " print('Last learning rate: '+str(lastLearningRate))\n", - "\n", - " if Weights_choice == \"best\":\n", - " print('Learning rate of best validation loss: '+str(bestLearningRate))\n", - "\n", - " if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n", - " bestLearningRate = initial_learning_rate\n", - " lastLearningRate = initial_learning_rate\n", - " print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n", - "\n", - "#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n", - " if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n", - " print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n", - " bestLearningRate = initial_learning_rate\n", - " lastLearningRate = initial_learning_rate\n", - "\n", - "\n", - "# Display info about the pretrained model to be loaded (or not)\n", - "if Use_pretrained_model:\n", - " print('Weights found in:')\n", - " print(h5_file_path)\n", - " print('will be loaded prior to training.')\n", - "\n", - "else:\n", - " print(bcolors.WARNING+'No pretrained nerwork will be used.')\n", - "\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "rQndJj70FzfL" - }, - "source": [ - "# **4. Train the network**\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tGW2iaU6X5zi" - }, - "source": [ - "## **4.1. Prepare the training data and model for training**\n", - "---\n", - "Here, we use the information from 3. to build the model and convert the training data into a suitable format for training." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "WMJnGJpCMa4y", - "cellView": "form" - }, - "source": [ - "#@markdown ##Create the model and dataset objects\n", - "\n", - "# --------------------- Here we delete the model folder if it already exist ------------------------\n", - "\n", - "if os.path.exists(model_path+'/'+model_name):\n", - " print(bcolors.WARNING +\"!! WARNING: Model folder already exists and has been removed !!\" + W)\n", - " shutil.rmtree(model_path+'/'+model_name)\n", - "\n", - "os.makedirs(full_model_path)\n", - "\n", - "\n", - "#Data preparation Load all the images into a single numpy array\n", - "\n", - "Training_Filelist = os.listdir(Training_source_temp)\n", - "\n", - "Nb_of_training_images = len(Training_Filelist)\n", - "\n", - "Validation_Filelist = os.listdir(Validation_source_temp)\n", - "\n", - "Nb_of_validation_images = len(Validation_Filelist)\n", - "\n", - "print(str(Nb_of_training_images)+\" training images\")\n", - "print(str(Nb_of_validation_images)+\" validation images\")\n", - "\n", - "my_train_data=np.array([np.array(imread(Training_source_temp+\"/\"+fname)) for fname in Training_Filelist]).astype(np.float32)\n", - "my_val_data=np.array([np.array(imread(Validation_source_temp+\"/\"+fname)) for fname in Validation_Filelist]).astype(np.float32)\n", - "\n", - "#Subtract the mean value of the background. It can be measured in Fiji, for example. This is important for the positivity constraint, which requires the background to be at 0.\n", - "#my_train_data = my_train_data - Average_background\n", - "#my_val_data = my_val_data - Average_background\n", - "\n", - "\n", - "#automated substraction of the lowest value my_train_data\n", - "n_image_my_train_data = my_train_data.shape[0]\n", - "for n in range(n_image_my_train_data):\n", - " my_train_data[n] = my_train_data[n] - np.amin(my_train_data[n])\n", - "\n", - "#automated substraction of the lowest value in my_val_data\n", - "\n", - "n_image_my_val_data = my_val_data.shape[0]\n", - "for n in range(n_image_my_val_data):\n", - " my_val_data[n] = my_val_data[n] - np.amin(my_val_data[n])\n", - "\n", - "#Here we automatically define number_of_step in function of training data and batch size\n", - "\n", - "if (Use_Default_Advanced_Parameters) or (number_of_steps == 0): \n", - " #number_of_steps = int(Image_X*Image_Y/(patch_size*patch_size)*(int(len(Training_Filelist)/batch_size)+1))\n", - " number_of_steps = 20\n", - "\n", - " print(str(number_of_steps)+\" steps per EPOCH\")\n", - "\n", - "# --------------------- Using pretrained model ------------------------\n", - "#Here we ensure that the learning rate set correctly when using pre-trained models\n", - "if Use_pretrained_model:\n", - " if Weights_choice == \"last\":\n", - " initial_learning_rate = lastLearningRate\n", - "\n", - " if Weights_choice == \"best\": \n", - " initial_learning_rate = bestLearningRate\n", - "# --------------------- ---------------------- ------------------------\n", - " \n", - "#Prepare the model\n", - "\n", - "device=utils.getDevice()\n", - "\n", - "# The network requires only a single output unit per pixel\n", - "\n", - "if Use_pretrained_model:\n", - " net = torch.load(h5_file_path)\n", - "\n", - "if not Use_pretrained_model:\n", - " net = UNet(1, depth=3)\n", - "\n", - "net.psf=psf_tensor.to(device)\n", - "\n", - "#Export summary of training parameters as pdf\n", - "pdf_export(augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)\n", - "\n", - "print(\"Setup done.\")\n", - "\n", - "\n", - "# creates a plot and shows one training patch and one validation patch.\n", - "\n", - "norm = simple_norm(my_train_data[0], percent = 99)\n", - "\n", - "plt.figure(figsize=(10,5))\n", - "plt.subplot(121)\n", - "plt.imshow(my_train_data[0], norm=norm, cmap='magma')\n", - "plt.axis('off')\n", - "plt.title('Training image');\n", - "plt.subplot(122)\n", - "plt.imshow(my_val_data[0], norm=norm, cmap='magma')\n", - "plt.axis('off')\n", - "plt.title('Validation image');\n", - "\n", - "plt.show()\n", - "\n", - "#pdf_export(pretrained_model = Use_pretrained_model)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "wQPz0F6JlvJR" - }, - "source": [ - "## **4.2. Start Training**\n", - "---\n", - "When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n", - "\n", - "* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this \n", - "point.\n", - "\n", - "Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder from Google Drive as all data can be erased at the next training if using the same folder.\n", - "\n", - "\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "j_Qm5JBmlvJg", - "cellView": "form" - }, - "source": [ - "start = time.time()\n", - "\n", - "#@markdown ##Start training\n", - "\n", - "\n", - "\n", - "# Start training.\n", - "trainHist, valHist = training.trainNetwork(net = net, trainData = my_train_data, valData = my_val_data,\n", - " postfix = model_name, directory = full_model_path,\n", - " device = device, augment=Use_Data_augmentation, numOfEpochs = number_of_epochs, stepsPerEpoch = number_of_steps, \n", - " virtualBatchSize = virtual_batch_size, batchSize = batch_size, patchSize=patch_size, learningRate = initial_learning_rate,\n", - " psf = psf_tensor.to(device),positivity_constraint = 1)\n", - "\n", - "\n", - "print(\"Training done.\")\n", - "\n", - "\n", - "# convert the history.history dict to a pandas DataFrame: \n", - "#lossData = pd.DataFrame(history.valHist) \n", - "\n", - "if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n", - " shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n", - "\n", - "os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n", - "\n", - "# The training evaluation.csv is saved (overwrites the Files if needed). \n", - "lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n", - "with open(lossDataCSVpath, 'w') as f:\n", - " writer = csv.writer(f)\n", - " writer.writerow(['loss','val_loss'])\n", - " for i in range(len(valHist)):\n", - " writer.writerow([trainHist[i], valHist[i]])\n", - "\n", - "# Displaying the time elapsed for training\n", - "dt = time.time() - start\n", - "mins, sec = divmod(dt, 60) \n", - "hour, mins = divmod(mins, 60) \n", - "print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n", - "\n", - "pdf_export(trained = True, pretrained_model = Use_pretrained_model)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "QYuIOWQ3imuU" - }, - "source": [ - "# **5. Evaluate your model**\n", - "---\n", - "\n", - "This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n", - "\n", - "**We highly recommend to perform quality control on all newly trained models.**\n", - "\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "cellView": "form", - "id": "zazOZ3wDx0zQ" - }, - "source": [ - "# model name and path\n", - "#@markdown ###Do you want to assess the model you just trained ?\n", - "Use_the_current_trained_model = True #@param {type:\"boolean\"}\n", - "\n", - "#@markdown ###If not, please provide the path to the model folder:\n", - "\n", - "QC_model_folder = \"\" #@param {type:\"string\"}\n", - "\n", - "#Here we define the loaded model name and path\n", - "QC_model_name = os.path.basename(QC_model_folder)\n", - "QC_model_path = os.path.dirname(QC_model_folder)\n", - "\n", - "if (Use_the_current_trained_model): \n", - " QC_model_name = model_name\n", - " QC_model_path = model_path\n", - "\n", - "full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n", - "if os.path.exists(full_QC_model_path):\n", - " print(\"The \"+QC_model_name+\" network will be evaluated\")\n", - "else:\n", - " \n", - " print(bcolors.WARNING + '!! WARNING: The chosen model does not exist !!')\n", - " print('Please make sure you provide a valid model path and model name before proceeding further.')\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "yDY9dtzdUTLh" - }, - "source": [ - "## **5.1. Inspection of the loss function**\n", - "---\n", - "\n", - "It is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n", - "\n", - "**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n", - "\n", - "**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n", - "\n", - "During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n", - "\n", - "Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact noise patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased." - ] - }, - { - "cell_type": "code", - "metadata": { - "cellView": "form", - "id": "vMzSP50kMv5p" - }, - "source": [ - "#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n", - "\n", - "lossDataFromCSV = []\n", - "vallossDataFromCSV = []\n", - "\n", - "with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n", - " csvRead = csv.reader(csvfile, delimiter=',')\n", - " next(csvRead)\n", - " for row in csvRead:\n", - " lossDataFromCSV.append(float(row[0]))\n", - " vallossDataFromCSV.append(float(row[1]))\n", - "\n", - "epochNumber = range(len(lossDataFromCSV))\n", - "plt.figure(figsize=(15,10))\n", - "\n", - "plt.subplot(2,1,1)\n", - "plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n", - "plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n", - "plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n", - "plt.ylabel('Loss')\n", - "plt.xlabel('Epoch number')\n", - "plt.legend()\n", - "\n", - "plt.subplot(2,1,2)\n", - "plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n", - "plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n", - "plt.title('Training loss and validation loss vs. epoch number (log scale)')\n", - "plt.ylabel('Loss')\n", - "plt.xlabel('Epoch number')\n", - "plt.legend()\n", - "plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png')\n", - "plt.show()\n", - "\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "biT9FI9Ri77_" - }, - "source": [ - "## **5.2. Error mapping and quality metrics estimation**\n", - "---\n", - "\n", - "This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" !\n", - "\n", - "**1. The SSIM (structural similarity) map** \n", - "\n", - "The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n", - "\n", - "**mSSIM** is the SSIM value calculated across the entire window of both images.\n", - "\n", - "**The output below shows the SSIM maps with the mSSIM**\n", - "\n", - "**2. The RSE (Root Squared Error) map** \n", - "\n", - "This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n", - "\n", - "\n", - "**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n", - "\n", - "**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n", - "\n", - "**The output below shows the RSE maps with the NRMSE and PSNR values.**\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "cellView": "form", - "id": "nAs4Wni7VYbq" - }, - "source": [ - "#@markdown ##Choose the folders that contain your Quality Control dataset\n", - "\n", - "Source_QC_folder = \"\" #@param{type:\"string\"}\n", - "Target_QC_folder = \"\" #@param{type:\"string\"}\n", - "\n", - "Weights_choice = \"best\" #@param [\"last\", \"best\"]\n", - "\n", - "# Create a quality control/Prediction Folder\n", - "if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"):\n", - " shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n", - "\n", - "os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n", - "\n", - "\n", - "# List Tif images in Source_QC_folder\n", - "Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\n", - "Z = sorted(glob(Source_QC_folder_tif))\n", - "Z = list(map(imread,Z))\n", - "\n", - "print('Number of test dataset found in the folder: '+str(len(Z)))\n", - "\n", - "\n", - "# Activate the pretrained models. \n", - "net = torch.load(QC_model_path+\"/\"+QC_model_name+\"/\" + Weights_choice+\"_\" + QC_model_name + \".net\")\n", - "\n", - "\n", - "# Perform prediction on all datasets in the Source_QC folder\n", - "for filename in os.listdir(Source_QC_folder):\n", - " img = imread(os.path.join(Source_QC_folder, filename))\n", - " deconvolvedResult, denoisedResult = prediction.tiledPredict(img, net ,ps=256, overlap=48, device=device)\n", - "\n", - " os.chdir(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n", - " imsave(filename, denoisedResult)\n", - "\n", - "def ssim(img1, img2):\n", - " return structural_similarity(img1,img2,data_range=1.,full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n", - "\n", - "\n", - "def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n", - " \"\"\"This function is adapted from Martin Weigert\"\"\"\n", - " \"\"\"Percentile-based image normalization.\"\"\"\n", - "\n", - " mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n", - " ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n", - " return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n", - "\n", - "\n", - "def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n", - " \"\"\"This function is adapted from Martin Weigert\"\"\"\n", - " if dtype is not None:\n", - " x = x.astype(dtype,copy=False)\n", - " mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n", - " ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n", - " eps = dtype(eps)\n", - "\n", - " try:\n", - " import numexpr\n", - " x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n", - " except ImportError:\n", - " x = (x - mi) / ( ma - mi + eps )\n", - "\n", - " if clip:\n", - " x = np.clip(x,0,1)\n", - "\n", - " return x\n", - "\n", - "def norm_minmse(gt, x, normalize_gt=True):\n", - " \"\"\"This function is adapted from Martin Weigert\"\"\"\n", - "\n", - " \"\"\"\n", - " normalizes and affinely scales an image pair such that the MSE is minimized \n", - " \n", - " Parameters\n", - " ----------\n", - " gt: ndarray\n", - " the ground truth image \n", - " x: ndarray\n", - " the image that will be affinely scaled \n", - " normalize_gt: bool\n", - " set to True of gt image should be normalized (default)\n", - " Returns\n", - " -------\n", - " gt_scaled, x_scaled \n", - " \"\"\"\n", - " if normalize_gt:\n", - " gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n", - " x = x.astype(np.float32, copy=False) - np.mean(x)\n", - " #x = x - np.mean(x)\n", - " gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n", - " #gt = gt - np.mean(gt)\n", - " scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n", - " return gt, scale * x\n", - "\n", - "# Open and create the csv file that will contain all the QC metrics\n", - "with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/QC_metrics_\"+QC_model_name+\".csv\", \"w\", newline='') as file:\n", - " writer = csv.writer(file)\n", - "\n", - " # Write the header in the csv file\n", - " writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n", - "\n", - " # Let's loop through the provided dataset in the QC folders\n", - "\n", - "\n", - " for i in os.listdir(Source_QC_folder):\n", - " if not os.path.isdir(os.path.join(Source_QC_folder,i)):\n", - " print('Running QC on: '+i)\n", - " # -------------------------------- Target test data (Ground truth) --------------------------------\n", - " test_GT = io.imread(os.path.join(Target_QC_folder, i))\n", - "\n", - " # -------------------------------- Source test data --------------------------------\n", - " test_source = io.imread(os.path.join(Source_QC_folder,i))\n", - "\n", - " # Normalize the images wrt each other by minimizing the MSE between GT and Source image\n", - " test_GT_norm,test_source_norm = norm_minmse(test_GT, test_source, normalize_gt=True)\n", - "\n", - " # -------------------------------- Prediction --------------------------------\n", - " test_prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\",i))\n", - "\n", - " # Normalize the images wrt each other by minimizing the MSE between GT and prediction\n", - " test_GT_norm,test_prediction_norm = norm_minmse(test_GT, test_prediction, normalize_gt=True) \n", - "\n", - "\n", - " # -------------------------------- Calculate the metric maps and save them --------------------------------\n", - "\n", - " # Calculate the SSIM maps\n", - " index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT_norm, test_prediction_norm)\n", - " index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT_norm, test_source_norm)\n", - "\n", - " #Save ssim_maps\n", - " img_SSIM_GTvsPrediction_32bit = np.float32(img_SSIM_GTvsPrediction)\n", - " io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/SSIM_GTvsPrediction_'+i,img_SSIM_GTvsPrediction_32bit)\n", - " img_SSIM_GTvsSource_32bit = np.float32(img_SSIM_GTvsSource)\n", - " io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/SSIM_GTvsSource_'+i,img_SSIM_GTvsSource_32bit)\n", - " \n", - " # Calculate the Root Squared Error (RSE) maps\n", - " img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n", - " img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n", - "\n", - " # Save SE maps\n", - " img_RSE_GTvsPrediction_32bit = np.float32(img_RSE_GTvsPrediction)\n", - " img_RSE_GTvsSource_32bit = np.float32(img_RSE_GTvsSource)\n", - " io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/RSE_GTvsPrediction_'+i,img_RSE_GTvsPrediction_32bit)\n", - " io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/RSE_GTvsSource_'+i,img_RSE_GTvsSource_32bit)\n", - "\n", - "\n", - " # -------------------------------- Calculate the RSE metrics and save them --------------------------------\n", - "\n", - " # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n", - " NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n", - " NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n", - " \n", - " # We can also measure the peak signal to noise ratio between the images\n", - " PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n", - " PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n", - "\n", - " writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource),str(PSNR_GTvsPrediction),str(PSNR_GTvsSource)])\n", - "\n", - "\n", - "# All data is now processed saved\n", - "Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same\n", - "\n", - "img_GT = io.imread(os.path.join(Target_QC_folder, Test_FileList[-1]))\n", - "img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n", - "img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction/\", Test_FileList[-1]))\n", - "\n", - "norm_img_GT = simple_norm(img_GT, percent = 99)\n", - "norm_img_Source = simple_norm(img_Source, percent = 99)\n", - "norm_img_Prediction = simple_norm(img_Prediction, percent = 99)\n", - "\n", - "\n", - "plt.figure(figsize=(15,15))\n", - "# Currently only displays the last computed set, from memory\n", - "# Target (Ground-truth)\n", - "plt.subplot(3,3,1)\n", - "plt.axis('off')\n", - "plt.imshow(img_GT, norm=norm_img_GT, cmap='magma')\n", - "plt.title('Target',fontsize=15)\n", - "\n", - "# Source\n", - "plt.subplot(3,3,2)\n", - "plt.axis('off')\n", - "plt.imshow(img_Source, norm=norm_img_Source, cmap='magma')\n", - "plt.title('Source',fontsize=15)\n", - "\n", - "#Prediction\n", - "plt.subplot(3,3,3)\n", - "plt.axis('off')\n", - "plt.imshow(img_Prediction, norm=norm_img_Prediction, cmap='magma')\n", - "plt.title('Prediction',fontsize=15)\n", - "\n", - "#Setting up colours\n", - "cmap = plt.cm.CMRmap\n", - "\n", - "#SSIM between GT and Source\n", - "plt.subplot(3,3,5)\n", - "#plt.axis('off')\n", - "plt.tick_params(\n", - " axis='both', # changes apply to the x-axis and y-axis\n", - " which='both', # both major and minor ticks are affected\n", - " bottom=False, # ticks along the bottom edge are off\n", - " top=False, # ticks along the top edge are off\n", - " left=False, # ticks along the left edge are off\n", - " right=False, # ticks along the right edge are off\n", - " labelbottom=False,\n", - " labelleft=False) \n", - "imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n", - "plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n", - "plt.title('Target vs. Source',fontsize=15)\n", - "plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\n", - "plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n", - "\n", - "#SSIM between GT and Prediction\n", - "plt.subplot(3,3,6)\n", - "#plt.axis('off')\n", - "plt.tick_params(\n", - " axis='both', # changes apply to the x-axis and y-axis\n", - " which='both', # both major and minor ticks are affected\n", - " bottom=False, # ticks along the bottom edge are off\n", - " top=False, # ticks along the top edge are off\n", - " left=False, # ticks along the left edge are off\n", - " right=False, # ticks along the right edge are off\n", - " labelbottom=False,\n", - " labelleft=False) \n", - "imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n", - "plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n", - "plt.title('Target vs. Prediction',fontsize=15)\n", - "plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\n", - "\n", - "#Root Squared Error between GT and Source\n", - "plt.subplot(3,3,8)\n", - "#plt.axis('off')\n", - "plt.tick_params(\n", - " axis='both', # changes apply to the x-axis and y-axis\n", - " which='both', # both major and minor ticks are affected\n", - " bottom=False, # ticks along the bottom edge are off\n", - " top=False, # ticks along the top edge are off\n", - " left=False, # ticks along the left edge are off\n", - " right=False, # ticks along the right edge are off\n", - " labelbottom=False,\n", - " labelleft=False) \n", - "imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource, cmap = cmap, vmin=0, vmax = 1)\n", - "plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n", - "plt.title('Target vs. Source',fontsize=15)\n", - "plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsSource,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n", - "#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\n", - "plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n", - "\n", - "#Root Squared Error between GT and Prediction\n", - "plt.subplot(3,3,9)\n", - "#plt.axis('off')\n", - "plt.tick_params(\n", - " axis='both', # changes apply to the x-axis and y-axis\n", - " which='both', # both major and minor ticks are affected\n", - " bottom=False, # ticks along the bottom edge are off\n", - " top=False, # ticks along the top edge are off\n", - " left=False, # ticks along the left edge are off\n", - " right=False, # ticks along the right edge are off\n", - " labelbottom=False,\n", - " labelleft=False) \n", - "imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n", - "plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n", - "plt.title('Target vs. Prediction',fontsize=15)\n", - "plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsPrediction,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n", - "plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n", - "\n", - "qc_pdf_export()" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "69aJVFfsqXbY" - }, - "source": [ - "# **6. Using the trained model**\n", - "\n", - "---\n", - "\n", - "In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tcPNRq1TrMPB" - }, - "source": [ - "## **6.1. Generate prediction(s) from unseen dataset**\n", - "---\n", - "\n", - "The current trained model (from section 4.2) can now be used to process images. If an older model needs to be used, please untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n", - "\n", - "**`Data_folder`:** This folder should contains the images that you want to predict using the network that you will train.\n", - "\n", - "**`Result_folder`:** This folder will contain the predicted output images.\n", - "\n", - "**`Data_type`:** Please indicate if the images you want to predict are single images or stacks" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "Am2JSmpC0frj", - "cellView": "form" - }, - "source": [ - "Single_Images = 1\n", - "Stacks = 2\n", - "\n", - "from deconoising import prediction\n", - "\n", - "\n", - "#@markdown ### Provide the path to your dataset and to the folder where the prediction will be saved, then play the cell to predict output on your unseen images.\n", - "\n", - "#@markdown ###Path to data to analyse and where predicted output should be saved:\n", - "Data_folder = \"\" #@param {type:\"string\"}\n", - "Result_folder = \"\" #@param {type:\"string\"}\n", - "\n", - "#@markdown ###Are your data single images or stacks?\n", - "\n", - "Data_type = Single_Images #@param [\"Single_Images\", \"Stacks\"] {type:\"raw\"}\n", - "\n", - "# model name and path\n", - "#@markdown ###Do you want to use the current trained model?\n", - "Use_the_current_trained_model = True #@param {type:\"boolean\"}\n", - "\n", - "Weights_choice = \"best\" #@param [\"last\", \"best\"]\n", - "\n", - "\n", - "#@markdown ###If not, please provide the path to the model folder:\n", - "\n", - "Prediction_model_folder = \"\" #@param {type:\"string\"}\n", - "\n", - "#@markdown ###Would you like to save the deconvolved images?\n", - "Save_deconvolved_images = True #@param {type:\"boolean\"}\n", - "\n", - "\n", - "\n", - "#Here we find the loaded model name and parent path\n", - "Prediction_model_name = os.path.basename(Prediction_model_folder)\n", - "Prediction_model_path = os.path.dirname(Prediction_model_folder)\n", - "\n", - "if (Use_the_current_trained_model): \n", - " print(\"Using current trained network\")\n", - " Prediction_model_name = model_name\n", - " Prediction_model_path = model_path\n", - "\n", - "full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n", - "if os.path.exists(full_Prediction_model_path):\n", - " print(\"The \"+Prediction_model_name+\" network will be used.\")\n", - "else:\n", - " print(bcolors.WARNING +'!! WARNING: The chosen model does not exist !!')\n", - " print('Please make sure you provide a valid model path and model name before proceeding further.')\n", - "\n", - "\n", - "#Activate the pretrained model. \n", - "net = torch.load(full_Prediction_model_path+\"/\" + Weights_choice+\"_\" + Prediction_model_name + \".net\")\n", - "\n", - " # r=root, d=directories, f = files\n", - "for r, d, f in os.walk(Data_folder):\n", - " for file in f:\n", - " if \".tif\" in file:\n", - " print(os.path.join(r, file))\n", - "\n", - "if Data_type == 1 :\n", - " print(\"Single images are now beeing predicted\")\n", - "\n", - "# Loop through the files\n", - " for r, d, f in os.walk(Data_folder):\n", - " for file in f:\n", - " base_filename = os.path.basename(file)\n", - " input_train = imread(os.path.join(r, file))\n", - " \n", - " # We are using tiling to fit the image into memory\n", - " # If you get an error try a smaller patch size (ps)\n", - " # Here we are predicting the deconvolved and denoised image\n", - " deconvolvedResult, denoisedResult = prediction.tiledPredict(input_train, net ,ps=256, overlap=48, device=device)\n", - "\n", - " if Save_deconvolved_images:\n", - " io.imsave(Result_folder+'/deconvolved_'+base_filename,deconvolvedResult)\n", - " \n", - " io.imsave(Result_folder+'/denoised_'+base_filename,denoisedResult)\n", - "\n", - " print(\"Images saved into folder:\", Result_folder)\n", - "\n", - "if Data_type == 2 :\n", - " print(\"Stacks are now beeing predicted\")\n", - " for r, d, f in os.walk(Data_folder):\n", - " for file in f:\n", - " base_filename = os.path.basename(file)\n", - " print(\"Denoising: \"+str(base_filename))\n", - " timelapse = imread(os.path.join(r, file))\n", - " n_timepoint = timelapse.shape[0]\n", - " deconvolvedResult_stack = np.zeros((n_timepoint, timelapse.shape[1], timelapse.shape[2]))\n", - " denoisedResult_stack = np.zeros((n_timepoint, timelapse.shape[1], timelapse.shape[2]))\n", - "\n", - " for t in range(n_timepoint):\n", - " img_t = timelapse[t]\n", - " deconvolvedResult_stack[t], denoisedResult_stack[t] = prediction.tiledPredict(img_t, net ,ps=256, overlap=48, device=device)\n", - " \n", - " if Save_deconvolved_images:\n", - " deconvolvedResult_stack_32 = img_as_float32(deconvolvedResult_stack, force_copy=False)\n", - " imsave(Result_folder+'/deconvolved_'+base_filename, deconvolvedResult_stack_32) \n", - "\n", - " denoisedResult_stack_32 = img_as_float32(denoisedResult_stack, force_copy=False)\n", - " imsave(Result_folder+'/denoised_'+base_filename, denoisedResult_stack_32)\n", - " del denoisedResult_stack_32 \n", - " del denoisedResult_stack\n", - " del timelapse\n", - " del deconvolvedResult_stack\n", - " \n", - "\n", - "\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "67_8rEKp8C-z" - }, - "source": [ - "## **6.2. Assess predicted output**\n", - "---\n", - "\n", - "\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "n-stU-f08Cae", - "cellView": "form" - }, - "source": [ - "# @markdown ##Run this cell to display a randomly chosen input and its corresponding predicted output.\n", - "\n", - "# This will display a randomly chosen dataset input and predicted output\n", - "\n", - "\n", - "random_choice = random.choice(os.listdir(Data_folder))\n", - "input_train = imread(Data_folder+\"/\"+random_choice)\n", - "\n", - "#os.chdir(Result_folder)\n", - "#deconvolvedResult = imread(Result_folder+'/denoised_'+random_choice)\n", - "denoisedResult = imread(Result_folder+'/denoised_'+random_choice)\n", - "\n", - "norm = simple_norm(input_train, percent = 99)\n", - "\n", - "\n", - "if Data_type == 1 :\n", - "\n", - " plt.figure(figsize=(15, 15))\n", - " plt.subplot(1, 2, 1)\n", - " plt.title('Input image')\n", - " plt.imshow(input_train, norm=norm, cmap='magma')\n", - " plt.axis('off');\n", - "\n", - "\n", - " plt.subplot(1, 2, 2)\n", - " plt.title('Denoised output')\n", - " plt.imshow(denoisedResult, norm=norm, cmap='magma')\n", - " plt.axis('off');\n", - "\n", - " plt.figure(figsize=(15, 15))\n", - " plt.subplot(1, 2, 1)\n", - " plt.title('Input image')\n", - " plt.imshow(input_train[100:200,150:250], norm=norm, cmap='magma')\n", - " plt.axis('off');\n", - "\n", - " plt.subplot(1, 2, 2)\n", - " plt.title('Denoised output')\n", - " plt.imshow(denoisedResult[100:200,150:250], norm=norm, cmap='magma') \n", - " plt.axis('off');\n", - " \n", - "\n", - "if Data_type == 2 :\n", - " \n", - "\n", - " plt.figure(figsize=(15, 15))\n", - " plt.subplot(1, 2, 1)\n", - " plt.title('Input image')\n", - " plt.imshow(input_train[1], cmap='magma')\n", - " plt.axis('off');\n", - "\n", - " plt.subplot(1, 2, 2)\n", - " plt.title('Denoised output')\n", - " plt.imshow(denoisedResult[1], cmap='magma')\n", - " plt.axis('off');\n", - "\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "hvkd66PldsXB" - }, - "source": [ - "## **6.3. Download your predictions**\n", - "---\n", - "\n", - "**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "u4pcBe8Z3T2J" - }, - "source": [ - "#**Thank you for using DecoNoising 2D!**" - ] - } - ] -} \ No newline at end of file +{"nbformat":4,"nbformat_minor":0,"metadata":{"accelerator":"GPU","colab":{"name":"DecoNoising_2D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1hMjEc-Ex7j-jeYGclaPw2x3OgbkeC6Bl","timestamp":1610626439596},{"file_id":"1_W4q9V1ExGFldTUBvGK91E0LG5QMc7K6","timestamp":1602523405636},{"file_id":"1t9a-44km730bI7F4I08-6Xh7wEZuL98p","timestamp":1591013189418},{"file_id":"11TigzvLl4FSSwFHUNwLzZKI2IAix4Nmu","timestamp":1586415689249},{"file_id":"1_dSnxUg_qtNWjrPc7D6RWDWlCanEL4Ve","timestamp":1585153449937},{"file_id":"1bKo8jYVZPPgXPa_-Gdu1KhDnNN4vYfLx","timestamp":1583200150464}],"collapsed_sections":[],"machine_shape":"hm"},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.4"}},"cells":[{"cell_type":"markdown","metadata":{"id":"5oqGsnIIhcWe"},"source":["#**Currently missing features**\n","\n","- Learning rate is not currently saved \n","- Transfer learning is included but probably does not work well (last learning rate is not loaded)\n"]},{"cell_type":"markdown","metadata":{"id":"V9zNGvape2-I"},"source":["# **DecoNoising (2D)**\n","\n","---\n","\n"," DecoNoising is an unsupervised denoising method that takes in consideration the expected point spread function of the images to denoise to improve the predictions. DecoNoising was originally published by [Goncharova *et al.* on arXiv](https://arxiv.org/abs/2008.08414). \n","\n"," **This particular notebook enables self-supervised denoising of 2D dataset. If you are interested in 3D dataset, you should use the Noise2Void 3D notebook instead.**\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the following paper:\n","\n","**Improving Blind Spot Denoising for Microscopy**\n","from Goncharova *et al.* published on arXiv in 2020 (https://arxiv.org/abs/2008.08414)\n","\n","And source code found in: /~https://github.com/juglab/DecoNoising\n","\n","**Please also cite this original paper when using or developing this notebook.**\n"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"vNMDQHm0Ah-Z"},"source":["# **0. Before getting started**\n","---\n","\n","Before you run the notebook, please ensure that you are logged into your Google account and have the training and/or data to process in your Google Drive.\n","\n","For Noise2Void to train, it only requires a single noisy image but multiple images can be used. Information on how to generate a training dataset is available in our Wiki page: /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","Please note that you currently can **only use .tif files!**\n","\n","**We strongly recommend that you generate high signal to noise ration version of your noisy images (Quality control dataset). These images can be used to assess the quality of your trained model**. The quality control assessment can be done directly in this notebook.\n","\n"," You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed.\n","\n","Here is a common data structure that can work:\n","\n","* Data\n"," - **Training dataset**\n"," - **Quality control dataset** (Optional but recomended)\n"," - Low SNR images\n"," - img_1.tif, img_2.tif\n"," - High SNR images\n"," - img_1.tif, img_2.tif \n"," - **Data to be predicted** \n"," - Results\n","\n","\n","The **Results** folder will contain the processed images, trained model and network parameters as csv file. Your original images remain unmodified.\n","\n","---\n","**Important note**\n","\n","- If you wish to **train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **1. Install DecoNoising and dependencies**\n","---"]},{"cell_type":"code","metadata":{"id":"GT3rF4uA8gL6","cellView":"form"},"source":["Notebook_version = '1.13'\n","Network = 'DecoNoising'\n","\n","\n","#@markdown ##Install DecoNoising and dependencies\n","\n","from builtins import any as b_any\n","\n","def get_requirements_path():\n"," # Store requirements file in 'contents' directory\n"," current_dir = os.getcwd()\n"," dir_count = current_dir.count('/') - 1\n"," path = '../' * (dir_count) + 'requirements.txt'\n"," return path\n","\n","def filter_files(file_list, filter_list):\n"," filtered_list = []\n"," for fname in file_list:\n"," if b_any(fname.split('==')[0] in s for s in filter_list):\n"," filtered_list.append(fname)\n"," return filtered_list\n","\n","def build_requirements_file(before, after):\n"," path = get_requirements_path()\n","\n"," # Exporting requirements.txt for local run\n"," !pip freeze > $path\n","\n"," # Get minimum requirements file\n"," df = pd.read_csv(path, delimiter = \"\\n\")\n"," mod_list = [m.split('.')[0] for m in after if not m in before]\n"," req_list_temp = df.values.tolist()\n"," req_list = [x[0] for x in req_list_temp]\n","\n"," # Replace with package name and handle cases where import name is different to module name\n"," mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]\n"," mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list]\n"," filtered_list = filter_files(req_list, mod_replace_list)\n","\n"," file=open(path,'w')\n"," for item in filtered_list:\n"," file.writelines(item + '\\n')\n","\n"," file.close()\n","\n","import sys\n","before = [str(m) for m in sys.modules]\n","\n","# ------- Variable specific to DecoNoising -------\n","\n","!git clone /~https://github.com/juglab/DecoNoising\n","\n","!git clone /~https://github.com/juglab/PN2V\n","\n","sys.path.append('/content/PN2V')\n","\n","from pn2v import training\n","\n","sys.path.append('/content/DecoNoising')\n","\n","import matplotlib.pyplot as plt\n","from unet.model import UNet\n","from deconoising import utils\n","from deconoising.utils import PSNR\n","from deconoising import training\n","from deconoising import prediction\n","\n","import torch\n","\n","from scipy.ndimage import gaussian_filter\n","\n","device=utils.getDevice()\n","\n","!pip install wget\n","!pip install fpdf\n","\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","from fpdf import FPDF, HTMLMixin\n","from datetime import datetime\n","from pip._internal.operations.freeze import freeze\n","import subprocess\n","from datetime import datetime\n","from PIL import Image \n","from PIL.TiffTags import TAGS\n","import tensorflow as tf\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print(\"Libraries installed\")\n","\n","\n","# Check if this is the latest version of the notebook\n","All_notebook_versions = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_Notebook_versions.csv\", dtype=str)\n","print('Notebook version: '+Notebook_version)\n","Latest_Notebook_version = All_notebook_versions[All_notebook_versions[\"Notebook\"] == Network]['Version'].iloc[0]\n","print('Latest notebook version: '+Latest_Notebook_version)\n","if Notebook_version == Latest_Notebook_version:\n"," print(\"This notebook is up-to-date.\")\n","else:\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","\n","\n","def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," # add another cell \n"," if trained:\n"," training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n"," pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n"," pdf.ln(1)\n","\n"," Header_2 = 'Information for your materials and method:'\n"," pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n"," #print(all_packages)\n","\n"," #Main Packages\n"," main_packages = ''\n"," version_numbers = []\n"," for name in ['torch','numpy']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n"," cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n"," cuda_version = cuda_version.stdout.decode('utf-8')\n"," cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n"," gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n"," gpu_name = gpu_name.stdout.decode('utf-8')\n"," gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n"," #print(cuda_version[cuda_version.find(', V')+3:-1])\n"," #print(gpu_name)\n","\n"," shape = io.imread(Training_source+'/'+os.listdir(Training_source)[0]).shape\n"," dataset_size = len(os.listdir(Training_source))\n","\n"," text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(len(Training_Filelist))+' images (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a virtual batch size of '+str(virtual_batch_size)+', using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include torch (v '+version_numbers[0]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," if pretrained_model:\n"," text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(len(Training_Filelist))+' images (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a virtual batch size of '+str(virtual_batch_size)+', using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was re-trained from a pretrained model. Key python packages used include torch (v '+version_numbers[0]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.multi_cell(190, 5, txt = text, align='L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(26, 5, txt='Augmentation: ', ln=0)\n"," pdf.set_font('')\n"," if Use_Data_augmentation:\n"," aug_text = 'The dataset was augmented by default.'\n"," else:\n"," aug_text = 'No augmentation was used for training.'\n"," pdf.multi_cell(190, 5, txt=aug_text, align='L')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," if Use_Default_Advanced_Parameters:\n"," pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n"," pdf.cell(200, 5, txt='The following parameters were used for training:')\n"," pdf.ln(1)\n"," html = \"\"\" \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
number_of_epochs{0}
patch_size{1}
batch_size{2}
virtual_batch_size{3}
number_of_steps{4}
percentage_validation{5}
initial_learning_rate{6}
std_gaussian_for_PSF{7}
pixel_size{8}
\n"," \"\"\".format(number_of_epochs,str(patch_size)+'x'+str(patch_size),batch_size,virtual_batch_size,number_of_steps,percentage_validation,initial_learning_rate,std_gaussian_for_PSF,pixel_size)\n"," pdf.write_html(html)\n","\n"," #pdf.multi_cell(190, 5, txt = text_2, align='L')\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," pdf.ln(1)\n"," pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(28, 5, txt= 'Training_source:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n"," # pdf.set_font('')\n"," # pdf.set_font('Arial', size = 10, style = 'B')\n"," # pdf.cell(28, 5, txt= 'Training_target:', align = 'L', ln=0)\n"," # pdf.set_font('')\n"," # pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n"," #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(21, 5, txt= 'Model Path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n"," pdf.ln(1)\n"," pdf.cell(60, 5, txt = 'Example Training Image', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread('/content/TrainingDataExample.png').shape\n"," pdf.image('/content/TrainingDataExample.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- DecoNoising: Goncharova Anna et al. \"Improving Blind Spot Denoising for Microscopy\" https://arxiv.org/abs/2008.08414. 2020.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," pdf.ln(3)\n"," reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(model_path+'/'+model_name+'/'+model_name+\"_training_report.pdf\")\n","\n","\n","\n"," #Make a pdf summary of the QC results\n","\n","def qc_pdf_export():\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'DecoNoising 2D'\n","\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(2)\n"," pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/lossCurvePlots.png').shape\n"," if os.path.exists(full_QC_model_path+'/Quality Control/lossCurvePlots.png'):\n"," pdf.image(full_QC_model_path+'/Quality Control/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," pdf.cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.')\n"," pdf.ln(2)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(3)\n"," pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n"," pdf.image(full_QC_model_path+'/Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/10))\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n","\n"," pdf.ln(1)\n"," html = \"\"\"\n"," \n"," \n"," \"\"\"\n"," with open(full_QC_model_path+'/Quality Control/QC_metrics_'+QC_model_name+'.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," mSSIM_PvsGT = header[1]\n"," mSSIM_SvsGT = header[2]\n"," NRMSE_PvsGT = header[3]\n"," NRMSE_SvsGT = header[4]\n"," PSNR_PvsGT = header[5]\n"," PSNR_SvsGT = header[6]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,mSSIM_PvsGT,mSSIM_SvsGT,NRMSE_PvsGT,NRMSE_SvsGT,PSNR_PvsGT,PSNR_SvsGT)\n"," html = html+header\n"," for row in metrics:\n"," image = row[0]\n"," mSSIM_PvsGT = row[1]\n"," mSSIM_SvsGT = row[2]\n"," NRMSE_PvsGT = row[3]\n"," NRMSE_SvsGT = row[4]\n"," PSNR_PvsGT = row[5]\n"," PSNR_SvsGT = row[6]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,str(round(float(mSSIM_PvsGT),3)),str(round(float(mSSIM_SvsGT),3)),str(round(float(NRMSE_PvsGT),3)),str(round(float(NRMSE_SvsGT),3)),str(round(float(PSNR_PvsGT),3)),str(round(float(PSNR_SvsGT),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}{3}{4}{5}{6}
{0}{1}{2}{3}{4}{5}{6}
\"\"\"\n"," \n"," pdf.write_html(html)\n","\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- DecoNoising: Goncharova Anna et al. \"Improving Blind Spot Denoising for Microscopy\" https://arxiv.org/abs/2008.08414. 2020.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," pdf.ln(3)\n"," reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n","\n","\n","def getPixelSizeTIFFmetadata(TIFFpath, display=False):\n"," with Image.open(TIFFpath) as img:\n"," meta_dict = {TAGS[key] : img.tag[key] for key in img.tag.keys()}\n","\n","\n"," # TIFF tags\n"," # https://www.loc.gov/preservation/digital/formats/content/tiff_tags.shtml\n"," # https://www.awaresystems.be/imaging/tiff/tifftags/resolutionunit.html\n"," ResolutionUnit = meta_dict['ResolutionUnit'][0] # unit of resolution\n"," width = meta_dict['ImageWidth'][0]\n"," height = meta_dict['ImageLength'][0]\n","\n"," xResolution = meta_dict['XResolution'][0] # number of pixels / ResolutionUnit\n","\n"," if len(xResolution) == 1:\n"," xResolution = xResolution[0]\n"," elif len(xResolution) == 2:\n"," xResolution = xResolution[0]/xResolution[1]\n"," else:\n"," print('Image resolution not defined.')\n"," xResolution = 1\n","\n"," if ResolutionUnit == 2:\n"," # Units given are in inches\n"," pixel_size = 0.025*1e9/xResolution\n"," elif ResolutionUnit == 3:\n"," # Units given are in cm\n"," pixel_size = 0.01*1e9/xResolution\n"," else: \n"," # ResolutionUnit is therefore 1\n"," print('Resolution unit not defined. Assuming: um')\n"," pixel_size = 1e3/xResolution\n","\n"," if display:\n"," print('Pixel size obtained from metadata: '+str(pixel_size)+' nm')\n"," print('Image size: '+str(width)+'x'+str(height))\n"," \n"," return (pixel_size, width, height)\n","\n","# Build requirements file for local run\n","after = [str(m) for m in sys.modules]\n","build_requirements_file(before, after)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"b4-r1gE7Iamv"},"source":["# **2. Initialise the Colab session**\n","---"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["\n","## **2.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelerator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"BDhmUgqCStlm","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-oqBTeLaImnU"},"source":["## **2.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"cellView":"form","id":"01Djr8v-5pPk"},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Fw0kkTU6CsU4"},"source":["# **3. Select your parameters and paths**\n","\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"WzYAA-MuaYrT"},"source":["## **3.1. Setting main training parameters**\n","---\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"CB6acvUFtWqd"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`:** These is the path to your folders containing the Training_source (noisy images). To find the path of the folder containing your datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Do not re-use the name of an existing model (saved in the same folder), otherwise it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","\n","**Training Parameters**\n","\n","\n","**`number_of_epochs`:** Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for 100-200 epochs. Evaluate the performance after training (see 4.3.). **Default value: 100**\n","\n","**`patch_size`:** Deconoising divides the image into patches for training. Input the size of the patches (length of a side). The value should be between 64 and the dimensions of the image and divisible by 8. **Default value: 80**\n"," \n","**Advanced Parameters - experienced users only**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. **Default value: 4**\n","\n","**`virtual_batch_size:`** The number of batches that are processed before a gradient step is performed. **Default value: 20**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each image / patch is seen at least once per epoch. **Default value: Number of images / batch_size**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during the training. **Default value: 20**\n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.001**\n"]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","cellView":"form"},"source":["\n","#@markdown ###Path to training image(s): \n","Training_source = \"\" #@param {type:\"string\"}\n","\n","#compatibility to easily change the name of the parameters\n","training_images = Training_source \n","\n","#@markdown ### Model name and path:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","full_model_path = model_path+'/'+model_name+'/'\n","\n","\n","#@markdown Number of epochs:\n","number_of_epochs = 100#@param {type:\"number\"}\n","\n","#@markdown Patch size (pixels)\n","patch_size = 80#@param {type:\"number\"}\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True#@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please input:\n","batch_size = 4#@param {type:\"number\"}\n","virtual_batch_size = 20 #@param {type:\"number\"}\n","number_of_steps = 12#@param {type:\"number\"}\n","percentage_validation = 20#@param {type:\"number\"}\n","initial_learning_rate = 0.001 #@param {type:\"number\"}\n","\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," # number_of_steps is defined in the following cell in this case\n"," batch_size = 4\n"," percentage_validation = 20\n"," initial_learning_rate = 0.001\n"," virtual_batch_size = 20\n"," \n","#here we check that no model with the same name already exist, if so print a warning\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n"," print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\"+W)\n","\n","\n","# This will open a randomly chosen dataset input image\n","random_choice = random.choice(os.listdir(Training_source))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","# Here we check the image dimensions\n","\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","\n","#Hyperparameters failsafes\n","\n","# Here we check that patch_size is smaller than the smallest xy dimension of the image \n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is divisible by 8\n","if not patch_size % 8 == 0:\n"," patch_size = ((int(patch_size / 8)-1) * 8)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is now:\",patch_size)\n","\n","\n","# Here we check that the input images contains the expected dimensions\n","if len(x.shape) == 2:\n"," print('Loaded images (width, length) =', x.shape)\n","\n","\n","if not len(x.shape) == 2:\n"," print(bcolors.WARNING +\"Your images appear to have the wrong dimensions. Image dimension\",x.shape)\n","\n","#Here we split the images between training and validation\n","\n","# Here we count the number of files in the training target folder\n","Noisy_Filelist = os.listdir(Training_source)\n","Noisy_number_files = len(Noisy_Filelist)\n","\n","# Here we count the number of file to use for validation\n","Noisy_for_validation = int((Noisy_number_files)/percentage_validation)\n","\n","if Noisy_for_validation == 0:\n"," Noisy_for_validation = 1\n","\n","#Here we split the training dataset between training and validation\n","# Everything is copied in the /Content Folder\n","Training_source_temp = \"/content/training_source\"\n","\n","if os.path.exists(Training_source_temp):\n"," shutil.rmtree(Training_source_temp)\n","os.makedirs(Training_source_temp)\n","\n","\n","Validation_source_temp = \"/content/validation_source\"\n","\n","if os.path.exists(Validation_source_temp):\n"," shutil.rmtree(Validation_source_temp)\n","os.makedirs(Validation_source_temp)\n","\n","list_source = os.listdir(os.path.join(Training_source))\n","\n","#Move files into the temporary source and target directories:\n","\n","for f in os.listdir(os.path.join(Training_source)):\n"," shutil.copy(Training_source+\"/\"+f, Training_source_temp+\"/\"+f)\n","\n","#Here we move images to be used for validation\n","for i in range(Noisy_for_validation): \n"," shutil.move(Training_source_temp+\"/\"+list_source[i], Validation_source_temp+\"/\"+list_source[i])\n","\n","# Here we disable pre-trained model by default (in case the next cell is not run)\n","Use_pretrained_model = False\n","\n","# Here we enable data augmentation by default (in case the cell is not ran)\n","Use_Data_augmentation = True\n","\n","print(\"Parameters initiated.\")\n","\n","#Here we display one image\n","norm = simple_norm(x, percent = 99)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest', norm=norm, cmap='magma')\n","plt.title('Training source')\n","plt.axis('off');\n","plt.savefig('/content/TrainingDataExample.png',bbox_inches='tight',pad_inches=0)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"j5aThu1DaB04"},"source":["## **3.2. PSF simulator**\n","---\n","\n"]},{"cell_type":"markdown","metadata":{"id":"ibpaGXPBQaNP"},"source":[" DecoNoising uses the expected point spread function of your images to denoise and improve the predictions. Here we simulate a point spread function using a Gaussian function. \n","\n","**`Lambda`:** This corresponds to the wavelength **in nanoometers** of the fluorescence emission in the experiment.\n","\n","**`NA`:** This corresponds to the numerical aperture of the microscope objective used in the experiment.\n","\n","**`pixel_size:`:** Indicate the pixel size of your images (**in nanometers**). This information can be found, for instance, by opening your images in Fiji. It can also sometimes be directly extracted from the metadata.\n","\n","The simulator uses the Gaussian approximation of the point spread function (PSF). The nominal value of the standard deviation of the corresponding 2D Gaussian function here is evaluated using `0.21 x Lambda / NA`, described by [Zhang *et al.*, Applied Optics 2007](https://doi.org/10.1364/AO.46.001819)). This equation however likely represents an optimistic estimation of the size of the PSF and therefore should be taken as an under-estimation of the real experimental PSF. Users may consider increasing Lambda or decreasing NA slightly to take into account potential deviation from this ideal case.\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"u2ag5fe_Ine2","cellView":"form"},"source":["\n","#@markdown ###Run this cell to simulate your PSF: \n","\n","Lambda = 500 #@param {type:\"number\"}\n","NA = 1.2 #@param {type:\"number\"}\n","\n","Load_pixel_size_from_metadata = False#@param {type:\"boolean\"}\n","#@markdown ###If not, please input:\n","\n","pixel_size = 250#@param {type:\"number\"}\n","\n","sigma_nm = 0.21*Lambda/NA\n","print(\"Gaussian PSF sigma: \"+str(sigma_nm)+\" um\")\n","std_gaussian_for_PSF = sigma_nm/pixel_size\n","print(\"Gaussian PSF sigma: \"+str(std_gaussian_for_PSF)+\" pixels\")\n","\n","\n","if Load_pixel_size_from_metadata:\n"," pixel_size,_,_ = getPixelSizeTIFFmetadata(Training_source+\"/\"+random_choice, True)\n","\n","#size_of_psf = 24*int(std_gaussian_for_PSF)+1 # 12 STD wide is good enough\n","\n","size_of_psf = patch_size+1\n","\n","if size_of_psf > patch_size:\n"," size_of_psf = (patch_size+1)\n","\n","if size_of_psf > patch_size:\n"," size_of_psf = (patch_size+1)\n","\n","\n","def artificial_psf(size_of_psf = size_of_psf, std_gauss = std_gaussian_for_PSF): \n"," filt = np.zeros((size_of_psf, size_of_psf))\n"," p = (size_of_psf - 1)//2 # integer division\n"," filt[p,p] = 1\n"," filt = torch.tensor(gaussian_filter(filt,std_gaussian_for_PSF).reshape(1,1,size_of_psf,size_of_psf).astype(np.float32))\n"," filt = filt/torch.sum(filt)\n"," return filt\n","\n","psf_tensor = artificial_psf(std_gauss = std_gaussian_for_PSF)\n","# plt.imshow(psf_tensor[0, 0, 20:60, 20:60], cmap = 'gray', interpolation= 'none')\n","plt.imshow(psf_tensor[0, 0, :, :], cmap = 'gray', interpolation= 'none')\n","# plt.axis('off')\n","\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xGcl7WGP4WHt"},"source":["## **3.3. Data augmentation**\n","---"]},{"cell_type":"markdown","metadata":{"id":"5Lio8hpZ4PJ1"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n","Data augmentation is performed here by rotating the patches in XY-Plane and flip them along X-Axis. This only works if the patches are square in XY.\n","\n"," **By default data augmentation is enabled. Disable this option is you run out of RAM during the training**.\n"," "]},{"cell_type":"code","metadata":{"cellView":"form","id":"htqjkJWt5J_8"},"source":["#Data augmentation\n","\n","#@markdown ##Play this cell to enable or disable data augmentation: \n","\n","Use_Data_augmentation = True #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation:\n"," print(\"Data augmentation enabled\")\n","\n","if not Use_Data_augmentation:\n"," print(\"Data augmentation disabled\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"bQDuybvyadKU"},"source":["\n","## **3.4. Using weights from a pre-trained model as initial weights** NOT implemented \n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a Deconoising 2D model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"8vPkzEBNamE4","cellView":"form"},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","\n","Weights_choice = \"best\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","pretrained_model_name = os.path.basename(pretrained_model_path)\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, Weights_choice+\"_\"+pretrained_model_name+\".net\")\n","\n","\n","# --------------------- Download the a model provided in the XXX ------------------------\n","\n"," if pretrained_model_choice == \"Model_name\":\n"," pretrained_model_name = \"Model_name\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the 2D_Demo_Model_from_Stardist_2D_paper\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path) \n"," wget.download(\"\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_last.h5 pretrained model does not exist')\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print(bcolors.WARNING+'No pretrained nerwork will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"rQndJj70FzfL"},"source":["# **4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"tGW2iaU6X5zi"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"id":"WMJnGJpCMa4y","cellView":"form"},"source":["#@markdown ##Create the model and dataset objects\n","\n","# --------------------- Here we delete the model folder if it already exist ------------------------\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: Model folder already exists and has been removed !!\" + W)\n"," shutil.rmtree(model_path+'/'+model_name)\n","\n","os.makedirs(full_model_path)\n","\n","\n","#Data preparation Load all the images into a single numpy array\n","\n","Training_Filelist = os.listdir(Training_source_temp)\n","\n","Nb_of_training_images = len(Training_Filelist)\n","\n","Validation_Filelist = os.listdir(Validation_source_temp)\n","\n","Nb_of_validation_images = len(Validation_Filelist)\n","\n","print(str(Nb_of_training_images)+\" training images\")\n","print(str(Nb_of_validation_images)+\" validation images\")\n","\n","my_train_data=np.array([np.array(imread(Training_source_temp+\"/\"+fname)) for fname in Training_Filelist]).astype(np.float32)\n","my_val_data=np.array([np.array(imread(Validation_source_temp+\"/\"+fname)) for fname in Validation_Filelist]).astype(np.float32)\n","\n","#Subtract the mean value of the background. It can be measured in Fiji, for example. This is important for the positivity constraint, which requires the background to be at 0.\n","#my_train_data = my_train_data - Average_background\n","#my_val_data = my_val_data - Average_background\n","\n","\n","#automated substraction of the lowest value my_train_data\n","n_image_my_train_data = my_train_data.shape[0]\n","for n in range(n_image_my_train_data):\n"," my_train_data[n] = my_train_data[n] - np.amin(my_train_data[n])\n","\n","#automated substraction of the lowest value in my_val_data\n","\n","n_image_my_val_data = my_val_data.shape[0]\n","for n in range(n_image_my_val_data):\n"," my_val_data[n] = my_val_data[n] - np.amin(my_val_data[n])\n","\n","#Here we automatically define number_of_step in function of training data and batch size\n","\n","if (Use_Default_Advanced_Parameters) or (number_of_steps == 0): \n"," #number_of_steps = int(Image_X*Image_Y/(patch_size*patch_size)*(int(len(Training_Filelist)/batch_size)+1))\n"," number_of_steps = 20\n","\n"," print(str(number_of_steps)+\" steps per EPOCH\")\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n"," \n","#Prepare the model\n","\n","device=utils.getDevice()\n","\n","# The network requires only a single output unit per pixel\n","\n","if Use_pretrained_model:\n"," net = torch.load(h5_file_path)\n","\n","if not Use_pretrained_model:\n"," net = UNet(1, depth=3)\n","\n","net.psf=psf_tensor.to(device)\n","\n","#Export summary of training parameters as pdf\n","pdf_export(augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)\n","\n","print(\"Setup done.\")\n","\n","\n","# creates a plot and shows one training patch and one validation patch.\n","\n","norm = simple_norm(my_train_data[0], percent = 99)\n","\n","plt.figure(figsize=(10,5))\n","plt.subplot(121)\n","plt.imshow(my_train_data[0], norm=norm, cmap='magma')\n","plt.axis('off')\n","plt.title('Training image');\n","plt.subplot(122)\n","plt.imshow(my_val_data[0], norm=norm, cmap='magma')\n","plt.axis('off')\n","plt.title('Validation image');\n","\n","plt.show()\n","\n","#pdf_export(pretrained_model = Use_pretrained_model)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"wQPz0F6JlvJR"},"source":["## **4.2. Start Training**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this \n","point.\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder from Google Drive as all data can be erased at the next training if using the same folder.\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"j_Qm5JBmlvJg","cellView":"form"},"source":["start = time.time()\n","\n","#@markdown ##Start training\n","\n","\n","\n","# Start training.\n","trainHist, valHist = training.trainNetwork(net = net, trainData = my_train_data, valData = my_val_data,\n"," postfix = model_name, directory = full_model_path,\n"," device = device, augment=Use_Data_augmentation, numOfEpochs = number_of_epochs, stepsPerEpoch = number_of_steps, \n"," virtualBatchSize = virtual_batch_size, batchSize = batch_size, patchSize=patch_size, learningRate = initial_learning_rate,\n"," psf = psf_tensor.to(device),positivity_constraint = 1)\n","\n","\n","print(\"Training done.\")\n","\n","\n","# convert the history.history dict to a pandas DataFrame: \n","#lossData = pd.DataFrame(history.valHist) \n","\n","if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n"," shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss'])\n"," for i in range(len(valHist)):\n"," writer.writerow([trainHist[i], valHist[i]])\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","pdf_export(trained = True, pretrained_model = Use_pretrained_model)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"QYuIOWQ3imuU"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"cellView":"form","id":"zazOZ3wDx0zQ"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," \n"," print(bcolors.WARNING + '!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"yDY9dtzdUTLh"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","It is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact noise patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"cellView":"form","id":"vMzSP50kMv5p"},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png')\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"biT9FI9Ri77_"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n"]},{"cell_type":"code","metadata":{"cellView":"form","id":"nAs4Wni7VYbq"},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","Weights_choice = \"best\" #@param [\"last\", \"best\"]\n","\n","# Create a quality control/Prediction Folder\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","\n","# List Tif images in Source_QC_folder\n","Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\n","Z = sorted(glob(Source_QC_folder_tif))\n","Z = list(map(imread,Z))\n","\n","print('Number of test dataset found in the folder: '+str(len(Z)))\n","\n","\n","# Activate the pretrained models. \n","net = torch.load(QC_model_path+\"/\"+QC_model_name+\"/\" + Weights_choice+\"_\" + QC_model_name + \".net\")\n","\n","\n","# Perform prediction on all datasets in the Source_QC folder\n","for filename in os.listdir(Source_QC_folder):\n"," img = imread(os.path.join(Source_QC_folder, filename))\n"," deconvolvedResult, denoisedResult = prediction.tiledPredict(img, net ,ps=256, overlap=48, device=device)\n","\n"," os.chdir(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n"," imsave(filename, denoisedResult)\n","\n","def ssim(img1, img2):\n"," return structural_similarity(img1,img2,data_range=1.,full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n","\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/QC_metrics_\"+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n","\n"," # Let's loop through the provided dataset in the QC folders\n","\n","\n"," for i in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder,i)):\n"," print('Running QC on: '+i)\n"," # -------------------------------- Target test data (Ground truth) --------------------------------\n"," test_GT = io.imread(os.path.join(Target_QC_folder, i))\n","\n"," # -------------------------------- Source test data --------------------------------\n"," test_source = io.imread(os.path.join(Source_QC_folder,i))\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and Source image\n"," test_GT_norm,test_source_norm = norm_minmse(test_GT, test_source, normalize_gt=True)\n","\n"," # -------------------------------- Prediction --------------------------------\n"," test_prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\",i))\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and prediction\n"," test_GT_norm,test_prediction_norm = norm_minmse(test_GT, test_prediction, normalize_gt=True) \n","\n","\n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\n","\n"," # Calculate the SSIM maps\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT_norm, test_prediction_norm)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT_norm, test_source_norm)\n","\n"," #Save ssim_maps\n"," img_SSIM_GTvsPrediction_32bit = np.float32(img_SSIM_GTvsPrediction)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/SSIM_GTvsPrediction_'+i,img_SSIM_GTvsPrediction_32bit)\n"," img_SSIM_GTvsSource_32bit = np.float32(img_SSIM_GTvsSource)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/SSIM_GTvsSource_'+i,img_SSIM_GTvsSource_32bit)\n"," \n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n","\n"," # Save SE maps\n"," img_RSE_GTvsPrediction_32bit = np.float32(img_RSE_GTvsPrediction)\n"," img_RSE_GTvsSource_32bit = np.float32(img_RSE_GTvsSource)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/RSE_GTvsPrediction_'+i,img_RSE_GTvsPrediction_32bit)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/RSE_GTvsSource_'+i,img_RSE_GTvsSource_32bit)\n","\n","\n"," # -------------------------------- Calculate the RSE metrics and save them --------------------------------\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n"," \n"," # We can also measure the peak signal to noise ratio between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n","\n"," writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource),str(PSNR_GTvsPrediction),str(PSNR_GTvsSource)])\n","\n","\n","# All data is now processed saved\n","Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same\n","\n","img_GT = io.imread(os.path.join(Target_QC_folder, Test_FileList[-1]))\n","img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n","img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction/\", Test_FileList[-1]))\n","\n","norm_img_GT = simple_norm(img_GT, percent = 99)\n","norm_img_Source = simple_norm(img_Source, percent = 99)\n","norm_img_Prediction = simple_norm(img_Prediction, percent = 99)\n","\n","\n","plt.figure(figsize=(15,15))\n","# Currently only displays the last computed set, from memory\n","# Target (Ground-truth)\n","plt.subplot(3,3,1)\n","plt.axis('off')\n","plt.imshow(img_GT, norm=norm_img_GT, cmap='magma')\n","plt.title('Target',fontsize=15)\n","\n","# Source\n","plt.subplot(3,3,2)\n","plt.axis('off')\n","plt.imshow(img_Source, norm=norm_img_Source, cmap='magma')\n","plt.title('Source',fontsize=15)\n","\n","#Prediction\n","plt.subplot(3,3,3)\n","plt.axis('off')\n","plt.imshow(img_Prediction, norm=norm_img_Prediction, cmap='magma')\n","plt.title('Prediction',fontsize=15)\n","\n","#Setting up colours\n","cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Source\n","plt.subplot(3,3,5)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\n","plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n","plt.subplot(3,3,6)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n","plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\n","\n","#Root Squared Error between GT and Source\n","plt.subplot(3,3,8)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource, cmap = cmap, vmin=0, vmax = 1)\n","plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsSource,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n","#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\n","plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#Root Squared Error between GT and Prediction\n","plt.subplot(3,3,9)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsPrediction,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n","plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","\n","qc_pdf_export()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"69aJVFfsqXbY"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"tcPNRq1TrMPB"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If an older model needs to be used, please untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contains the images that you want to predict using the network that you will train.\n","\n","**`Result_folder`:** This folder will contain the predicted output images.\n","\n","**`Data_type`:** Please indicate if the images you want to predict are single images or stacks"]},{"cell_type":"code","metadata":{"id":"Am2JSmpC0frj","cellView":"form"},"source":["Single_Images = 1\n","Stacks = 2\n","\n","from deconoising import prediction\n","\n","\n","#@markdown ### Provide the path to your dataset and to the folder where the prediction will be saved, then play the cell to predict output on your unseen images.\n","\n","#@markdown ###Path to data to analyse and where predicted output should be saved:\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n","#@markdown ###Are your data single images or stacks?\n","\n","Data_type = Single_Images #@param [\"Single_Images\", \"Stacks\"] {type:\"raw\"}\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","Weights_choice = \"best\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#@markdown ###Would you like to save the deconvolved images?\n","Save_deconvolved_images = True #@param {type:\"boolean\"}\n","\n","\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," print(bcolors.WARNING +'!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","#Activate the pretrained model. \n","net = torch.load(full_Prediction_model_path+\"/\" + Weights_choice+\"_\" + Prediction_model_name + \".net\")\n","\n"," # r=root, d=directories, f = files\n","for r, d, f in os.walk(Data_folder):\n"," for file in f:\n"," if \".tif\" in file:\n"," print(os.path.join(r, file))\n","\n","if Data_type == 1 :\n"," print(\"Single images are now beeing predicted\")\n","\n","# Loop through the files\n"," for r, d, f in os.walk(Data_folder):\n"," for file in f:\n"," base_filename = os.path.basename(file)\n"," input_train = imread(os.path.join(r, file))\n"," \n"," # We are using tiling to fit the image into memory\n"," # If you get an error try a smaller patch size (ps)\n"," # Here we are predicting the deconvolved and denoised image\n"," deconvolvedResult, denoisedResult = prediction.tiledPredict(input_train, net ,ps=256, overlap=48, device=device)\n","\n"," if Save_deconvolved_images:\n"," io.imsave(Result_folder+'/deconvolved_'+base_filename,deconvolvedResult)\n"," \n"," io.imsave(Result_folder+'/denoised_'+base_filename,denoisedResult)\n","\n"," print(\"Images saved into folder:\", Result_folder)\n","\n","if Data_type == 2 :\n"," print(\"Stacks are now beeing predicted\")\n"," for r, d, f in os.walk(Data_folder):\n"," for file in f:\n"," base_filename = os.path.basename(file)\n"," print(\"Denoising: \"+str(base_filename))\n"," timelapse = imread(os.path.join(r, file))\n"," n_timepoint = timelapse.shape[0]\n"," deconvolvedResult_stack = np.zeros((n_timepoint, timelapse.shape[1], timelapse.shape[2]))\n"," denoisedResult_stack = np.zeros((n_timepoint, timelapse.shape[1], timelapse.shape[2]))\n","\n"," for t in range(n_timepoint):\n"," img_t = timelapse[t]\n"," deconvolvedResult_stack[t], denoisedResult_stack[t] = prediction.tiledPredict(img_t, net ,ps=256, overlap=48, device=device)\n"," \n"," if Save_deconvolved_images:\n"," deconvolvedResult_stack_32 = img_as_float32(deconvolvedResult_stack, force_copy=False)\n"," imsave(Result_folder+'/deconvolved_'+base_filename, deconvolvedResult_stack_32) \n","\n"," denoisedResult_stack_32 = img_as_float32(denoisedResult_stack, force_copy=False)\n"," imsave(Result_folder+'/denoised_'+base_filename, denoisedResult_stack_32)\n"," del denoisedResult_stack_32 \n"," del denoisedResult_stack\n"," del timelapse\n"," del deconvolvedResult_stack\n"," \n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"67_8rEKp8C-z"},"source":["## **6.2. Assess predicted output**\n","---\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"n-stU-f08Cae","cellView":"form"},"source":["# @markdown ##Run this cell to display a randomly chosen input and its corresponding predicted output.\n","\n","# This will display a randomly chosen dataset input and predicted output\n","\n","\n","random_choice = random.choice(os.listdir(Data_folder))\n","input_train = imread(Data_folder+\"/\"+random_choice)\n","\n","#os.chdir(Result_folder)\n","#deconvolvedResult = imread(Result_folder+'/denoised_'+random_choice)\n","denoisedResult = imread(Result_folder+'/denoised_'+random_choice)\n","\n","norm = simple_norm(input_train, percent = 99)\n","\n","\n","if Data_type == 1 :\n","\n"," plt.figure(figsize=(15, 15))\n"," plt.subplot(1, 2, 1)\n"," plt.title('Input image')\n"," plt.imshow(input_train, norm=norm, cmap='magma')\n"," plt.axis('off');\n","\n","\n"," plt.subplot(1, 2, 2)\n"," plt.title('Denoised output')\n"," plt.imshow(denoisedResult, norm=norm, cmap='magma')\n"," plt.axis('off');\n","\n"," plt.figure(figsize=(15, 15))\n"," plt.subplot(1, 2, 1)\n"," plt.title('Input image')\n"," plt.imshow(input_train[100:200,150:250], norm=norm, cmap='magma')\n"," plt.axis('off');\n","\n"," plt.subplot(1, 2, 2)\n"," plt.title('Denoised output')\n"," plt.imshow(denoisedResult[100:200,150:250], norm=norm, cmap='magma') \n"," plt.axis('off');\n"," \n","\n","if Data_type == 2 :\n"," \n","\n"," plt.figure(figsize=(15, 15))\n"," plt.subplot(1, 2, 1)\n"," plt.title('Input image')\n"," plt.imshow(input_train[1], cmap='magma')\n"," plt.axis('off');\n","\n"," plt.subplot(1, 2, 2)\n"," plt.title('Denoised output')\n"," plt.imshow(denoisedResult[1], cmap='magma')\n"," plt.axis('off');\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"kuBWz4o1CmLA"},"source":["# **7. Version log**\n","---\n","**v1.13**: \n","\n","\n","* The section 1 and 2 are now swapped for better export of *requirements.txt*.\n","\n","* This version also now includes built-in version check and the version log that you're reading now."]},{"cell_type":"markdown","metadata":{"id":"u4pcBe8Z3T2J"},"source":["#**Thank you for using DecoNoising 2D!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/Beta notebooks/Deep-STORM_2D_ZeroCostDL4Mic_BioImageModelZoo_export.ipynb b/Colab_notebooks/Beta notebooks/Deep-STORM_2D_ZeroCostDL4Mic_BioImageModelZoo_export.ipynb index 1f37f140..402abc0b 100644 --- a/Colab_notebooks/Beta notebooks/Deep-STORM_2D_ZeroCostDL4Mic_BioImageModelZoo_export.ipynb +++ b/Colab_notebooks/Beta notebooks/Deep-STORM_2D_ZeroCostDL4Mic_BioImageModelZoo_export.ipynb @@ -1,3426 +1 @@ -{ - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "Deep-STORM_2D_ZeroCostDL4Mic_BioImageModelZoo_export.ipynb", - "provenance": [], - "collapsed_sections": [], - "toc_visible": true, - "machine_shape": "hm" - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "accelerator": "GPU" - }, - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "FpCtYevLHfl4" - }, - "source": [ - "# **Deep-STORM (2D)**\n", - "\n", - "---\n", - "\n", - "Deep-STORM is a neural network capable of image reconstruction from high-density single-molecule localization microscopy (SMLM), first published in 2018 by [Nehme *et al.* in Optica](https://www.osapublishing.org/optica/abstract.cfm?uri=optica-5-4-458). The architecture used here is a U-Net based network without skip connections. This network allows image reconstruction of 2D super-resolution images, in a supervised training manner. The network is trained using simulated high-density SMLM data for which the ground-truth is available. These simulations are obtained from random distribution of single molecules in a field-of-view and therefore do not imprint structural priors during training. The network output a super-resolution image with increased pixel density (typically upsampling factor of 8 in each dimension).\n", - "\n", - "Deep-STORM has **two key advantages**:\n", - "- SMLM reconstruction at high density of emitters\n", - "- fast prediction (reconstruction) once the model is trained appropriately, compared to more common multi-emitter fitting processes.\n", - "\n", - "\n", - "---\n", - "\n", - "*Disclaimer*:\n", - "\n", - "This notebook is part of the *Zero-Cost Deep-Learning to Enhance Microscopy* project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n", - "\n", - "This notebook is based on the following paper: \n", - "\n", - "**Deep-STORM: super-resolution single-molecule microscopy by deep learning**, Optica (2018) by *Elias Nehme, Lucien E. Weiss, Tomer Michaeli, and Yoav Shechtman* (https://www.osapublishing.org/optica/abstract.cfm?uri=optica-5-4-458)\n", - "\n", - "And source code found in: /~https://github.com/EliasNehme/Deep-STORM\n", - "\n", - "\n", - "**Please also cite this original paper when using or developing this notebook.**" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "wyzTn3IcHq6Y" - }, - "source": [ - "# **How to use this notebook?**\n", - "\n", - "---\n", - "\n", - "Video describing how to use our notebooks are available on youtube:\n", - " - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n", - " - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n", - "\n", - "\n", - "---\n", - "###**Structure of a notebook**\n", - "\n", - "The notebook contains two types of cell: \n", - "\n", - "**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n", - "\n", - "**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n", - "\n", - "---\n", - "###**Table of contents, Code snippets** and **Files**\n", - "\n", - "On the top left side of the notebook you find three tabs which contain from top to bottom:\n", - "\n", - "*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n", - "\n", - "*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n", - "\n", - "*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n", - "\n", - "**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n", - "\n", - "**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n", - "\n", - "---\n", - "###**Making changes to the notebook**\n", - "\n", - "**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n", - "\n", - "To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n", - "You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "bEy4EBXHHyAX" - }, - "source": [ - "#**0. Before getting started**\n", - "---\n", - " Deep-STORM is able to train on simulated dataset of SMLM data (see https://www.osapublishing.org/optica/abstract.cfm?uri=optica-5-4-458 for more info). Here, we provide a simulator that will generate training dataset (section 3.1.b). A few parameters will allow you to match the simulation to your experimental data. Similarly to what is described in the paper, simulations obtained from ThunderSTORM can also be loaded here (section 3.1.a).\n", - "\n", - "---\n", - "**Important note**\n", - "\n", - "- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n", - "\n", - "- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n", - "\n", - "- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "E04mOlG_H5Tz" - }, - "source": [ - "# **1. Initialise the Colab session**\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "F_tjlGzsH-Dn" - }, - "source": [ - "\n", - "## **1.1. Check for GPU access**\n", - "---\n", - "\n", - "By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n", - "\n", - "Go to **Runtime -> Change the Runtime type**\n", - "\n", - "**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n", - "\n", - "**Accelerator: GPU** *(Graphics processing unit)*\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "gn-LaaNNICqL", - "cellView": "form" - }, - "source": [ - "#@markdown ##Run this cell to check if you have GPU access\n", - "%tensorflow_version 1.x\n", - "\n", - "import tensorflow as tf\n", - "# if tf.__version__ != '2.2.0':\n", - "# !pip install tensorflow==2.2.0\n", - "\n", - "if tf.test.gpu_device_name()=='':\n", - " print('You do not have GPU access.') \n", - " print('Did you change your runtime ?') \n", - " print('If the runtime settings are correct then Google did not allocate GPU to your session')\n", - " print('Expect slow performance. To access GPU try reconnecting later')\n", - "\n", - "else:\n", - " print('You have GPU access')\n", - " !nvidia-smi\n", - "\n", - "# from tensorflow.python.client import device_lib \n", - "# device_lib.list_local_devices()\n", - "\n", - "# print the tensorflow version\n", - "print('Tensorflow version is ' + str(tf.__version__))\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tnP7wM79IKW-" - }, - "source": [ - "## **1.2. Mount your Google Drive**\n", - "---\n", - " To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n", - "\n", - " Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n", - "\n", - " Once this is done, your data are available in the **Files** tab on the top left of notebook." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "1R-7Fo34_gOd", - "cellView": "form" - }, - "source": [ - "#@markdown ##Run this cell to connect your Google Drive to Colab\n", - "\n", - "#@markdown * Click on the URL. \n", - "\n", - "#@markdown * Sign in your Google Account. \n", - "\n", - "#@markdown * Copy the authorization code. \n", - "\n", - "#@markdown * Enter the authorization code. \n", - "\n", - "#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n", - "\n", - "#mounts user's Google Drive to Google Colab.\n", - "\n", - "from google.colab import drive\n", - "drive.mount('/content/gdrive')" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jRnQZWSZhArJ" - }, - "source": [ - "# **2. Install Deep-STORM and dependencies**\n", - "---\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "JXtHCfYt43Xz" - }, - "source": [ - "## 2.1. Install key dependencies\n", - "---\n", - " " - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "Qs2fl-BG43jz", - "cellView": "form" - }, - "source": [ - "#@markdown ##Play to install DeepSTORM dependencies\n", - "\n", - "!pip install pydeepimagej==2.1.2\n", - "# !pip uninstall -y keras-nightly\n", - "!pip install data\n", - "!pip install fpdf\n", - "!pip install h5py==2.10\n", - "\n", - "\n", - "#Force session restart\n", - "# exit(0)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "hIELSNcX436q" - }, - "source": [ - "## 2.2. Restart your runtime and run all the cells again. \n", - "---\n", - " " - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "GHbIqZcxeHMJ" - }, - "source": [ - "** Skip this step if you already restarted the runtime.**" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "yuAPmDIc5WpY" - }, - "source": [ - "## 2.3. Load key dependencies\n", - "---\n", - " " - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "kSrZMo3X_NhO", - "cellView": "form" - }, - "source": [ - "Notebook_version = ['1.12']\n", - "\n", - "from builtins import any as b_any\n", - "\n", - "def get_requirements_path():\n", - " # Store requirements file in 'contents' directory \n", - " current_dir = os.getcwd()\n", - " dir_count = current_dir.count('/') - 1\n", - " path = '../' * (dir_count) + 'requirements.txt'\n", - " return path\n", - "\n", - "def filter_files(file_list, filter_list):\n", - " filtered_list = []\n", - " for fname in file_list:\n", - " if b_any(fname.split('==')[0] in s for s in filter_list):\n", - " filtered_list.append(fname)\n", - " return filtered_list\n", - "\n", - "def build_requirements_file(before, after):\n", - " path = get_requirements_path()\n", - "\n", - " # Exporting requirements.txt for local run\n", - " !pip freeze > $path\n", - "\n", - " # Get minimum requirements file\n", - " df = pd.read_csv(path, delimiter = \"\\n\")\n", - " mod_list = [m.split('.')[0] for m in after if not m in before]\n", - " req_list_temp = df.values.tolist()\n", - " req_list = [x[0] for x in req_list_temp]\n", - "\n", - " # Replace with package name and handle cases where import name is different to module name\n", - " mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]\n", - " mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list] \n", - " filtered_list = filter_files(req_list, mod_replace_list)\n", - "\n", - " file=open(path,'w')\n", - " for item in filtered_list:\n", - " file.writelines(item + '\\n')\n", - "\n", - " file.close()\n", - "\n", - "import sys\n", - "before = [str(m) for m in sys.modules]\n", - "\n", - "#@markdown ##Install Deep-STORM and dependencies\n", - "# %% Model definition + helper functions\n", - "\n", - "# Import keras modules and libraries\n", - "from tensorflow import keras\n", - "from tensorflow.keras.models import Model\n", - "from tensorflow.keras.layers import Input, Activation, UpSampling2D, Convolution2D, MaxPooling2D, BatchNormalization, Layer\n", - "from tensorflow.keras.callbacks import Callback\n", - "from tensorflow.keras import backend as K\n", - "from tensorflow.keras import optimizers, losses\n", - "\n", - "from tensorflow.keras.preprocessing.image import ImageDataGenerator\n", - "from tensorflow.keras.callbacks import ModelCheckpoint\n", - "from tensorflow.keras.callbacks import ReduceLROnPlateau\n", - "from skimage.transform import warp\n", - "from skimage.transform import SimilarityTransform\n", - "from skimage.metrics import structural_similarity\n", - "from skimage.metrics import peak_signal_noise_ratio as psnr\n", - "from scipy.signal import fftconvolve\n", - "\n", - "# Import common libraries\n", - "import tensorflow as tf\n", - "import numpy as np\n", - "import pandas as pd\n", - "import matplotlib.pyplot as plt\n", - "import h5py\n", - "import scipy.io as sio\n", - "from os.path import abspath\n", - "from sklearn.model_selection import train_test_split\n", - "from skimage import io\n", - "import time\n", - "import os\n", - "import shutil\n", - "import csv\n", - "from PIL import Image \n", - "from PIL.TiffTags import TAGS\n", - "from scipy.ndimage import gaussian_filter\n", - "import math\n", - "from astropy.visualization import simple_norm\n", - "from sys import getsizeof\n", - "from fpdf import FPDF, HTMLMixin\n", - "from pip._internal.operations.freeze import freeze\n", - "import subprocess\n", - "from datetime import datetime\n", - "\n", - "\n", - "# For sliders and dropdown menu, progress bar\n", - "from ipywidgets import interact\n", - "import ipywidgets as widgets\n", - "from tqdm import tqdm\n", - "\n", - "# For Multi-threading in simulation\n", - "from numba import njit, prange\n", - "\n", - "\n", - "\n", - "# define a function that projects and rescales an image to the range [0,1]\n", - "def project_01(im):\n", - " im = np.squeeze(im)\n", - " min_val = im.min()\n", - " max_val = im.max()\n", - " return (im - min_val)/(max_val - min_val)\n", - "\n", - "# normalize image given mean and std\n", - "def normalize_im(im, dmean, dstd):\n", - " im = np.squeeze(im)\n", - " im_norm = np.zeros(im.shape,dtype=np.float32)\n", - " im_norm = (im - dmean)/dstd\n", - " return im_norm\n", - "\n", - "# Define the loss history recorder\n", - "class LossHistory(Callback):\n", - " def on_train_begin(self, logs={}):\n", - " self.losses = []\n", - "\n", - " def on_batch_end(self, batch, logs={}):\n", - " self.losses.append(logs.get('loss'))\n", - " \n", - "# Define a matlab like gaussian 2D filter\n", - "def matlab_style_gauss2D(shape=(7,7),sigma=1):\n", - " \"\"\" \n", - " 2D gaussian filter - should give the same result as:\n", - " MATLAB's fspecial('gaussian',[shape],[sigma]) \n", - " \"\"\"\n", - " m,n = [(ss-1.)/2. for ss in shape]\n", - " y,x = np.ogrid[-m:m+1,-n:n+1]\n", - " h = np.exp( -(x*x + y*y) / (2.*sigma*sigma) )\n", - " h.astype(dtype=K.floatx())\n", - " h[ h < np.finfo(h.dtype).eps*h.max() ] = 0\n", - " sumh = h.sum()\n", - " if sumh != 0:\n", - " h /= sumh\n", - " h = h*2.0\n", - " h = h.astype('float32')\n", - " return h\n", - "\n", - "# Expand the filter dimensions\n", - "psf_heatmap = matlab_style_gauss2D(shape = (7,7),sigma=1)\n", - "gfilter = tf.reshape(psf_heatmap, [7, 7, 1, 1])\n", - "\n", - "# Combined MSE + L1 loss\n", - "def L1L2loss(input_shape):\n", - " def bump_mse(heatmap_true, spikes_pred):\n", - "\n", - " # generate the heatmap corresponding to the predicted spikes\n", - " heatmap_pred = K.conv2d(spikes_pred, gfilter, strides=(1, 1), padding='same')\n", - "\n", - " # heatmaps MSE\n", - " loss_heatmaps = losses.mean_squared_error(heatmap_true,heatmap_pred)\n", - "\n", - " # l1 on the predicted spikes\n", - " loss_spikes = losses.mean_absolute_error(spikes_pred,tf.zeros(input_shape))\n", - " return loss_heatmaps + loss_spikes\n", - " return bump_mse\n", - "\n", - "# Define the concatenated conv2, batch normalization, and relu block\n", - "def conv_bn_relu(nb_filter, rk, ck, name):\n", - " def f(input):\n", - " conv = Convolution2D(nb_filter, kernel_size=(rk, ck), strides=(1,1),\\\n", - " padding=\"same\", use_bias=False,\\\n", - " kernel_initializer=\"Orthogonal\",name='conv-'+name)(input)\n", - " conv_norm = BatchNormalization(name='BN-'+name)(conv)\n", - " conv_norm_relu = Activation(activation = \"relu\",name='Relu-'+name)(conv_norm)\n", - " return conv_norm_relu\n", - " return f\n", - "\n", - "# Define the model architechture\n", - "def CNN(input,names):\n", - " Features1 = conv_bn_relu(32,3,3,names+'F1')(input)\n", - " pool1 = MaxPooling2D(pool_size=(2,2),name=names+'Pool1')(Features1)\n", - " Features2 = conv_bn_relu(64,3,3,names+'F2')(pool1)\n", - " pool2 = MaxPooling2D(pool_size=(2, 2),name=names+'Pool2')(Features2)\n", - " Features3 = conv_bn_relu(128,3,3,names+'F3')(pool2)\n", - " pool3 = MaxPooling2D(pool_size=(2, 2),name=names+'Pool3')(Features3)\n", - " Features4 = conv_bn_relu(512,3,3,names+'F4')(pool3)\n", - " up5 = UpSampling2D(size=(2, 2),name=names+'Upsample1')(Features4)\n", - " Features5 = conv_bn_relu(128,3,3,names+'F5')(up5)\n", - " up6 = UpSampling2D(size=(2, 2),name=names+'Upsample2')(Features5)\n", - " Features6 = conv_bn_relu(64,3,3,names+'F6')(up6)\n", - " up7 = UpSampling2D(size=(2, 2),name=names+'Upsample3')(Features6)\n", - " Features7 = conv_bn_relu(32,3,3,names+'F7')(up7)\n", - " return Features7\n", - "\n", - "# Define the Model building for an arbitrary input size\n", - "def buildModel(input_dim, initial_learning_rate = 0.001):\n", - " input_ = Input (shape = (input_dim))\n", - " act_ = CNN (input_,'CNN')\n", - " density_pred = Convolution2D(1, kernel_size=(1, 1), strides=(1, 1), padding=\"same\",\\\n", - " activation=\"linear\", use_bias = False,\\\n", - " kernel_initializer=\"Orthogonal\",name='Prediction')(act_)\n", - " model = Model (inputs= input_, outputs=density_pred)\n", - " opt = optimizers.Adam(lr = initial_learning_rate)\n", - " model.compile(optimizer=opt, loss = L1L2loss(input_dim))\n", - " return model\n", - "\n", - "\n", - "# define a function that trains a model for a given data SNR and density\n", - "def train_model(patches, heatmaps, modelPath, epochs, steps_per_epoch, batch_size, upsampling_factor=8, validation_split = 0.3, initial_learning_rate = 0.001, pretrained_model_path = '', L2_weighting_factor = 100):\n", - " \n", - " \"\"\"\n", - " This function trains a CNN model on the desired training set, given the \n", - " upsampled training images and labels generated in MATLAB.\n", - " \n", - " # Inputs\n", - " # TO UPDATE ----------\n", - "\n", - " # Outputs\n", - " function saves the weights of the trained model to a hdf5, and the \n", - " normalization factors to a mat file. These will be loaded later for testing \n", - " the model in test_model. \n", - " \"\"\"\n", - " \n", - " # for reproducibility\n", - " np.random.seed(123)\n", - "\n", - " X_train, X_test, y_train, y_test = train_test_split(patches, heatmaps, test_size = validation_split, random_state=42)\n", - " print('Number of training examples: %d' % X_train.shape[0])\n", - " print('Number of validation examples: %d' % X_test.shape[0])\n", - " \n", - " # Setting type\n", - " X_train = X_train.astype('float32')\n", - " X_test = X_test.astype('float32')\n", - " y_train = y_train.astype('float32')\n", - " y_test = y_test.astype('float32')\n", - "\n", - " \n", - " #===================== Training set normalization ==========================\n", - " # normalize training images to be in the range [0,1] and calculate the \n", - " # training set mean and std\n", - " mean_train = np.zeros(X_train.shape[0],dtype=np.float32)\n", - " std_train = np.zeros(X_train.shape[0], dtype=np.float32)\n", - " for i in range(X_train.shape[0]):\n", - " X_train[i, :, :] = project_01(X_train[i, :, :])\n", - " mean_train[i] = X_train[i, :, :].mean()\n", - " std_train[i] = X_train[i, :, :].std()\n", - "\n", - " # resulting normalized training images\n", - " mean_val_train = mean_train.mean()\n", - " std_val_train = std_train.mean()\n", - " X_train_norm = np.zeros(X_train.shape, dtype=np.float32)\n", - " for i in range(X_train.shape[0]):\n", - " X_train_norm[i, :, :] = normalize_im(X_train[i, :, :], mean_val_train, std_val_train)\n", - " \n", - " # patch size\n", - " psize = X_train_norm.shape[1]\n", - "\n", - " # Reshaping\n", - " X_train_norm = X_train_norm.reshape(X_train.shape[0], psize, psize, 1)\n", - "\n", - " # ===================== Test set normalization ==========================\n", - " # normalize test images to be in the range [0,1] and calculate the test set \n", - " # mean and std\n", - " mean_test = np.zeros(X_test.shape[0],dtype=np.float32)\n", - " std_test = np.zeros(X_test.shape[0], dtype=np.float32)\n", - " for i in range(X_test.shape[0]):\n", - " X_test[i, :, :] = project_01(X_test[i, :, :])\n", - " mean_test[i] = X_test[i, :, :].mean()\n", - " std_test[i] = X_test[i, :, :].std()\n", - "\n", - " # resulting normalized test images\n", - " mean_val_test = mean_test.mean()\n", - " std_val_test = std_test.mean()\n", - " X_test_norm = np.zeros(X_test.shape, dtype=np.float32)\n", - " for i in range(X_test.shape[0]):\n", - " X_test_norm[i, :, :] = normalize_im(X_test[i, :, :], mean_val_test, std_val_test)\n", - " \n", - " # Reshaping\n", - " X_test_norm = X_test_norm.reshape(X_test.shape[0], psize, psize, 1)\n", - "\n", - " # Reshaping labels\n", - " Y_train = y_train.reshape(y_train.shape[0], psize, psize, 1)\n", - " Y_test = y_test.reshape(y_test.shape[0], psize, psize, 1)\n", - "\n", - " # Save datasets to a matfile to open later in matlab\n", - " mdict = {\"mean_test\": mean_val_test, \"std_test\": std_val_test, \"upsampling_factor\": upsampling_factor, \"Normalization factor\": L2_weighting_factor}\n", - " sio.savemat(os.path.join(modelPath,\"model_metadata.mat\"), mdict)\n", - "\n", - "\n", - " # Set the dimensions ordering according to tensorflow consensous\n", - " # K.set_image_dim_ordering('tf')\n", - " K.set_image_data_format('channels_last')\n", - "\n", - " # Save the model weights after each epoch if the validation loss decreased\n", - " checkpointer = ModelCheckpoint(filepath=os.path.join(modelPath,\"weights_best.hdf5\"), verbose=1,\n", - " save_best_only=True)\n", - "\n", - " # Change learning when loss reaches a plataeu\n", - " change_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5, min_lr=0.00005)\n", - " \n", - " # Model building and complitation\n", - " model = buildModel((psize, psize, 1), initial_learning_rate = initial_learning_rate)\n", - " model.summary()\n", - "\n", - " # Load pretrained model\n", - " if not pretrained_model_path:\n", - " print('Using random initial model weights.')\n", - " else:\n", - " print('Loading model weights from '+pretrained_model_path)\n", - " model.load_weights(pretrained_model_path)\n", - " \n", - " # Create an image data generator for real time data augmentation\n", - " datagen = ImageDataGenerator(\n", - " featurewise_center=False, # set input mean to 0 over the dataset\n", - " samplewise_center=False, # set each sample mean to 0\n", - " featurewise_std_normalization=False, # divide inputs by std of the dataset\n", - " samplewise_std_normalization=False, # divide each input by its std\n", - " zca_whitening=False, # apply ZCA whitening\n", - " rotation_range=0., # randomly rotate images in the range (degrees, 0 to 180)\n", - " width_shift_range=0., # randomly shift images horizontally (fraction of total width)\n", - " height_shift_range=0., # randomly shift images vertically (fraction of total height)\n", - " zoom_range=0.,\n", - " shear_range=0.,\n", - " horizontal_flip=False, # randomly flip images\n", - " vertical_flip=False, # randomly flip images\n", - " fill_mode='constant',\n", - " data_format=K.image_data_format())\n", - "\n", - " # Fit the image generator on the training data\n", - " datagen.fit(X_train_norm)\n", - " \n", - " # loss history recorder\n", - " history = LossHistory()\n", - "\n", - " # Inform user training begun\n", - " print('-------------------------------')\n", - " print('Training model...')\n", - "\n", - " # Fit model on the batches generated by datagen.flow()\n", - " train_history = model.fit_generator(datagen.flow(X_train_norm, Y_train, batch_size=batch_size), \n", - " steps_per_epoch=steps_per_epoch, epochs=epochs, verbose=1, \n", - " validation_data=(X_test_norm, Y_test), \n", - " callbacks=[history, checkpointer, change_lr]) \n", - "\n", - " # Inform user training ended\n", - " print('-------------------------------')\n", - " print('Training Complete!')\n", - " \n", - " # Save the last model\n", - " model.save(os.path.join(modelPath, 'weights_last.hdf5'))\n", - "\n", - " # convert the history.history dict to a pandas DataFrame: \n", - " lossData = pd.DataFrame(train_history.history) \n", - "\n", - " if os.path.exists(os.path.join(modelPath,\"Quality Control\")):\n", - " shutil.rmtree(os.path.join(modelPath,\"Quality Control\"))\n", - "\n", - " os.makedirs(os.path.join(modelPath,\"Quality Control\"))\n", - "\n", - " # The training evaluation.csv is saved (overwrites the Files if needed). \n", - " lossDataCSVpath = os.path.join(modelPath,\"Quality Control/training_evaluation.csv\")\n", - " with open(lossDataCSVpath, 'w') as f:\n", - " writer = csv.writer(f)\n", - " writer.writerow(['loss','val_loss','learning rate'])\n", - " for i in range(len(train_history.history['loss'])):\n", - " writer.writerow([train_history.history['loss'][i], train_history.history['val_loss'][i], train_history.history['lr'][i]])\n", - "\n", - " return\n", - "\n", - "\n", - "# Normalization functions from Martin Weigert used in CARE\n", - "def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n", - " \"\"\"This function is adapted from Martin Weigert\"\"\"\n", - " \"\"\"Percentile-based image normalization.\"\"\"\n", - "\n", - " mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n", - " ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n", - " return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n", - "\n", - "\n", - "def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n", - " \"\"\"This function is adapted from Martin Weigert\"\"\"\n", - " if dtype is not None:\n", - " x = x.astype(dtype,copy=False)\n", - " mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n", - " ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n", - " eps = dtype(eps)\n", - "\n", - " try:\n", - " import numexpr\n", - " x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n", - " except ImportError:\n", - " x = (x - mi) / ( ma - mi + eps )\n", - "\n", - " if clip:\n", - " x = np.clip(x,0,1)\n", - "\n", - " return x\n", - "\n", - "def norm_minmse(gt, x, normalize_gt=True):\n", - " \"\"\"This function is adapted from Martin Weigert\"\"\"\n", - "\n", - " \"\"\"\n", - " normalizes and affinely scales an image pair such that the MSE is minimized \n", - " \n", - " Parameters\n", - " ----------\n", - " gt: ndarray\n", - " the ground truth image \n", - " x: ndarray\n", - " the image that will be affinely scaled \n", - " normalize_gt: bool\n", - " set to True of gt image should be normalized (default)\n", - " Returns\n", - " -------\n", - " gt_scaled, x_scaled \n", - " \"\"\"\n", - " if normalize_gt:\n", - " gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n", - " x = x.astype(np.float32, copy=False) - np.mean(x)\n", - " #x = x - np.mean(x)\n", - " gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n", - " #gt = gt - np.mean(gt)\n", - " scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n", - " return gt, scale * x\n", - "\n", - "\n", - "# Multi-threaded Erf-based image construction\n", - "@njit(parallel=True)\n", - "def FromLoc2Image_Erf(xc_array, yc_array, photon_array, sigma_array, image_size = (64,64), pixel_size = 100):\n", - " w = image_size[0]\n", - " h = image_size[1]\n", - " erfImage = np.zeros((w, h))\n", - " for ij in prange(w*h):\n", - " j = int(ij/w)\n", - " i = ij - j*w\n", - " for (xc, yc, photon, sigma) in zip(xc_array, yc_array, photon_array, sigma_array):\n", - " # Don't bother if the emitter has photons <= 0 or if Sigma <= 0\n", - " if (sigma > 0) and (photon > 0):\n", - " S = sigma*math.sqrt(2)\n", - " x = i*pixel_size - xc\n", - " y = j*pixel_size - yc\n", - " # Don't bother if the emitter is further than 4 sigma from the centre of the pixel\n", - " if (x+pixel_size/2)**2 + (y+pixel_size/2)**2 < 16*sigma**2:\n", - " ErfX = math.erf((x+pixel_size)/S) - math.erf(x/S)\n", - " ErfY = math.erf((y+pixel_size)/S) - math.erf(y/S)\n", - " erfImage[j][i] += 0.25*photon*ErfX*ErfY\n", - " return erfImage\n", - "\n", - "\n", - "@njit(parallel=True)\n", - "def FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = (64,64), pixel_size = 100):\n", - " w = image_size[0]\n", - " h = image_size[1]\n", - " locImage = np.zeros((image_size[0],image_size[1]) )\n", - " n_locs = len(xc_array)\n", - "\n", - " for e in prange(n_locs):\n", - " locImage[int(max(min(round(yc_array[e]/pixel_size),w-1),0))][int(max(min(round(xc_array[e]/pixel_size),h-1),0))] += 1\n", - "\n", - " return locImage\n", - "\n", - "\n", - "\n", - "def getPixelSizeTIFFmetadata(TIFFpath, display=False):\n", - " with Image.open(TIFFpath) as img:\n", - " meta_dict = {TAGS[key] : img.tag[key] for key in img.tag.keys()}\n", - "\n", - "\n", - " # TIFF tags\n", - " # https://www.loc.gov/preservation/digital/formats/content/tiff_tags.shtml\n", - " # https://www.awaresystems.be/imaging/tiff/tifftags/resolutionunit.html\n", - " ResolutionUnit = meta_dict['ResolutionUnit'][0] # unit of resolution\n", - " width = meta_dict['ImageWidth'][0]\n", - " height = meta_dict['ImageLength'][0]\n", - "\n", - " xResolution = meta_dict['XResolution'][0] # number of pixels / ResolutionUnit\n", - "\n", - " if len(xResolution) == 1:\n", - " xResolution = xResolution[0]\n", - " elif len(xResolution) == 2:\n", - " xResolution = xResolution[0]/xResolution[1]\n", - " else:\n", - " print('Image resolution not defined.')\n", - " xResolution = 1\n", - "\n", - " if ResolutionUnit == 2:\n", - " # Units given are in inches\n", - " pixel_size = 0.025*1e9/xResolution\n", - " elif ResolutionUnit == 3:\n", - " # Units given are in cm\n", - " pixel_size = 0.01*1e9/xResolution\n", - " else: \n", - " # ResolutionUnit is therefore 1\n", - " print('Resolution unit not defined. Assuming: um')\n", - " pixel_size = 1e3/xResolution\n", - "\n", - " if display:\n", - " print('Pixel size obtained from metadata: '+str(pixel_size)+' nm')\n", - " print('Image size: '+str(width)+'x'+str(height))\n", - " \n", - " return (pixel_size, width, height)\n", - "\n", - "\n", - "def saveAsTIF(path, filename, array, pixel_size):\n", - " \"\"\"\n", - " Image saving using PIL to save as .tif format\n", - " # Input \n", - " path - path where it will be saved\n", - " filename - name of the file to save (no extension)\n", - " array - numpy array conatining the data at the required format\n", - " pixel_size - physical size of pixels in nanometers (identical for x and y)\n", - " \"\"\"\n", - "\n", - " # print('Data type: '+str(array.dtype))\n", - " if (array.dtype == np.uint16):\n", - " mode = 'I;16'\n", - " elif (array.dtype == np.uint32):\n", - " mode = 'I'\n", - " else:\n", - " mode = 'F'\n", - "\n", - " # Rounding the pixel size to the nearest number that divides exactly 1cm.\n", - " # Resolution needs to be a rational number --> see TIFF format\n", - " # pixel_size = 10000/(round(10000/pixel_size))\n", - "\n", - " if len(array.shape) == 2:\n", - " im = Image.fromarray(array)\n", - " im.save(os.path.join(path, filename+'.tif'),\n", - " mode = mode, \n", - " resolution_unit = 3,\n", - " resolution = 0.01*1e9/pixel_size)\n", - "\n", - "\n", - " elif len(array.shape) == 3:\n", - " imlist = []\n", - " for frame in array:\n", - " imlist.append(Image.fromarray(frame))\n", - "\n", - " imlist[0].save(os.path.join(path, filename+'.tif'), save_all=True,\n", - " append_images=imlist[1:],\n", - " mode = mode, \n", - " resolution_unit = 3,\n", - " resolution = 0.01*1e9/pixel_size)\n", - "\n", - " return\n", - "\n", - "\n", - "\n", - "\n", - "class Maximafinder(Layer):\n", - " def __init__(self, thresh, neighborhood_size, use_local_avg, **kwargs):\n", - " super(Maximafinder, self).__init__(**kwargs)\n", - " self.thresh = tf.constant(thresh, dtype=tf.float32)\n", - " self.nhood = neighborhood_size\n", - " self.use_local_avg = use_local_avg\n", - "\n", - " def build(self, input_shape):\n", - " if self.use_local_avg is True:\n", - " self.kernel_x = tf.reshape(tf.constant([[-1,0,1],[-1,0,1],[-1,0,1]], dtype=tf.float32), [3, 3, 1, 1])\n", - " self.kernel_y = tf.reshape(tf.constant([[-1,-1,-1],[0,0,0],[1,1,1]], dtype=tf.float32), [3, 3, 1, 1])\n", - " self.kernel_sum = tf.reshape(tf.constant([[1,1,1],[1,1,1],[1,1,1]], dtype=tf.float32), [3, 3, 1, 1])\n", - "\n", - " def call(self, inputs):\n", - "\n", - " # local maxima positions\n", - " max_pool_image = MaxPooling2D(pool_size=(self.nhood,self.nhood), strides=(1,1), padding='same')(inputs)\n", - " cond = tf.math.greater(max_pool_image, self.thresh) & tf.math.equal(max_pool_image, inputs)\n", - " indices = tf.where(cond)\n", - " bind, xind, yind = indices[:, 0], indices[:, 2], indices[:, 1]\n", - " confidence = tf.gather_nd(inputs, indices)\n", - "\n", - " # local CoG estimator\n", - " if self.use_local_avg:\n", - " x_image = K.conv2d(inputs, self.kernel_x, padding='same')\n", - " y_image = K.conv2d(inputs, self.kernel_y, padding='same')\n", - " sum_image = K.conv2d(inputs, self.kernel_sum, padding='same')\n", - " confidence = tf.cast(tf.gather_nd(sum_image, indices), dtype=tf.float32)\n", - " x_local = tf.math.divide(tf.gather_nd(x_image, indices),tf.gather_nd(sum_image, indices))\n", - " y_local = tf.math.divide(tf.gather_nd(y_image, indices),tf.gather_nd(sum_image, indices))\n", - " xind = tf.cast(xind, dtype=tf.float32) + tf.cast(x_local, dtype=tf.float32)\n", - " yind = tf.cast(yind, dtype=tf.float32) + tf.cast(y_local, dtype=tf.float32)\n", - " else:\n", - " xind = tf.cast(xind, dtype=tf.float32)\n", - " yind = tf.cast(yind, dtype=tf.float32)\n", - " \n", - " return bind, xind, yind, confidence\n", - "\n", - " def get_config(self):\n", - "\n", - " # Implement get_config to enable serialization. This is optional.\n", - " base_config = super(Maximafinder, self).get_config()\n", - " config = {}\n", - " return dict(list(base_config.items()) + list(config.items()))\n", - "\n", - "\n", - "\n", - "# ------------------------------- Prediction with postprocessing function-------------------------------\n", - "def batchFramePredictionLocalization(dataPath, filename, modelPath, savePath, batch_size=1, thresh=0.1, neighborhood_size=3, use_local_avg = False, pixel_size = None):\n", - " \"\"\"\n", - " This function tests a trained model on the desired test set, given the \n", - " tiff stack of test images, learned weights, and normalization factors.\n", - " \n", - " # Inputs\n", - " dataPath - the path to the folder containing the tiff stack(s) to run prediction on \n", - " filename - the name of the file to process\n", - " modelPath - the path to the folder containing the weights file and the mean and standard deviation file generated in train_model\n", - " savePath - the path to the folder where to save the prediction\n", - " batch_size. - the number of frames to predict on for each iteration\n", - " thresh - threshoold percentage from the maximum of the gaussian scaling\n", - " neighborhood_size - the size of the neighborhood for local maxima finding\n", - " use_local_average - Boolean whether to perform local averaging or not\n", - " \"\"\"\n", - " \n", - " # load mean and std\n", - " matfile = sio.loadmat(os.path.join(modelPath,'model_metadata.mat'))\n", - " test_mean = np.array(matfile['mean_test'])\n", - " test_std = np.array(matfile['std_test']) \n", - " upsampling_factor = np.array(matfile['upsampling_factor'])\n", - " upsampling_factor = upsampling_factor.item() # convert to scalar\n", - " L2_weighting_factor = np.array(matfile['Normalization factor'])\n", - " L2_weighting_factor = L2_weighting_factor.item() # convert to scalar\n", - "\n", - " # Read in the raw file\n", - " Images = io.imread(os.path.join(dataPath, filename))\n", - " if pixel_size == None:\n", - " pixel_size, _, _ = getPixelSizeTIFFmetadata(os.path.join(dataPath, filename), display=True)\n", - " pixel_size_hr = pixel_size/upsampling_factor\n", - "\n", - " # get dataset dimensions\n", - " (nFrames, M, N) = Images.shape\n", - " print('Input image is '+str(N)+'x'+str(M)+' with '+str(nFrames)+' frames.')\n", - "\n", - " # Build the model for a bigger image\n", - " model = buildModel((upsampling_factor*M, upsampling_factor*N, 1))\n", - "\n", - " # Load the trained weights\n", - " model.load_weights(os.path.join(modelPath,'weights_best.hdf5'))\n", - "\n", - " # add a post-processing module\n", - " max_layer = Maximafinder(thresh*L2_weighting_factor, neighborhood_size, use_local_avg)\n", - "\n", - " # Initialise the results: lists will be used to collect all the localizations\n", - " frame_number_list, x_nm_list, y_nm_list, confidence_au_list = [], [], [], []\n", - "\n", - " # Initialise the results\n", - " Prediction = np.zeros((M*upsampling_factor, N*upsampling_factor), dtype=np.float32)\n", - " Widefield = np.zeros((M, N), dtype=np.float32)\n", - "\n", - " # run model in batches\n", - " n_batches = math.ceil(nFrames/batch_size)\n", - " for b in tqdm(range(n_batches)):\n", - "\n", - " nF = min(batch_size, nFrames - b*batch_size)\n", - " Images_norm = np.zeros((nF, M, N),dtype=np.float32)\n", - " Images_upsampled = np.zeros((nF, M*upsampling_factor, N*upsampling_factor), dtype=np.float32)\n", - "\n", - " # Upsampling using a simple nearest neighbor interp and calculating - MULTI-THREAD this?\n", - " for f in range(nF):\n", - " Images_norm[f,:,:] = project_01(Images[b*batch_size+f,:,:])\n", - " Images_norm[f,:,:] = normalize_im(Images_norm[f,:,:], test_mean, test_std)\n", - " Images_upsampled[f,:,:] = np.kron(Images_norm[f,:,:], np.ones((upsampling_factor,upsampling_factor)))\n", - " Widefield += Images[b*batch_size+f,:,:]\n", - "\n", - " # Reshaping\n", - " Images_upsampled = np.expand_dims(Images_upsampled,axis=3)\n", - "\n", - " # Run prediction and local amxima finding\n", - " predicted_density = model.predict_on_batch(Images_upsampled)\n", - " predicted_density[predicted_density < 0] = 0\n", - " Prediction += predicted_density.sum(axis = 3).sum(axis = 0)\n", - "\n", - " bind, xind, yind, confidence = max_layer(predicted_density)\n", - " bind = bind.eval(session=tf.compat.v1.Session())\n", - " xind = xind.eval(session=tf.compat.v1.Session())\n", - " yind = yind.eval(session=tf.compat.v1.Session())\n", - " confidence = confidence.eval(session=tf.compat.v1.Session())\n", - "\n", - " # normalizing the confidence by the L2_weighting_factor\n", - " confidence /= L2_weighting_factor \n", - "\n", - " # turn indices to nms and append to the results\n", - " xind, yind = xind*pixel_size_hr, yind*pixel_size_hr\n", - " # frmind = (bind.numpy() + b*batch_size + 1).tolist()\n", - " # xind = xind.numpy().tolist()\n", - " # yind = yind.numpy().tolist()\n", - " # confidence = confidence.numpy().tolist()\n", - " frmind = (bind + b*batch_size + 1).tolist()\n", - " xind = xind.tolist()\n", - " yind = yind.tolist()\n", - " confidence = confidence.tolist()\n", - " \n", - " frame_number_list += frmind\n", - " x_nm_list += xind\n", - " y_nm_list += yind\n", - " confidence_au_list += confidence\n", - "\n", - " # Open and create the csv file that will contain all the localizations\n", - " if use_local_avg:\n", - " ext = '_avg'\n", - " else:\n", - " ext = '_max'\n", - " with open(os.path.join(savePath, 'Localizations_' + os.path.splitext(filename)[0] + ext + '.csv'), \"w\", newline='') as file:\n", - " writer = csv.writer(file)\n", - " writer.writerow(['frame', 'x [nm]', 'y [nm]', 'confidence [a.u]'])\n", - " locs = list(zip(frame_number_list, x_nm_list, y_nm_list, confidence_au_list))\n", - " writer.writerows(locs)\n", - "\n", - " # Save the prediction and widefield image\n", - " Widefield = np.kron(Widefield, np.ones((upsampling_factor,upsampling_factor)))\n", - " Widefield = np.float32(Widefield)\n", - "\n", - " # io.imsave(os.path.join(savePath, 'Predicted_'+os.path.splitext(filename)[0]+'.tif'), Prediction)\n", - " # io.imsave(os.path.join(savePath, 'Widefield_'+os.path.splitext(filename)[0]+'.tif'), Widefield)\n", - "\n", - " saveAsTIF(savePath, 'Predicted_'+os.path.splitext(filename)[0], Prediction, pixel_size_hr)\n", - " saveAsTIF(savePath, 'Widefield_'+os.path.splitext(filename)[0], Widefield, pixel_size_hr)\n", - "\n", - "\n", - " return\n", - "\n", - "\n", - "# Colors for the warning messages\n", - "class bcolors:\n", - " WARNING = '\\033[31m'\n", - " NORMAL = '\\033[0m' # white (normal)\n", - "\n", - "\n", - "\n", - "def list_files(directory, extension):\n", - " return (f for f in os.listdir(directory) if f.endswith('.' + extension))\n", - "\n", - "\n", - "# @njit(parallel=True)\n", - "def subPixelMaxLocalization(array, method = 'CoM', patch_size = 3):\n", - " xMaxInd, yMaxInd = np.unravel_index(array.argmax(), array.shape, order='C')\n", - " centralPatch = XC[(xMaxInd-patch_size):(xMaxInd+patch_size+1),(yMaxInd-patch_size):(yMaxInd+patch_size+1)]\n", - "\n", - " if (method == 'MAX'):\n", - " x0 = xMaxInd\n", - " y0 = yMaxInd\n", - "\n", - " elif (method == 'CoM'):\n", - " x0 = 0\n", - " y0 = 0\n", - " S = 0\n", - " for xy in range(patch_size*patch_size):\n", - " y = math.floor(xy/patch_size)\n", - " x = xy - y*patch_size\n", - " x0 += x*array[x,y]\n", - " y0 += y*array[x,y]\n", - " S = array[x,y]\n", - " \n", - " x0 = x0/S - patch_size/2 + xMaxInd\n", - " y0 = y0/S - patch_size/2 + yMaxInd\n", - " \n", - " elif (method == 'Radiality'):\n", - " # Not implemented yet\n", - " x0 = xMaxInd\n", - " y0 = yMaxInd\n", - " \n", - " return (x0, y0)\n", - "\n", - "\n", - "@njit(parallel=True)\n", - "def correctDriftLocalization(xc_array, yc_array, frames, xDrift, yDrift):\n", - " n_locs = xc_array.shape[0]\n", - " xc_array_Corr = np.empty(n_locs)\n", - " yc_array_Corr = np.empty(n_locs)\n", - " \n", - " for loc in prange(n_locs):\n", - " xc_array_Corr[loc] = xc_array[loc] - xDrift[frames[loc]]\n", - " yc_array_Corr[loc] = yc_array[loc] - yDrift[frames[loc]]\n", - "\n", - " return (xc_array_Corr, yc_array_Corr)\n", - "\n", - "\n", - "print('--------------------------------')\n", - "print('DeepSTORM installation complete.')\n", - "\n", - "# Check if this is the latest version of the notebook\n", - "Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n", - "\n", - "if Notebook_version == list(Latest_notebook_version.columns):\n", - " print(\"This notebook is up-to-date.\")\n", - "\n", - "if not Notebook_version == list(Latest_notebook_version.columns):\n", - " print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n", - "\n", - "def pdf_export(trained = False, raw_data = False, pretrained_model = False):\n", - " class MyFPDF(FPDF, HTMLMixin):\n", - " pass\n", - "\n", - " pdf = MyFPDF()\n", - " pdf.add_page()\n", - " pdf.set_right_margin(-1)\n", - " pdf.set_font(\"Arial\", size = 11, style='B') \n", - "\n", - " Network = 'Deep-STORM'\n", - " #model_name = 'little_CARE_test'\n", - " day = datetime.now()\n", - " datetime_str = str(day)[0:10]\n", - "\n", - " Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n", - " pdf.multi_cell(180, 5, txt = Header, align = 'L') \n", - " \n", - " # add another cell \n", - " if trained:\n", - " training_time = \"Training time: \"+str(hours)+ \"hour(s) \"+str(minutes)+\"min(s) \"+str(round(seconds))+\"sec(s)\"\n", - " pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n", - " pdf.ln(1)\n", - "\n", - " Header_2 = 'Information for your materials and method:'\n", - " pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n", - "\n", - " all_packages = ''\n", - " for requirement in freeze(local_only=True):\n", - " all_packages = all_packages+requirement+', '\n", - " #print(all_packages)\n", - "\n", - " #Main Packages\n", - " main_packages = ''\n", - " version_numbers = []\n", - " for name in ['tensorflow','numpy','Keras']:\n", - " find_name=all_packages.find(name)\n", - " main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n", - " #Version numbers only here:\n", - " version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n", - "\n", - " cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n", - " cuda_version = cuda_version.stdout.decode('utf-8')\n", - " cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n", - " gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n", - " gpu_name = gpu_name.stdout.decode('utf-8')\n", - " gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n", - " #print(cuda_version[cuda_version.find(', V')+3:-1])\n", - " #print(gpu_name)\n", - " if raw_data == True:\n", - " shape = (M,N)\n", - " else:\n", - " shape = (int(FOV_size/pixel_size),int(FOV_size/pixel_size))\n", - " #dataset_size = len(os.listdir(Training_source))\n", - "\n", - " text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(n_patches)+' paired image patches (image dimensions: '+str(patch_size)+', patch size (upsampled): ('+str(int(patch_size))+','+str(int(patch_size))+') with a batch size of '+str(batch_size)+', using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Losses were calculated using MSE for the heatmaps and L1 loss for the spike prediction. Key python packages used include tensorflow (v '+version_numbers[0]+'), numpy (v '+version_numbers[1]+'), Keras (v '+version_numbers[2]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+' GPU.'\n", - "\n", - " if pretrained_model:\n", - " text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(n_patches)+' paired image patches (image dimensions: '+str(patch_size)+', patch size (upsampled): ('+str(int(patch_size))+','+str(int(patch_size))+') with a batch size of '+str(batch_size)+', using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Losses were calculated using MSE for the heatmaps and L1 loss for the spike prediction. The models was retrained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), numpy (v '+version_numbers[1]+'), Keras (v '+version_numbers[2]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+' GPU.'\n", - "\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - " pdf.multi_cell(180, 5, txt = text, align='L')\n", - " pdf.ln(1)\n", - " pdf.set_font('')\n", - " pdf.set_font(\"Arial\", size = 11, style='B')\n", - " pdf.ln(1)\n", - " pdf.cell(190, 5, txt = 'Training dataset', align='L', ln=1)\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - " if raw_data==False:\n", - " simul_text = 'The training dataset was created in the notebook using the following simulation settings:'\n", - " pdf.cell(200, 5, txt=simul_text, align='L')\n", - " pdf.ln(1)\n", - " html = \"\"\" \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
SettingSimulated Value
FOV_size{0}
pixel_size{1}
ADC_per_photon_conversion{2}
ReadOutNoise_ADC{3}
ADC_offset{4}
emitter_density{5}
emitter_density_std{6}
number_of_frames{7}
sigma{8}
sigma_std{9}
n_photons{10}
n_photons_std{11}
\n", - " \"\"\".format(FOV_size, pixel_size, ADC_per_photon_conversion, ReadOutNoise_ADC, ADC_offset, emitter_density, emitter_density_std, number_of_frames, sigma, sigma_std, n_photons, n_photons_std)\n", - " pdf.write_html(html)\n", - " else:\n", - " simul_text = 'The training dataset was simulated using ThunderSTORM and loaded into the notebook.'\n", - " pdf.multi_cell(190, 5, txt=simul_text, align='L')\n", - " pdf.set_font(\"Arial\", size = 11, style='B')\n", - " #pdf.ln(1)\n", - " #pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.cell(29, 5, txt= 'ImageData_path', align = 'L', ln=0)\n", - " pdf.set_font('')\n", - " pdf.multi_cell(170, 5, txt = ImageData_path, align = 'L')\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.cell(28, 5, txt= 'LocalizationData_path:', align = 'L', ln=0)\n", - " pdf.set_font('')\n", - " pdf.multi_cell(170, 5, txt = LocalizationData_path, align = 'L')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.cell(28, 5, txt= 'pixel_size:', align = 'L', ln=0)\n", - " pdf.set_font('')\n", - " pdf.multi_cell(170, 5, txt = str(pixel_size), align = 'L')\n", - " #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n", - " pdf.set_font('Arial', size = 11, style = 'B')\n", - " pdf.ln(1)\n", - " pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - " # if Use_Default_Advanced_Parameters:\n", - " # pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n", - " pdf.cell(200, 5, txt='The following parameters were used to generate patches:')\n", - " pdf.ln(1)\n", - " html = \"\"\"\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \"\"\".format(str(patch_size)+'x'+str(patch_size), upsampling_factor, num_patches_per_frame, min_number_of_emitters_per_patch, max_num_patches, gaussian_sigma, Automatic_normalization, L2_weighting_factor)\n", - " pdf.write_html(html)\n", - " pdf.ln(3)\n", - " pdf.set_font('Arial', size=10)\n", - " pdf.cell(200, 5, txt='The following parameters were used for training:')\n", - " pdf.ln(1)\n", - " html = \"\"\" \n", - "
Patch ParameterValue
patch_size{0}
upsampling_factor{1}
num_patches_per_frame{2}
min_number_of_emitters_per_patch{3}
max_num_patches{4}
gaussian_sigma{5}
Automatic_normalization{6}
L2_weighting_factor{7}
\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
Training ParameterValue
number_of_epochs{0}
batch_size{1}
number_of_steps{2}
percentage_validation{3}
initial_learning_rate{4}
\n", - " \"\"\".format(number_of_epochs,batch_size,number_of_steps,percentage_validation,initial_learning_rate)\n", - " pdf.write_html(html)\n", - "\n", - " pdf.ln(1)\n", - " # pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.cell(21, 5, txt= 'Model Path:', align = 'L', ln=0)\n", - " pdf.set_font('')\n", - " pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n", - "\n", - " pdf.ln(1)\n", - " pdf.cell(60, 5, txt = 'Example Training Images', ln=1)\n", - " pdf.ln(1)\n", - " exp_size = io.imread('/content/TrainingDataExample_DeepSTORM2D.png').shape\n", - " pdf.image('/content/TrainingDataExample_DeepSTORM2D.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n", - " pdf.ln(1)\n", - " ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" bioRxiv (2020).'\n", - " pdf.multi_cell(190, 5, txt = ref_1, align='L')\n", - " ref_2 = '- Deep-STORM: Nehme, Elias, et al. \"Deep-STORM: super-resolution single-molecule microscopy by deep learning.\" Optica 5.4 (2018): 458-464.'\n", - " pdf.multi_cell(190, 5, txt = ref_2, align='L')\n", - " # if Use_Data_augmentation:\n", - " # ref_3 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\n", - " # pdf.multi_cell(190, 5, txt = ref_3, align='L')\n", - " pdf.ln(3)\n", - " reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n", - " pdf.set_font('Arial', size = 11, style='B')\n", - " pdf.multi_cell(190, 5, txt=reminder, align='C')\n", - "\n", - " pdf.output(model_path+'/'+model_name+'/'+model_name+'_training_report.pdf')\n", - " print('------------------------------')\n", - " print('PDF report exported in '+model_path+'/'+model_name+'/')\n", - "\n", - "def qc_pdf_export():\n", - " class MyFPDF(FPDF, HTMLMixin):\n", - " pass\n", - "\n", - " pdf = MyFPDF()\n", - " pdf.add_page()\n", - " pdf.set_right_margin(-1)\n", - " pdf.set_font(\"Arial\", size = 11, style='B') \n", - "\n", - " Network = 'Deep-STORM'\n", - " #model_name = os.path.basename(full_QC_model_path)\n", - " day = datetime.now()\n", - " datetime_str = str(day)[0:10]\n", - "\n", - " Header = 'Quality Control report for '+Network+' model ('+os.path.basename(QC_model_path)+')\\nDate: '+datetime_str\n", - " pdf.multi_cell(180, 5, txt = Header, align = 'L') \n", - "\n", - " all_packages = ''\n", - " for requirement in freeze(local_only=True):\n", - " all_packages = all_packages+requirement+', '\n", - "\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 11, style = 'B')\n", - " pdf.ln(2)\n", - " pdf.cell(190, 5, txt = 'Loss curves', ln=1, align='L')\n", - " pdf.ln(1)\n", - " if os.path.exists(savePath+'/lossCurvePlots.png'):\n", - " exp_size = io.imread(savePath+'/lossCurvePlots.png').shape\n", - " pdf.image(savePath+'/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/10))\n", - " else:\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size=10)\n", - " pdf.cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.')\n", - " pdf.ln(2)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.ln(3)\n", - " pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n", - " pdf.ln(1)\n", - " exp_size = io.imread(savePath+'/QC_example_data.png').shape\n", - " pdf.image(savePath+'/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n", - " pdf.ln(1)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 11, style = 'B')\n", - " pdf.ln(1)\n", - " pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - "\n", - " pdf.ln(1)\n", - " html = \"\"\"\n", - " \n", - " \n", - " \"\"\"\n", - " with open(savePath+'/'+os.path.basename(QC_model_path)+'_QC_metrics.csv', 'r') as csvfile:\n", - " metrics = csv.reader(csvfile)\n", - " header = next(metrics)\n", - " image = header[0]\n", - " mSSIM_PvsGT = header[1]\n", - " mSSIM_SvsGT = header[2]\n", - " NRMSE_PvsGT = header[3]\n", - " NRMSE_SvsGT = header[4]\n", - " PSNR_PvsGT = header[5]\n", - " PSNR_SvsGT = header[6]\n", - " header = \"\"\"\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \"\"\".format(image,mSSIM_PvsGT,mSSIM_SvsGT,NRMSE_PvsGT,NRMSE_SvsGT,PSNR_PvsGT,PSNR_SvsGT)\n", - " html = html+header\n", - " for row in metrics:\n", - " image = row[0]\n", - " mSSIM_PvsGT = row[1]\n", - " mSSIM_SvsGT = row[2]\n", - " NRMSE_PvsGT = row[3]\n", - " NRMSE_SvsGT = row[4]\n", - " PSNR_PvsGT = row[5]\n", - " PSNR_SvsGT = row[6]\n", - " cells = \"\"\"\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \"\"\".format(image,str(round(float(mSSIM_PvsGT),3)),str(round(float(mSSIM_SvsGT),3)),str(round(float(NRMSE_PvsGT),3)),str(round(float(NRMSE_SvsGT),3)),str(round(float(PSNR_PvsGT),3)),str(round(float(PSNR_SvsGT),3)))\n", - " html = html+cells\n", - " html = html+\"\"\"
{0}{1}{2}{3}{4}{5}{6}
{0}{1}{2}{3}{4}{5}{6}
\"\"\"\n", - " \n", - " pdf.write_html(html)\n", - "\n", - " pdf.ln(1)\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - " ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" bioRxiv (2020).'\n", - " pdf.multi_cell(190, 5, txt = ref_1, align='L')\n", - " ref_2 = '- Deep-STORM: Nehme, Elias, et al. \"Deep-STORM: super-resolution single-molecule microscopy by deep learning.\" Optica 5.4 (2018): 458-464.'\n", - " pdf.multi_cell(190, 5, txt = ref_2, align='L')\n", - "\n", - " pdf.ln(3)\n", - " reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n", - "\n", - " pdf.set_font('Arial', size = 11, style='B')\n", - " pdf.multi_cell(190, 5, txt=reminder, align='C')\n", - "\n", - " pdf.output(savePath+'/'+os.path.basename(QC_model_path)+'_QC_report.pdf')\n", - "\n", - "\n", - " print('------------------------------')\n", - " print('QC PDF report exported as '+savePath+'/'+os.path.basename(QC_model_path)+'_QC_report.pdf')\n", - "\n", - "# Build requirements file for local run\n", - "after = [str(m) for m in sys.modules]\n", - "build_requirements_file(before, after)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "vu8f5NGJkJos" - }, - "source": [ - "\n", - "# **3. Generate patches for training**\n", - "---\n", - "\n", - "For Deep-STORM the training data can be obtained in two ways:\n", - "* Simulated using ThunderSTORM or other simulation tool and loaded here (**using Section 3.1.a**)\n", - "* Directly simulated in this notebook (**using Section 3.1.b**)\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "WSV8xnlynp0l" - }, - "source": [ - "## **3.1.a Load training data**\n", - "---\n", - "\n", - "Here you can load your simulated data along with its corresponding localization file.\n", - "* The `pixel_size` is defined in nanometer (nm). " - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "CT6SNcfNg6j0", - "cellView": "form" - }, - "source": [ - "#@markdown ##Load raw data\n", - "\n", - "load_raw_data = True\n", - "\n", - "# Get user input\n", - "ImageData_path = \"\" #@param {type:\"string\"}\n", - "LocalizationData_path = \"\" #@param {type: \"string\"}\n", - "#@markdown Get pixel size from file?\n", - "get_pixel_size_from_file = True #@param {type:\"boolean\"}\n", - "#@markdown Otherwise, use this value:\n", - "pixel_size = 100 #@param {type:\"number\"}\n", - "\n", - "if get_pixel_size_from_file:\n", - " pixel_size,_,_ = getPixelSizeTIFFmetadata(ImageData_path, True)\n", - "\n", - "# load the tiff data\n", - "Images = io.imread(ImageData_path)\n", - "# get dataset dimensions\n", - "if len(Images.shape) == 3:\n", - " (number_of_frames, M, N) = Images.shape\n", - "elif len(Images.shape) == 2:\n", - " (M, N) = Images.shape\n", - " number_of_frames = 1\n", - "print('Loaded images: '+str(M)+'x'+str(N)+' with '+str(number_of_frames)+' frames')\n", - "\n", - "# Interactive display of the stack\n", - "def scroll_in_time(frame):\n", - " f=plt.figure(figsize=(6,6))\n", - " plt.imshow(Images[frame-1], interpolation='nearest', cmap = 'gray')\n", - " plt.title('Training source at frame = ' + str(frame))\n", - " plt.axis('off');\n", - "\n", - "if number_of_frames > 1:\n", - " interact(scroll_in_time, frame=widgets.IntSlider(min=1, max=Images.shape[0], step=1, value=0, continuous_update=False));\n", - "else:\n", - " f=plt.figure(figsize=(6,6))\n", - " plt.imshow(Images, interpolation='nearest', cmap = 'gray')\n", - " plt.title('Training source')\n", - " plt.axis('off');\n", - "\n", - "# Load the localization file and display the first\n", - "LocData = pd.read_csv(LocalizationData_path, index_col=0)\n", - "LocData.tail()\n", - "\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "K9xE5GeYiks9" - }, - "source": [ - "## **3.1.b Simulate training data**\n", - "---\n", - "This simulation tool allows you to generate SMLM data of randomly distrubuted emitters in a field-of-view. \n", - "The assumptions are as follows:\n", - "\n", - "* Gaussian Point Spread Function (PSF) with standard deviation defined by `Sigma`. The nominal value of `sigma` can be evaluated using `sigma = 0.21 x Lambda / NA`. (from [Zhang *et al.*, Applied Optics 2007](https://doi.org/10.1364/AO.46.001819))\n", - "* Each emitter will emit `n_photons` per frame, and generate their equivalent Poisson noise.\n", - "* The camera will contribute Gaussian noise to the signal with a standard deviation defined by `ReadOutNoise_ADC` in ADC\n", - "* The `emitter_density` is defined as the number of emitters / um^2 on any given frame. Variability in the emitter density can be applied by adjusting `emitter_density_std`. The latter parameter represents the standard deviation of the normal distribution that the density is drawn from for each individual frame. `emitter_density` **is defined in number of emitters / um^2**.\n", - "* The `n_photons` and `sigma` can additionally include some Gaussian variability by setting `n_photons_std` and `sigma_std`.\n", - "\n", - "Important note:\n", - "- All dimensions are in nanometer (e.g. `FOV_size` = 6400 represents a field of view of 6.4 um x 6.4 um).\n", - "\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "sQyLXpEhitsg", - "cellView": "form" - }, - "source": [ - "load_raw_data = False\n", - "\n", - "# ---------------------------- User input ----------------------------\n", - "#@markdown Run the simulation\n", - "#@markdown --- \n", - "#@markdown Camera settings: \n", - "FOV_size = 6400#@param {type:\"number\"}\n", - "pixel_size = 100#@param {type:\"number\"}\n", - "ADC_per_photon_conversion = 1 #@param {type:\"number\"}\n", - "ReadOutNoise_ADC = 4.5#@param {type:\"number\"}\n", - "ADC_offset = 50#@param {type:\"number\"}\n", - "\n", - "#@markdown Acquisition settings: \n", - "emitter_density = 6#@param {type:\"number\"}\n", - "emitter_density_std = 0#@param {type:\"number\"}\n", - "\n", - "number_of_frames = 20#@param {type:\"integer\"}\n", - "\n", - "sigma = 110 #@param {type:\"number\"}\n", - "sigma_std = 5 #@param {type:\"number\"}\n", - "# NA = 1.1 #@param {type:\"number\"}\n", - "# wavelength = 800#@param {type:\"number\"}\n", - "# wavelength_std = 150#@param {type:\"number\"}\n", - "n_photons = 2250#@param {type:\"number\"}\n", - "n_photons_std = 250#@param {type:\"number\"}\n", - "\n", - "\n", - "# ---------------------------- Variable initialisation ----------------------------\n", - "# Start the clock to measure how long it takes\n", - "start = time.time()\n", - "\n", - "print('-----------------------------------------------------------')\n", - "n_molecules = emitter_density*FOV_size*FOV_size/10**6\n", - "n_molecules_std = emitter_density_std*FOV_size*FOV_size/10**6\n", - "print('Number of molecules / FOV: '+str(round(n_molecules,2))+' +/- '+str((round(n_molecules_std,2))))\n", - "\n", - "# sigma = 0.21*wavelength/NA\n", - "# sigma_std = 0.21*wavelength_std/NA\n", - "# print('Gaussian PSF sigma: '+str(round(sigma,2))+' +/- '+str(round(sigma_std,2))+' nm')\n", - "\n", - "M = N = round(FOV_size/pixel_size)\n", - "FOV_size = M*pixel_size\n", - "print('Final image size: '+str(M)+'x'+str(M)+' ('+str(round(FOV_size/1000, 3))+'um x'+str(round(FOV_size/1000,3))+' um)')\n", - "\n", - "np.random.seed(1)\n", - "display_upsampling = 8 # used to display the loc map here\n", - "NoiseFreeImages = np.zeros((number_of_frames, M, M))\n", - "locImage = np.zeros((number_of_frames, display_upsampling*M, display_upsampling*N))\n", - "\n", - "frames = []\n", - "all_xloc = []\n", - "all_yloc = []\n", - "all_photons = []\n", - "all_sigmas = []\n", - "\n", - "# ---------------------------- Main simulation loop ----------------------------\n", - "print('-----------------------------------------------------------')\n", - "for f in tqdm(range(number_of_frames)):\n", - " \n", - " # Define the coordinates of emitters by randomly distributing them across the FOV\n", - " n_mol = int(max(round(np.random.normal(n_molecules, n_molecules_std, size=1)[0]), 0))\n", - " x_c = np.random.uniform(low=0.0, high=FOV_size, size=n_mol)\n", - " y_c = np.random.uniform(low=0.0, high=FOV_size, size=n_mol)\n", - " photon_array = np.random.normal(n_photons, n_photons_std, size=n_mol)\n", - " sigma_array = np.random.normal(sigma, sigma_std, size=n_mol)\n", - " # x_c = np.linspace(0,3000,5)\n", - " # y_c = np.linspace(0,3000,5)\n", - "\n", - " all_xloc += x_c.tolist()\n", - " all_yloc += y_c.tolist()\n", - " frames += ((f+1)*np.ones(x_c.shape[0])).tolist()\n", - " all_photons += photon_array.tolist()\n", - " all_sigmas += sigma_array.tolist()\n", - "\n", - " locImage[f] = FromLoc2Image_SimpleHistogram(x_c, y_c, image_size = (N*display_upsampling, M*display_upsampling), pixel_size = pixel_size/display_upsampling)\n", - "\n", - " # # Get the approximated locations according to the grid pixel size\n", - " # Chr_emitters = [int(max(min(round(display_upsampling*x_c[i]/pixel_size),N*display_upsampling-1),0)) for i in range(len(x_c))]\n", - " # Rhr_emitters = [int(max(min(round(display_upsampling*y_c[i]/pixel_size),M*display_upsampling-1),0)) for i in range(len(y_c))]\n", - "\n", - " # # Build Localization image\n", - " # for (r,c) in zip(Rhr_emitters, Chr_emitters):\n", - " # locImage[f][r][c] += 1\n", - "\n", - " NoiseFreeImages[f] = FromLoc2Image_Erf(x_c, y_c, photon_array, sigma_array, image_size = (M,M), pixel_size = pixel_size)\n", - "\n", - "\n", - "# ---------------------------- Create DataFrame fof localization file ----------------------------\n", - "# Table with localization info as dataframe output\n", - "LocData = pd.DataFrame()\n", - "LocData[\"frame\"] = frames\n", - "LocData[\"x [nm]\"] = all_xloc\n", - "LocData[\"y [nm]\"] = all_yloc\n", - "LocData[\"Photon #\"] = all_photons\n", - "LocData[\"Sigma [nm]\"] = all_sigmas\n", - "LocData.index += 1 # set indices to start at 1 and not 0 (same as ThunderSTORM)\n", - "\n", - "\n", - "# ---------------------------- Estimation of SNR ----------------------------\n", - "n_frames_for_SNR = 100\n", - "M_SNR = 10\n", - "x_c = np.random.uniform(low=0.0, high=pixel_size*M_SNR, size=n_frames_for_SNR)\n", - "y_c = np.random.uniform(low=0.0, high=pixel_size*M_SNR, size=n_frames_for_SNR)\n", - "photon_array = np.random.normal(n_photons, n_photons_std, size=n_frames_for_SNR)\n", - "sigma_array = np.random.normal(sigma, sigma_std, size=n_frames_for_SNR)\n", - "\n", - "SNR = np.zeros(n_frames_for_SNR)\n", - "for i in range(n_frames_for_SNR):\n", - " SingleEmitterImage = FromLoc2Image_Erf(np.array([x_c[i]]), np.array([x_c[i]]), np.array([photon_array[i]]), np.array([sigma_array[i]]), (M_SNR, M_SNR), pixel_size)\n", - " Signal_photon = np.max(SingleEmitterImage)\n", - " Noise_photon = math.sqrt((ReadOutNoise_ADC/ADC_per_photon_conversion)**2 + Signal_photon)\n", - " SNR[i] = Signal_photon/Noise_photon\n", - "\n", - "print('SNR: '+str(round(np.mean(SNR),2))+' +/- '+str(round(np.std(SNR),2)))\n", - "# ---------------------------- ----------------------------\n", - "\n", - "\n", - "# Table with info\n", - "simParameters = pd.DataFrame()\n", - "simParameters[\"FOV size (nm)\"] = [FOV_size]\n", - "simParameters[\"Pixel size (nm)\"] = [pixel_size]\n", - "simParameters[\"ADC/photon\"] = [ADC_per_photon_conversion]\n", - "simParameters[\"Read-out noise (ADC)\"] = [ReadOutNoise_ADC]\n", - "simParameters[\"Constant offset (ADC)\"] = [ADC_offset]\n", - "\n", - "simParameters[\"Emitter density (emitters/um^2)\"] = [emitter_density]\n", - "simParameters[\"STD of emitter density (emitters/um^2)\"] = [emitter_density_std]\n", - "simParameters[\"Number of frames\"] = [number_of_frames]\n", - "# simParameters[\"NA\"] = [NA]\n", - "# simParameters[\"Wavelength (nm)\"] = [wavelength]\n", - "# simParameters[\"STD of wavelength (nm)\"] = [wavelength_std]\n", - "simParameters[\"Sigma (nm))\"] = [sigma]\n", - "simParameters[\"STD of Sigma (nm))\"] = [sigma_std]\n", - "simParameters[\"Number of photons\"] = [n_photons]\n", - "simParameters[\"STD of number of photons\"] = [n_photons_std]\n", - "simParameters[\"SNR\"] = [np.mean(SNR)]\n", - "simParameters[\"STD of SNR\"] = [np.std(SNR)]\n", - "\n", - "\n", - "# ---------------------------- Finish simulation ----------------------------\n", - "# Calculating the noisy image\n", - "Images = ADC_per_photon_conversion * np.random.poisson(NoiseFreeImages) + ReadOutNoise_ADC * np.random.normal(size = (number_of_frames, M, N)) + ADC_offset\n", - "Images[Images <= 0] = 0\n", - "\n", - "# Convert to 16-bit or 32-bits integers\n", - "if Images.max() < (2**16-1):\n", - " Images = Images.astype(np.uint16)\n", - "else:\n", - " Images = Images.astype(np.uint32)\n", - "\n", - "\n", - "# ---------------------------- Display ----------------------------\n", - "# Displaying the time elapsed for simulation\n", - "dt = time.time() - start\n", - "minutes, seconds = divmod(dt, 60) \n", - "hours, minutes = divmod(minutes, 60) \n", - "print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds,1),\"sec(s)\")\n", - "\n", - "\n", - "# Interactively display the results using Widgets\n", - "def scroll_in_time(frame):\n", - " f = plt.figure(figsize=(18,6))\n", - " plt.subplot(1,3,1)\n", - " plt.imshow(locImage[frame-1], interpolation='bilinear', vmin = 0, vmax=0.1)\n", - " plt.title('Localization image')\n", - " plt.axis('off');\n", - "\n", - " plt.subplot(1,3,2)\n", - " plt.imshow(NoiseFreeImages[frame-1], interpolation='nearest', cmap='gray')\n", - " plt.title('Noise-free simulation')\n", - " plt.axis('off');\n", - "\n", - " plt.subplot(1,3,3)\n", - " plt.imshow(Images[frame-1], interpolation='nearest', cmap='gray')\n", - " plt.title('Noisy simulation')\n", - " plt.axis('off');\n", - "\n", - "interact(scroll_in_time, frame=widgets.IntSlider(min=1, max=Images.shape[0], step=1, value=0, continuous_update=False));\n", - "\n", - "# Display the head of the dataframe with localizations\n", - "LocData.tail()\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "Pz7RfSuoeJeq", - "cellView": "form" - }, - "source": [ - "# @markdown ---\n", - "# @markdown #Play this cell to save the simulated stack\n", - "# @markdown ####Please select a path to the folder where to save the simulated data. It is not necessary to save the data to run the training, but keeping the simulated for your own record can be useful to check its validity.\n", - "Save_path = \"\" #@param {type:\"string\"}\n", - "\n", - "if not os.path.exists(Save_path):\n", - " os.makedirs(Save_path)\n", - " print('Folder created.')\n", - "else:\n", - " print('Training data already exists in folder: Data overwritten.')\n", - "\n", - "saveAsTIF(Save_path, 'SimulatedDataset', Images, pixel_size)\n", - "# io.imsave(os.path.join(Save_path, 'SimulatedDataset.tif'),Images)\n", - "LocData.to_csv(os.path.join(Save_path, 'SimulatedDataset.csv'))\n", - "simParameters.to_csv(os.path.join(Save_path, 'SimulatedParameters.csv'))\n", - "print('Training dataset saved.')" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "K_8e3kE-JhVY" - }, - "source": [ - "## **3.2. Generate training patches**\n", - "---\n", - "\n", - "Training patches need to be created from the training data generated above. \n", - "* The `patch_size` needs to give sufficient contextual information and for most cases a `patch_size` of 26 (corresponding to patches of 26x26 pixels) works fine. **DEFAULT: 26**\n", - "* The `upsampling_factor` defines the effective magnification of the final super-resolved image compared to the input image (this is called magnification in ThunderSTORM). This is used to generate the super-resolved patches as target dataset. Using an `upsampling_factor` of 16 will require the use of more memory and it may be necessary to decreae the `patch_size` to 16 for example. **DEFAULT: 8**\n", - "* The `num_patches_per_frame` defines the number of patches extracted from each frame generated in section 3.1. **DEFAULT: 500**\n", - "* The `min_number_of_emitters_per_patch` defines the minimum number of emitters that need to be present in the patch to be a valid patch. An empty patch does not contain useful information for the network to learn from. **DEFAULT: 7**\n", - "* The `max_num_patches` defines the maximum number of patches to generate. Fewer may be generated depending on how many pacthes are rejected and how many frames are available. **DEFAULT: 10000**\n", - "* The `gaussian_sigma` defines the Gaussian standard deviation (in magnified pixels) applied to generate the super-resolved target image. **DEFAULT: 1**\n", - "* The `L2_weighting_factor` is a normalization factor used in the loss function. It helps balancing the loss from the L2 norm. When using higher densities, this factor should be decreased and vice-versa. This factor can be autimatically calculated using an empiraical formula. **DEFAULT: 100**\n", - "\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "AsNx5KzcFNvC", - "cellView": "form" - }, - "source": [ - "#@markdown ## **Provide patch parameters**\n", - "\n", - "\n", - "# -------------------- User input --------------------\n", - "patch_size = 26 #@param {type:\"integer\"}\n", - "upsampling_factor = 8 #@param [\"4\", \"8\", \"16\"] {type:\"raw\"}\n", - "num_patches_per_frame = 500#@param {type:\"integer\"}\n", - "min_number_of_emitters_per_patch = 7#@param {type:\"integer\"}\n", - "max_num_patches = 10000#@param {type:\"integer\"}\n", - "gaussian_sigma = 1#@param {type:\"integer\"}\n", - "\n", - "#@markdown Estimate the optimal normalization factor automatically?\n", - "Automatic_normalization = True #@param {type:\"boolean\"}\n", - "#@markdown Otherwise, it will use the following value:\n", - "L2_weighting_factor = 100 #@param {type:\"number\"}\n", - "\n", - "\n", - "# -------------------- Prepare variables --------------------\n", - "# Start the clock to measure how long it takes\n", - "start = time.time()\n", - "\n", - "# Initialize some parameters\n", - "pixel_size_hr = pixel_size/upsampling_factor # in nm\n", - "n_patches = min(number_of_frames*num_patches_per_frame, max_num_patches)\n", - "patch_size = patch_size*upsampling_factor\n", - "\n", - "# Dimensions of the high-res grid\n", - "Mhr = upsampling_factor*M # in pixels\n", - "Nhr = upsampling_factor*N # in pixels\n", - "\n", - "# Initialize the training patches and labels\n", - "patches = np.zeros((n_patches, patch_size, patch_size), dtype = np.float32)\n", - "spikes = np.zeros((n_patches, patch_size, patch_size), dtype = np.float32)\n", - "heatmaps = np.zeros((n_patches, patch_size, patch_size), dtype = np.float32)\n", - "\n", - "# Run over all frames and construct the training examples\n", - "k = 1 # current patch count\n", - "skip_counter = 0 # number of dataset skipped due to low density\n", - "id_start = 0 # id position in LocData for current frame\n", - "print('Generating '+str(n_patches)+' patches of '+str(patch_size)+'x'+str(patch_size))\n", - "\n", - "n_locs = len(LocData.index)\n", - "print('Total number of localizations: '+str(n_locs))\n", - "density = n_locs/(M*N*number_of_frames*(0.001*pixel_size)**2)\n", - "print('Density: '+str(round(density,2))+' locs/um^2')\n", - "n_locs_per_patch = patch_size**2*density\n", - "\n", - "if Automatic_normalization:\n", - " # This empirical formulae attempts to balance the loss L2 function between the background and the bright spikes\n", - " # A value of 100 was originally chosen to balance L2 for a patch size of 2.6x2.6^2 0.1um pixel size and density of 3 (hence the 20.28), at upsampling_factor = 8\n", - " L2_weighting_factor = 100/math.sqrt(min(n_locs_per_patch, min_number_of_emitters_per_patch)*8**2/(upsampling_factor**2*20.28))\n", - " print('Normalization factor: '+str(round(L2_weighting_factor,2)))\n", - "\n", - "# -------------------- Patch generation loop --------------------\n", - "\n", - "print('-----------------------------------------------------------')\n", - "for (f, thisFrame) in enumerate(tqdm(Images)):\n", - "\n", - " # Upsample the frame\n", - " upsampledFrame = np.kron(thisFrame, np.ones((upsampling_factor,upsampling_factor)))\n", - " # Read all the provided high-resolution locations for current frame\n", - " DataFrame = LocData[LocData['frame'] == f+1].copy()\n", - "\n", - " # Get the approximated locations according to the high-res grid pixel size\n", - " Chr_emitters = [int(max(min(round(DataFrame['x [nm]'][i]/pixel_size_hr),Nhr-1),0)) for i in range(id_start+1,id_start+1+len(DataFrame.index))]\n", - " Rhr_emitters = [int(max(min(round(DataFrame['y [nm]'][i]/pixel_size_hr),Mhr-1),0)) for i in range(id_start+1,id_start+1+len(DataFrame.index))]\n", - " id_start += len(DataFrame.index)\n", - "\n", - " # Build Localization image\n", - " LocImage = np.zeros((Mhr,Nhr))\n", - " LocImage[(Rhr_emitters, Chr_emitters)] = 1\n", - "\n", - " # Here, there's a choice between the original Gaussian (classification approach) and using the erf function\n", - " HeatMapImage = L2_weighting_factor*gaussian_filter(LocImage, gaussian_sigma) \n", - " # HeatMapImage = L2_weighting_factor*FromLoc2Image_MultiThreaded(np.array(list(DataFrame['x [nm]'])), np.array(list(DataFrame['y [nm]'])), \n", - " # np.ones(len(DataFrame.index)), pixel_size_hr*gaussian_sigma*np.ones(len(DataFrame.index)), \n", - " # Mhr, pixel_size_hr)\n", - " \n", - "\n", - " # Generate random position for the top left corner of the patch\n", - " xc = np.random.randint(0, Mhr-patch_size, size=num_patches_per_frame)\n", - " yc = np.random.randint(0, Nhr-patch_size, size=num_patches_per_frame)\n", - "\n", - " for c in range(len(xc)):\n", - " if LocImage[xc[c]:xc[c]+patch_size, yc[c]:yc[c]+patch_size].sum() < min_number_of_emitters_per_patch:\n", - " skip_counter += 1\n", - " continue\n", - " \n", - " else:\n", - " # Limit maximal number of training examples to 15k\n", - " if k > max_num_patches:\n", - " break\n", - " else:\n", - " # Assign the patches to the right part of the images\n", - " patches[k-1] = upsampledFrame[xc[c]:xc[c]+patch_size, yc[c]:yc[c]+patch_size]\n", - " spikes[k-1] = LocImage[xc[c]:xc[c]+patch_size, yc[c]:yc[c]+patch_size]\n", - " heatmaps[k-1] = HeatMapImage[xc[c]:xc[c]+patch_size, yc[c]:yc[c]+patch_size]\n", - " k += 1 # increment current patch count\n", - "\n", - "# Remove the empty data\n", - "patches = patches[:k-1]\n", - "spikes = spikes[:k-1]\n", - "heatmaps = heatmaps[:k-1]\n", - "n_patches = k-1\n", - "\n", - "# -------------------- Failsafe --------------------\n", - "# Check if the size of the training set is smaller than 5k to notify user to simulate more images using ThunderSTORM\n", - "if ((k-1) < 5000):\n", - " # W = '\\033[0m' # white (normal)\n", - " # R = '\\033[31m' # red\n", - " print(bcolors.WARNING+'!! WARNING: Training set size is below 5K - Consider simulating more images in ThunderSTORM. !!'+bcolors.NORMAL)\n", - "\n", - "\n", - "\n", - "# -------------------- Displays --------------------\n", - "print('Number of patches skipped due to low density: '+str(skip_counter))\n", - "# dataSize = int((getsizeof(patches)+getsizeof(heatmaps)+getsizeof(spikes))/(1024*1024)) #rounded in MB\n", - "# print('Size of patches: '+str(dataSize)+' MB')\n", - "print(str(n_patches)+' patches were generated.')\n", - "\n", - "# Displaying the time elapsed for training\n", - "dt = time.time() - start\n", - "minutes, seconds = divmod(dt, 60) \n", - "hours, minutes = divmod(minutes, 60) \n", - "print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n", - "\n", - "# Display patches interactively with a slider\n", - "def scroll_patches(patch):\n", - " f = plt.figure(figsize=(16,6))\n", - " plt.subplot(1,3,1)\n", - " plt.imshow(patches[patch-1], interpolation='nearest', cmap='gray')\n", - " plt.title('Raw data (frame #'+str(patch)+')')\n", - " plt.axis('off');\n", - "\n", - " plt.subplot(1,3,2)\n", - " plt.imshow(heatmaps[patch-1], interpolation='nearest')\n", - " plt.title('Heat map')\n", - " plt.axis('off');\n", - "\n", - " plt.subplot(1,3,3)\n", - " plt.imshow(spikes[patch-1], interpolation='nearest')\n", - " plt.title('Localization map')\n", - " plt.axis('off');\n", - " \n", - " plt.savefig('/content/TrainingDataExample_DeepSTORM2D.png',bbox_inches='tight',pad_inches=0)\n", - "\n", - "\n", - "interact(scroll_patches, patch=widgets.IntSlider(min=1, max=patches.shape[0], step=1, value=0, continuous_update=False));\n", - "\n", - "\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "DSjXFMevK7Iz" - }, - "source": [ - "# **4. Train the network**\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "hVeyKU0MdAPx" - }, - "source": [ - "## **4.1. Select your paths and parameters**\n", - "\n", - "---\n", - "\n", - "**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n", - "\n", - "**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n", - "\n", - "\n", - "**Training parameters**\n", - "\n", - "**`number_of_epochs`:**Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for ~100 epochs. Evaluate the performance after training (see 5). **Default value: 80**\n", - "\n", - "**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 16**\n", - "\n", - "**`number_of_steps`:** Define the number of training steps by epoch. **If this value is set to 0**, by default this parameter is calculated so that each patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n", - "\n", - "**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during training. **Default value: 30** \n", - "\n", - "**`initial_learning_rate`:** This parameter represents the initial value to be used as learning rate in the optimizer. **Default value: 0.001**" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "oa5cDZ7f_PF6", - "cellView": "form" - }, - "source": [ - "#@markdown ###Path to training images and parameters\n", - "\n", - "model_path = \"\" #@param {type: \"string\"} \n", - "model_name = \"\" #@param {type: \"string\"} \n", - "number_of_epochs = 200#@param {type:\"integer\"}\n", - "batch_size = 16#@param {type:\"integer\"}\n", - "\n", - "number_of_steps = 0#@param {type:\"integer\"}\n", - "percentage_validation = 30 #@param {type:\"number\"}\n", - "initial_learning_rate = 0.0001 #@param {type:\"number\"}\n", - "\n", - "\n", - "percentage_validation /= 100\n", - "if number_of_steps == 0: \n", - " number_of_steps = int((1-percentage_validation)*n_patches/batch_size)\n", - " print('Number of steps: '+str(number_of_steps))\n", - "\n", - "# Pretrained model path initialised here so next cell does not need to be run\n", - "h5_file_path = ''\n", - "Use_pretrained_model = False\n", - "\n", - "if not ('patches' in locals()):\n", - " # W = '\\033[0m' # white (normal)\n", - " # R = '\\033[31m' # red\n", - " print(WARNING+'!! WARNING: No patches were found in memory currently. !!')\n", - "\n", - "Save_path = os.path.join(model_path, model_name)\n", - "if os.path.exists(Save_path):\n", - " print(bcolors.WARNING+'The model folder already exists and will be overwritten.'+bcolors.NORMAL)\n", - "\n", - "print('-----------------------------')\n", - "print('Training parameters set.')\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "WIyEvQBWLp9n" - }, - "source": [ - "\n", - "## **4.2. Using weights from a pre-trained model as initial weights**\n", - "---\n", - " Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a Deep-STORM 2D model**. \n", - "\n", - " This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n", - "\n", - " In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. " - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "oHL5g0w8LqR0", - "cellView": "form" - }, - "source": [ - "# @markdown ##Loading weights from a pre-trained network\n", - "\n", - "Use_pretrained_model = False #@param {type:\"boolean\"}\n", - "pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n", - "Weights_choice = \"best\" #@param [\"last\", \"best\"]\n", - "\n", - "#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n", - "pretrained_model_path = \"\" #@param {type:\"string\"}\n", - "\n", - "# --------------------- Check if we load a previously trained model ------------------------\n", - "if Use_pretrained_model:\n", - "\n", - "# --------------------- Load the model from the choosen path ------------------------\n", - " if pretrained_model_choice == \"Model_from_file\":\n", - " h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".hdf5\")\n", - "\n", - "# --------------------- Download the a model provided in the XXX ------------------------\n", - "\n", - " if pretrained_model_choice == \"Model_name\":\n", - " pretrained_model_name = \"Model_name\"\n", - " pretrained_model_path = \"/content/\"+pretrained_model_name\n", - " print(\"Downloading the 2D_Demo_Model_from_Stardist_2D_paper\")\n", - " if os.path.exists(pretrained_model_path):\n", - " shutil.rmtree(pretrained_model_path)\n", - " os.makedirs(pretrained_model_path)\n", - " wget.download(\"\", pretrained_model_path)\n", - " wget.download(\"\", pretrained_model_path)\n", - " wget.download(\"\", pretrained_model_path) \n", - " wget.download(\"\", pretrained_model_path)\n", - " h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".hdf5\")\n", - "\n", - "# --------------------- Add additional pre-trained models here ------------------------\n", - "\n", - "\n", - "\n", - "# --------------------- Check the model exist ------------------------\n", - "# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n", - " if not os.path.exists(h5_file_path):\n", - " print(bcolors.WARNING+'WARNING: weights_'+Weights_choice+'.hdf5 pretrained model does not exist'+bcolors.NORMAL)\n", - " Use_pretrained_model = False\n", - "\n", - " \n", - "# If the model path contains a pretrain model, we load the training rate, \n", - " if os.path.exists(h5_file_path):\n", - "#Here we check if the learning rate can be loaded from the quality control folder\n", - " if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n", - " with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n", - " csvRead = pd.read_csv(csvfile, sep=',')\n", - " #print(csvRead)\n", - " if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n", - " print(\"pretrained network learning rate found\")\n", - " #find the last learning rate\n", - " lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n", - " #Find the learning rate corresponding to the lowest validation loss\n", - " min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n", - " #print(min_val_loss)\n", - " bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n", - " if Weights_choice == \"last\":\n", - " print('Last learning rate: '+str(lastLearningRate))\n", - " if Weights_choice == \"best\":\n", - " print('Learning rate of best validation loss: '+str(bestLearningRate))\n", - " if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n", - " bestLearningRate = initial_learning_rate\n", - " lastLearningRate = initial_learning_rate\n", - " print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead.'+bcolors.NORMAL)\n", - "\n", - "#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n", - " if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n", - " print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+bcolors.NORMAL)\n", - " bestLearningRate = initial_learning_rate\n", - " lastLearningRate = initial_learning_rate\n", - "\n", - "\n", - "# Display info about the pretrained model to be loaded (or not)\n", - "if Use_pretrained_model:\n", - " print('Weights found in:')\n", - " print(h5_file_path)\n", - " print('will be loaded prior to training.')\n", - "\n", - "else:\n", - " print('No pretrained network will be used.')\n", - " h5_file_path = ''\n", - "\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "OADNcie-LHxA" - }, - "source": [ - "## **4.4. Start Training**\n", - "---\n", - "When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n", - "\n", - "* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches.\n", - "\n", - "Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder from Google Drive as all data can be erased at the next training if using the same folder." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "qDgMu_mAK8US", - "cellView": "form" - }, - "source": [ - "#@markdown ##Start training\n", - "\n", - "# Start the clock to measure how long it takes\n", - "start = time.time()\n", - "\n", - "# --------------------- Using pretrained model ------------------------\n", - "#Here we ensure that the learning rate set correctly when using pre-trained models\n", - "if Use_pretrained_model:\n", - " if Weights_choice == \"last\":\n", - " initial_learning_rate = lastLearningRate\n", - "\n", - " if Weights_choice == \"best\": \n", - " initial_learning_rate = bestLearningRate\n", - "# --------------------- ---------------------- ------------------------\n", - "\n", - "\n", - "#here we check that no model with the same name already exist, if so delete\n", - "if os.path.exists(Save_path):\n", - " shutil.rmtree(Save_path)\n", - "\n", - "# Create the model folder!\n", - "os.makedirs(Save_path)\n", - "\n", - "# Export pdf summary \n", - "pdf_export(raw_data = load_raw_data, pretrained_model = Use_pretrained_model)\n", - "\n", - "# Let's go !\n", - "train_model(patches, heatmaps, Save_path, \n", - " steps_per_epoch=number_of_steps, epochs=number_of_epochs, batch_size=batch_size,\n", - " upsampling_factor = upsampling_factor,\n", - " validation_split = percentage_validation,\n", - " initial_learning_rate = initial_learning_rate, \n", - " pretrained_model_path = h5_file_path,\n", - " L2_weighting_factor = L2_weighting_factor)\n", - "\n", - "# # Show info about the GPU memory useage\n", - "# !nvidia-smi\n", - "\n", - "# Displaying the time elapsed for training\n", - "dt = time.time() - start\n", - "minutes, seconds = divmod(dt, 60) \n", - "hours, minutes = divmod(minutes, 60) \n", - "print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n", - "\n", - "# export pdf after training to update the existing document\n", - "pdf_export(trained = True, raw_data = load_raw_data, pretrained_model = Use_pretrained_model)\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "4N7-ShZpLhwr" - }, - "source": [ - "# **5. Evaluate your model**\n", - "---\n", - "\n", - "This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n", - "\n", - "**We highly recommend to perform quality control on all newly trained models.**" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "JDRsm7uKoBa-", - "cellView": "form" - }, - "source": [ - "# model name and path\n", - "#@markdown ###Do you want to assess the model you just trained ?\n", - "Use_the_current_trained_model = True #@param {type:\"boolean\"}\n", - "\n", - "#@markdown ###If not, please provide the path to the model folder:\n", - "#@markdown #####During training, the model files are automatically saved inside a folder named after the parameter `model_name` (see section 4.1). Provide the name of this folder as `QC_model_path` . \n", - "\n", - "QC_model_path = \"\" #@param {type:\"string\"}\n", - "\n", - "if (Use_the_current_trained_model): \n", - " QC_model_path = os.path.join(model_path, model_name)\n", - "\n", - "if os.path.exists(QC_model_path):\n", - " print(\"The \"+os.path.basename(QC_model_path)+\" model will be evaluated\")\n", - "else:\n", - " print(bcolors.WARNING+'!! WARNING: The chosen model does not exist !!'+bcolors.NORMAL)\n", - " print('Please make sure you provide a valid model path before proceeding further.')\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Gw7KaHZUoHC4" - }, - "source": [ - "## **5.1. Inspection of the loss function**\n", - "---\n", - "\n", - "First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n", - "\n", - "**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n", - "\n", - "**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n", - "\n", - "During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n", - "\n", - "Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "qUc-JMOcoGNZ", - "cellView": "form" - }, - "source": [ - "#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n", - "import csv\n", - "from matplotlib import pyplot as plt\n", - "\n", - "lossDataFromCSV = []\n", - "vallossDataFromCSV = []\n", - "\n", - "with open(os.path.join(QC_model_path,'Quality Control/training_evaluation.csv'),'r') as csvfile:\n", - " csvRead = csv.reader(csvfile, delimiter=',')\n", - " next(csvRead)\n", - " for row in csvRead:\n", - " lossDataFromCSV.append(float(row[0]))\n", - " vallossDataFromCSV.append(float(row[1]))\n", - "\n", - "epochNumber = range(len(lossDataFromCSV))\n", - "plt.figure(figsize=(15,10))\n", - "\n", - "plt.subplot(2,1,1)\n", - "plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n", - "plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n", - "plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n", - "plt.ylabel('Loss')\n", - "plt.xlabel('Epoch number')\n", - "plt.legend()\n", - "\n", - "plt.subplot(2,1,2)\n", - "plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n", - "plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n", - "plt.title('Training loss and validation loss vs. epoch number (log scale)')\n", - "plt.ylabel('Loss')\n", - "plt.xlabel('Epoch number')\n", - "plt.legend()\n", - "plt.savefig(os.path.join(QC_model_path,'Quality Control/lossCurvePlots.png'), bbox_inches='tight', pad_inches=0)\n", - "plt.show()\n", - "\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "32eNQjFioQkY" - }, - "source": [ - "## **5.2. Error mapping and quality metrics estimation**\n", - "---\n", - "\n", - "This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"QC_image_folder\" using teh corresponding localization data contained in \"QC_loc_folder\" !\n", - "\n", - "**1. The SSIM (structural similarity) map** \n", - "\n", - "The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n", - "\n", - "**mSSIM** is the SSIM value calculated across the entire window of both images.\n", - "\n", - "**The output below shows the SSIM maps with the mSSIM**\n", - "\n", - "**2. The RSE (Root Squared Error) map** \n", - "\n", - "This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n", - "\n", - "\n", - "**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n", - "\n", - "**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n", - "\n", - "**The output below shows the RSE maps with the NRMSE and PSNR values.**\n", - "\n", - "\n", - "\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "dhlTnxC5lUZy", - "cellView": "form" - }, - "source": [ - "\n", - "# ------------------------ User input ------------------------\n", - "#@markdown ##Choose the folders that contain your Quality Control dataset\n", - "QC_image_folder = \"\" #@param{type:\"string\"}\n", - "QC_loc_folder = \"\" #@param{type:\"string\"}\n", - "#@markdown Get pixel size from file?\n", - "get_pixel_size_from_file = True #@param {type:\"boolean\"}\n", - "#@markdown Otherwise, use this value:\n", - "pixel_size = 100 #@param {type:\"number\"}\n", - "\n", - "if get_pixel_size_from_file:\n", - " pixel_size_INPUT = None\n", - "else:\n", - " pixel_size_INPUT = pixel_size\n", - "\n", - "\n", - "# ------------------------ QC analysis loop over provided dataset ------------------------\n", - "\n", - "savePath = os.path.join(QC_model_path, 'Quality Control')\n", - "\n", - "# Open and create the csv file that will contain all the QC metrics\n", - "with open(os.path.join(savePath, os.path.basename(QC_model_path)+\"_QC_metrics.csv\"), \"w\", newline='') as file:\n", - " writer = csv.writer(file)\n", - "\n", - " # Write the header in the csv file\n", - " writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"WF v. GT mSSIM\", \"Prediction v. GT NRMSE\",\"WF v. GT NRMSE\", \"Prediction v. GT PSNR\", \"WF v. GT PSNR\"])\n", - "\n", - " # These lists will be used to collect all the metrics values per slice\n", - " file_name_list = []\n", - " slice_number_list = []\n", - " mSSIM_GvP_list = []\n", - " mSSIM_GvWF_list = []\n", - " NRMSE_GvP_list = []\n", - " NRMSE_GvWF_list = []\n", - " PSNR_GvP_list = []\n", - " PSNR_GvWF_list = []\n", - "\n", - " # Let's loop through the provided dataset in the QC folders\n", - "\n", - " for (imageFilename, locFilename) in zip(list_files(QC_image_folder, 'tif'), list_files(QC_loc_folder, 'csv')):\n", - " print('--------------')\n", - " print(imageFilename)\n", - " print(locFilename)\n", - "\n", - " # Get the prediction\n", - " batchFramePredictionLocalization(QC_image_folder, imageFilename, QC_model_path, savePath, pixel_size = pixel_size_INPUT)\n", - "\n", - " # test_model(QC_image_folder, imageFilename, QC_model_path, savePath, display=False);\n", - " thisPrediction = io.imread(os.path.join(savePath, 'Predicted_'+imageFilename))\n", - " thisWidefield = io.imread(os.path.join(savePath, 'Widefield_'+imageFilename))\n", - "\n", - " Mhr = thisPrediction.shape[0]\n", - " Nhr = thisPrediction.shape[1]\n", - "\n", - " if pixel_size_INPUT == None:\n", - " pixel_size, N, M = getPixelSizeTIFFmetadata(os.path.join(QC_image_folder,imageFilename))\n", - "\n", - " upsampling_factor = int(Mhr/M)\n", - " print('Upsampling factor: '+str(upsampling_factor))\n", - " pixel_size_hr = pixel_size/upsampling_factor # in nm\n", - "\n", - " # Load the localization file and display the first\n", - " LocData = pd.read_csv(os.path.join(QC_loc_folder,locFilename), index_col=0)\n", - "\n", - " x = np.array(list(LocData['x [nm]']))\n", - " y = np.array(list(LocData['y [nm]']))\n", - " locImage = FromLoc2Image_SimpleHistogram(x, y, image_size = (Mhr,Nhr), pixel_size = pixel_size_hr)\n", - "\n", - " # Remove extension from filename\n", - " imageFilename_no_extension = os.path.splitext(imageFilename)[0]\n", - "\n", - " # io.imsave(os.path.join(savePath, 'GT_image_'+imageFilename), locImage)\n", - " saveAsTIF(savePath, 'GT_image_'+imageFilename_no_extension, locImage, pixel_size_hr)\n", - "\n", - " # Normalize the images wrt each other by minimizing the MSE between GT and prediction\n", - " test_GT_norm, test_prediction_norm = norm_minmse(locImage, thisPrediction, normalize_gt=True)\n", - " # Normalize the images wrt each other by minimizing the MSE between GT and Source image\n", - " test_GT_norm, test_wf_norm = norm_minmse(locImage, thisWidefield, normalize_gt=True)\n", - "\n", - " # -------------------------------- Calculate the metric maps and save them --------------------------------\n", - "\n", - " # Calculate the SSIM maps\n", - " index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = structural_similarity(test_GT_norm, test_prediction_norm, data_range=1., full=True)\n", - " index_SSIM_GTvsWF, img_SSIM_GTvsWF = structural_similarity(test_GT_norm, test_wf_norm, data_range=1., full=True)\n", - "\n", - "\n", - " # Save ssim_maps\n", - " img_SSIM_GTvsPrediction_32bit = np.float32(img_SSIM_GTvsPrediction)\n", - " # io.imsave(os.path.join(savePath,'SSIM_GTvsPrediction_'+imageFilename),img_SSIM_GTvsPrediction_32bit)\n", - " saveAsTIF(savePath,'SSIM_GTvsPrediction_'+imageFilename_no_extension, img_SSIM_GTvsPrediction_32bit, pixel_size_hr)\n", - "\n", - "\n", - " img_SSIM_GTvsWF_32bit = np.float32(img_SSIM_GTvsWF)\n", - " # io.imsave(os.path.join(savePath,'SSIM_GTvsWF_'+imageFilename),img_SSIM_GTvsWF_32bit)\n", - " saveAsTIF(savePath,'SSIM_GTvsWF_'+imageFilename_no_extension, img_SSIM_GTvsWF_32bit, pixel_size_hr)\n", - "\n", - " \n", - " # Calculate the Root Squared Error (RSE) maps\n", - " img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n", - " img_RSE_GTvsWF = np.sqrt(np.square(test_GT_norm - test_wf_norm))\n", - "\n", - " # Save SE maps\n", - " img_RSE_GTvsPrediction_32bit = np.float32(img_RSE_GTvsPrediction)\n", - " # io.imsave(os.path.join(savePath,'RSE_GTvsPrediction_'+imageFilename),img_RSE_GTvsPrediction_32bit)\n", - " saveAsTIF(savePath,'RSE_GTvsPrediction_'+imageFilename_no_extension, img_RSE_GTvsPrediction_32bit, pixel_size_hr)\n", - "\n", - " img_RSE_GTvsWF_32bit = np.float32(img_RSE_GTvsWF)\n", - " # io.imsave(os.path.join(savePath,'RSE_GTvsWF_'+imageFilename),img_RSE_GTvsWF_32bit)\n", - " saveAsTIF(savePath,'RSE_GTvsWF_'+imageFilename_no_extension, img_RSE_GTvsWF_32bit, pixel_size_hr)\n", - "\n", - "\n", - " # -------------------------------- Calculate the RSE metrics and save them --------------------------------\n", - "\n", - " # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n", - " NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n", - " NRMSE_GTvsWF = np.sqrt(np.mean(img_RSE_GTvsWF))\n", - " \n", - " # We can also measure the peak signal to noise ratio between the images\n", - " PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n", - " PSNR_GTvsWF = psnr(test_GT_norm,test_wf_norm,data_range=1.0)\n", - "\n", - " writer.writerow([imageFilename,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsWF),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsWF),str(PSNR_GTvsPrediction), str(PSNR_GTvsWF)])\n", - "\n", - " # Collect values to display in dataframe output\n", - " file_name_list.append(imageFilename)\n", - " mSSIM_GvP_list.append(index_SSIM_GTvsPrediction)\n", - " mSSIM_GvWF_list.append(index_SSIM_GTvsWF)\n", - " NRMSE_GvP_list.append(NRMSE_GTvsPrediction)\n", - " NRMSE_GvWF_list.append(NRMSE_GTvsWF)\n", - " PSNR_GvP_list.append(PSNR_GTvsPrediction)\n", - " PSNR_GvWF_list.append(PSNR_GTvsWF)\n", - "\n", - "\n", - "# Table with metrics as dataframe output\n", - "pdResults = pd.DataFrame(index = file_name_list)\n", - "pdResults[\"Prediction v. GT mSSIM\"] = mSSIM_GvP_list\n", - "pdResults[\"Wide-field v. GT mSSIM\"] = mSSIM_GvWF_list\n", - "pdResults[\"Prediction v. GT NRMSE\"] = NRMSE_GvP_list\n", - "pdResults[\"Wide-field v. GT NRMSE\"] = NRMSE_GvWF_list\n", - "pdResults[\"Prediction v. GT PSNR\"] = PSNR_GvP_list\n", - "pdResults[\"Wide-field v. GT PSNR\"] = PSNR_GvWF_list\n", - "\n", - "\n", - "# ------------------------ Display ------------------------\n", - "\n", - "print('--------------------------------------------')\n", - "@interact\n", - "def show_QC_results(file = list_files(QC_image_folder, 'tif')):\n", - "\n", - " plt.figure(figsize=(15,15))\n", - " # Target (Ground-truth)\n", - " plt.subplot(3,3,1)\n", - " plt.axis('off')\n", - " img_GT = io.imread(os.path.join(savePath, 'GT_image_'+file))\n", - " plt.imshow(img_GT, norm = simple_norm(img_GT, percent = 99.5))\n", - " plt.title('Target',fontsize=15)\n", - "\n", - " # Wide-field\n", - " plt.subplot(3,3,2)\n", - " plt.axis('off')\n", - " img_Source = io.imread(os.path.join(savePath, 'Widefield_'+file))\n", - " plt.imshow(img_Source, norm = simple_norm(img_Source, percent = 99.5))\n", - " plt.title('Widefield',fontsize=15)\n", - "\n", - " #Prediction\n", - " plt.subplot(3,3,3)\n", - " plt.axis('off')\n", - " img_Prediction = io.imread(os.path.join(savePath, 'Predicted_'+file))\n", - " plt.imshow(img_Prediction, norm = simple_norm(img_Prediction, percent = 99.5))\n", - " plt.title('Prediction',fontsize=15)\n", - "\n", - " #Setting up colours\n", - " cmap = plt.cm.CMRmap\n", - "\n", - " #SSIM between GT and Source\n", - " plt.subplot(3,3,5)\n", - " #plt.axis('off')\n", - " plt.tick_params(\n", - " axis='both', # changes apply to the x-axis and y-axis\n", - " which='both', # both major and minor ticks are affected\n", - " bottom=False, # ticks along the bottom edge are off\n", - " top=False, # ticks along the top edge are off\n", - " left=False, # ticks along the left edge are off\n", - " right=False, # ticks along the right edge are off\n", - " labelbottom=False,\n", - " labelleft=False)\n", - " img_SSIM_GTvsWF = io.imread(os.path.join(savePath, 'SSIM_GTvsWF_'+file))\n", - " imSSIM_GTvsWF = plt.imshow(img_SSIM_GTvsWF, cmap = cmap, vmin=0, vmax=1)\n", - " plt.colorbar(imSSIM_GTvsWF,fraction=0.046, pad=0.04)\n", - " plt.title('Target vs. Widefield',fontsize=15)\n", - " plt.xlabel('mSSIM: '+str(round(pdResults.loc[file][\"Wide-field v. GT mSSIM\"],3)),fontsize=14)\n", - " plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n", - "\n", - " #SSIM between GT and Prediction\n", - " plt.subplot(3,3,6)\n", - " #plt.axis('off')\n", - " plt.tick_params(\n", - " axis='both', # changes apply to the x-axis and y-axis\n", - " which='both', # both major and minor ticks are affected\n", - " bottom=False, # ticks along the bottom edge are off\n", - " top=False, # ticks along the top edge are off\n", - " left=False, # ticks along the left edge are off\n", - " right=False, # ticks along the right edge are off\n", - " labelbottom=False,\n", - " labelleft=False)\n", - " img_SSIM_GTvsPrediction = io.imread(os.path.join(savePath, 'SSIM_GTvsPrediction_'+file))\n", - " imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n", - " plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n", - " plt.title('Target vs. Prediction',fontsize=15)\n", - " plt.xlabel('mSSIM: '+str(round(pdResults.loc[file][\"Prediction v. GT mSSIM\"],3)),fontsize=14)\n", - "\n", - " #Root Squared Error between GT and Source\n", - " plt.subplot(3,3,8)\n", - " #plt.axis('off')\n", - " plt.tick_params(\n", - " axis='both', # changes apply to the x-axis and y-axis\n", - " which='both', # both major and minor ticks are affected\n", - " bottom=False, # ticks along the bottom edge are off\n", - " top=False, # ticks along the top edge are off\n", - " left=False, # ticks along the left edge are off\n", - " right=False, # ticks along the right edge are off\n", - " labelbottom=False,\n", - " labelleft=False)\n", - " img_RSE_GTvsWF = io.imread(os.path.join(savePath, 'RSE_GTvsWF_'+file))\n", - " imRSE_GTvsWF = plt.imshow(img_RSE_GTvsWF, cmap = cmap, vmin=0, vmax = 1)\n", - " plt.colorbar(imRSE_GTvsWF,fraction=0.046,pad=0.04)\n", - " plt.title('Target vs. Widefield',fontsize=15)\n", - " plt.xlabel('NRMSE: '+str(round(pdResults.loc[file][\"Wide-field v. GT NRMSE\"],3))+', PSNR: '+str(round(pdResults.loc[file][\"Wide-field v. GT PSNR\"],3)),fontsize=14)\n", - " plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n", - "\n", - " #Root Squared Error between GT and Prediction\n", - " plt.subplot(3,3,9)\n", - " #plt.axis('off')\n", - " plt.tick_params(\n", - " axis='both', # changes apply to the x-axis and y-axis\n", - " which='both', # both major and minor ticks are affected\n", - " bottom=False, # ticks along the bottom edge are off\n", - " top=False, # ticks along the top edge are off\n", - " left=False, # ticks along the left edge are off\n", - " right=False, # ticks along the right edge are off\n", - " labelbottom=False,\n", - " labelleft=False)\n", - " img_RSE_GTvsPrediction = io.imread(os.path.join(savePath, 'RSE_GTvsPrediction_'+file))\n", - " imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n", - " plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n", - " plt.title('Target vs. Prediction',fontsize=15)\n", - " plt.xlabel('NRMSE: '+str(round(pdResults.loc[file][\"Prediction v. GT NRMSE\"],3))+', PSNR: '+str(round(pdResults.loc[file][\"Prediction v. GT PSNR\"],3)),fontsize=14)\n", - " plt.savefig(QC_model_path+'/Quality Control/QC_example_data.png', bbox_inches='tight', pad_inches=0)\n", - "print('--------------------------------------------')\n", - "pdResults.head()\n", - "\n", - "# Export pdf wth summary of QC results\n", - "qc_pdf_export()" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "5Q11jDczkLsZ" - }, - "source": [ - "## **5.4. Export your model into the BioImage Model Zoo format**\n", - "---\n", - "This section exports the model into the BioImage Model Zoo format so it can be used directly with DeepImageJ. The new files will be stored in the model folder specified at the beginning of Section 5. \n", - "\n", - "Once the cell is executed, you will find a new zip file with the name specified in `Trained_model_name.bioimage.io.model`.\n", - "\n", - "To use it with deepImageJ, download it and unzip it in the ImageJ/models/ or Fiji/models/ folder of your local machine. \n", - "\n", - "In ImageJ, open the example image given within the downloaded zip file. Go to Plugins > DeepImageJ > DeepImageJ Run. Choose this model from the list and click OK.\n", - "\n", - " More information at https://deepimagej.github.io/deepimagej/" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "QJuuDRFAkMIX", - "cellView": "form" - }, - "source": [ - "# ------------- User input ------------\n", - "# information about the model\n", - "# @markdown ##Introduce the metadata of the model architecture:\n", - "Trained_model_name = \"\" #@param {type:\"string\"}\n", - "Trained_model_authors = \"[Author 1, Author 2, Author 3]\" #@param {type:\"string\"}\n", - "Trained_model_description = #@param {type:\"string\n", - "Trained_model_license = 'MIT'#@param {type:\"string\"}\n", - "Trained_model_references = [\"Nehme E. et al., Optica 2018;\", \"Lucas von Chamier et al., biorXiv 2020\"]\n", - "Trained_model_DOI = [\"https://doi.org/10.1364/OPTICA.5.000458\", \"https://doi.org/10.1101/2020.03.20.000133\"]\n", - "\n", - "# information about the example image\n", - "#@markdown ##Do you want to choose the example image?\n", - "default_example_image = True #@param {type:\"boolean\"}\n", - "#@markdown ###If not, please input the path to the file:\n", - "example_image_file = \"\" #@param {type:\"string\"}\n", - "#@markdown ##Get pixel size from file?\n", - "get_pixel_size_from_file = True #@param {type:\"boolean\"}\n", - "#@markdown ###Otherwise, use this value (in nm):\n", - "PixelSize = 100 #@param {type:\"number\"}\n", - "\n", - "if default_example_image:\n", - " files = [file for file in list_files(QC_image_folder, 'tif')]\n", - " example_image_file = os.path.join(QC_image_folder,files[0])\n", - "\n", - "if get_pixel_size_from_file:\n", - " PixelSize, _, _ = getPixelSizeTIFFmetadata(example_image_file, display=True)\n", - "\n", - "# Create one example image (a 2D slice) and its output\n", - "\n", - "## Default values and data especific ones.\n", - "matfile = sio.loadmat(os.path.join(QC_model_path,'model_metadata.mat'))\n", - "test_mean = matfile['mean_test'].item() # convert to scalar\n", - "test_std = matfile['std_test'].item() # convert to scalar\n", - "upsampling_factor = np.array(matfile['upsampling_factor'])\n", - "upsampling_factor = upsampling_factor.item() # convert to scalar\n", - "L2_weighting_factor = np.array(matfile['Normalization factor'])\n", - "L2_weighting_factor = L2_weighting_factor.item() # convert to scalar\n", - "thresh = 0.1\n", - "pixel_size_hr = PixelSize/upsampling_factor\n", - "neighborhood_size = 3\n", - "# max_layer_thresh = thresh*L2_weighting_factor\n", - "\n", - "# Read the input image\n", - "test_img_stack = io.imread(example_image_file)\n", - "if len(test_img_stack.shape)>2:\n", - " test_img = test_img_stack[0]\n", - "else:\n", - " test_img = test_img_stack\n", - "input_im = project_01(test_img)\n", - "input_im = normalize_im(input_im, test_mean, test_std)\n", - "input_im = np.kron(input_im, np.ones((upsampling_factor,upsampling_factor)))\n", - "\n", - "# get dataset dimensions\n", - "(M, N) = input_im.shape\n", - "\n", - "# Build the model for a bigger image\n", - "model = buildModel((M, N, 1))\n", - "# Load the trained weights\n", - "model.load_weights(os.path.join(QC_model_path,'weights_best.hdf5'))\n", - "\n", - "# Reshaping\n", - "input_im = np.expand_dims(input_im, axis=[0,-1])\n", - "input_im = input_im.astype(np.float32)\n", - "# Inference\n", - "predicted_density = model.predict(input_im)\n", - "# Post-processing\n", - "predicted_density[predicted_density < 0] = 0\n", - "test_prediction = predicted_density.sum(axis = 3).sum(axis = 0)\n", - "test_prediction /= L2_weighting_factor\n", - "\n", - "# # Reduce model input size if necessary to avoid out of memory errors in Fiji\n", - "# M = np.min((M, 512))\n", - "# N = np.min((N, 512))\n", - "# Build the model for a bigger image\n", - "model = buildModel((M, N, 1))\n", - "# Load the trained weights\n", - "model.load_weights(os.path.join(QC_model_path,'weights_best.hdf5'))\n", - "\n", - "# Run this cell to export the model to the BioImage Model Zoo format.\n", - "####\n", - "from pydeepimagej.yaml import BioImageModelZooConfig\n", - "# from pydeepimagej.yaml.bioimage_specifications import get_specification\n", - "import urllib\n", - "\n", - "# ------------- Execute bioimage model zoo configuration ------------\n", - "# Check minimum size: it is [8,8] for the 2D XY plane\n", - "pooling_steps = 0\n", - "for keras_layer in model.layers:\n", - " if keras_layer.name.startswith('max') or \"pool\" in keras_layer.name:\n", - " pooling_steps += 1\n", - "MinimumSize = [2**(pooling_steps), 2**(pooling_steps)]\n", - "\n", - "dij_config = BioImageModelZooConfig(model, MinimumSize)\n", - "# we avoid padding for SMLM\n", - "dij_config.Halo = [0, 0]\n", - "\n", - "# Model developer details\n", - "dij_config.Authors = Trained_model_authors[1:-1].split(',')\n", - "dij_config.Description = Trained_model_description\n", - "dij_config.Name = Trained_model_name\n", - "dij_config.References = Trained_model_references\n", - "dij_config.DOI = Trained_model_DOI\n", - "dij_config.License = Trained_model_license\n", - "\n", - "# Additional information about the model\n", - "dij_config.GitHub = '/~https://github.com/HenriquesLab/ZeroCostDL4Mic'\n", - "dij_config.Date = datetime.now()\n", - "dij_config.Documentation = '/~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki'\n", - "dij_config.Tags = ['ZeroCostDL4Mic', 'deepimagej', 'SMLM', 'super-resolution', 'image reconstruction']\n", - "dij_config.Framework = 'tensorflow'\n", - "\n", - "# Add the information about the test image. Note here PixelSize is given in nm\n", - "dij_config.add_test_info(test_img, test_prediction, [0.001*PixelSize, 0.001*PixelSize])\n", - "dij_config.create_covers([test_img, test_prediction])\n", - "dij_config.Covers = ['./input.png', './output.png']\n", - "\n", - "# Store the model weights\n", - "# ---------------------------------------\n", - "# used_bioimageio_model_for_training_URL = \"/Some/URL/bioimage.io/\"\n", - "# dij_config.Parent = used_bioimageio_model_for_training_URL\n", - "\n", - "# Add weights information\n", - "format_authors = [\"pydeepimagej\"]\n", - "dij_config.add_weights_formats(model, 'TensorFlow', \n", - " parent=\"keras_hdf5\",\n", - " authors=[a for a in format_authors])\n", - "dij_config.add_weights_formats(model, 'KerasHDF5', \n", - " authors=[a for a in format_authors])\n", - "\n", - "## Prepare preprocessing file\n", - "path_preprocessing = \"MeanNormalization.ijm\"\n", - "urllib.request.urlretrieve(\"https://raw.githubusercontent.com/deepimagej/imagej-macros/master/MeanNormalization.ijm\", path_preprocessing )\n", - "# Modify the threshold in the macro to the chosen threshold\n", - "ijmacro = open(path_preprocessing,\"r\") \n", - "list_of_lines = ijmacro. readlines()\n", - "# Change model especific parameters\n", - "list_of_lines[19] = \"paramMean = {};\\n\".format(test_mean)\n", - "list_of_lines[20] = \"paramStd = {};\\n\".format(test_std)\n", - "list_of_lines.insert(len(list_of_lines), '\\n')\n", - "list_of_lines.insert(len(list_of_lines), '// Scaling\\n')\n", - "list_of_lines.insert(len(list_of_lines), 'run(\"Scale...\", \"x={0} y={0} interpolation=None create title=upsampled_input\");\\n'.format(upsampling_factor, N*upsampling_factor, M*upsampling_factor))\n", - "ijmacro.close()\n", - "ijmacro = open(path_preprocessing,\"w\") \n", - "ijmacro. writelines(list_of_lines)\n", - "ijmacro. close()\n", - "\n", - "## Prepare postprocessing file for Maxima Localization\n", - "path_postprocessing_max = \"LocalMaximaSMLM.ijm\"\n", - "urllib.request.urlretrieve(\"https://raw.githubusercontent.com/deepimagej/imagej-macros/master/LocalMaximaSMLM.ijm\", path_postprocessing_max )\n", - "\n", - "ijmacro = open(path_postprocessing_max,\"r\") \n", - "list_of_lines = ijmacro. readlines()\n", - "list_of_lines[11] = \"thresh = {};\\n\".format(thresh)\n", - "list_of_lines[12] = \"L2_weighting_factor = {};\\n\".format(L2_weighting_factor)\n", - "list_of_lines[13] = \"neighborhood_size = {};\\n\".format(neighborhood_size)\n", - "list_of_lines[14] = \"pixelSize = {}; // in nm and after upsampling\\n\".format(pixel_size_hr)\n", - "ijmacro.close()\n", - "ijmacro = open(path_postprocessing_max,\"w\") \n", - "ijmacro. writelines(list_of_lines)\n", - "ijmacro. close()\n", - "\n", - "## Prepare postprocessing file for Averaged Maxima Localization\n", - "path_postprocessing_avg = \"AveragedMaximaSMLM.ijm\"\n", - "urllib.request.urlretrieve(\"https://raw.githubusercontent.com/deepimagej/imagej-macros/master/AveragedMaximaSMLM.ijm\", path_postprocessing_avg)\n", - "\n", - "ijmacro = open(path_postprocessing_avg,\"r\") \n", - "list_of_lines = ijmacro. readlines()\n", - "list_of_lines[11] = \"thresh = {};\\n\".format(thresh)\n", - "list_of_lines[12] = \"L2_weighting_factor = {};\\n\".format(L2_weighting_factor)\n", - "list_of_lines[13] = \"neighborhood_size = {};\\n\".format(neighborhood_size)\n", - "list_of_lines[14] = \"pixelSize = {}; // in nm and after upsampling\\n\".format(pixel_size_hr)\n", - "ijmacro.close()\n", - "ijmacro = open(path_postprocessing_avg,\"w\") \n", - "ijmacro. writelines(list_of_lines)\n", - "ijmacro. close()\n", - "\n", - "# Include the info about the macros \n", - "dij_config.Preprocessing = [path_preprocessing]\n", - "dij_config.Preprocessing_files = [path_preprocessing]\n", - "\n", - "dij_config.Postprocessing = [path_postprocessing_max]\n", - "dij_config.Postprocessing_files = [path_postprocessing_max]\n", - "\n", - "## EXPORT THE MODEL\n", - "deepimagej_model_path = os.path.join(QC_model_path, Trained_model_name+'.bioimage.io.model')\n", - "dij_config.export_model(deepimagej_model_path)\n", - "\n", - "## Add csv with maxima localization and their confidence: \n", - "# Maxima localization\n", - "max_layer = Maximafinder(thresh*L2_weighting_factor, neighborhood_size, False)\n", - "bind, xind, yind, confidence = max_layer(predicted_density)\n", - "bind = bind.eval(session=tf.compat.v1.Session())\n", - "xind = xind.eval(session=tf.compat.v1.Session())\n", - "yind = yind.eval(session=tf.compat.v1.Session())\n", - "confidence = confidence.eval(session=tf.compat.v1.Session())\n", - "# normalizing the confidence by the L2_weighting_factor\n", - "confidence /= L2_weighting_factor \n", - "\n", - "# turn indices to nms and append to the results\n", - "xind, yind = xind*pixel_size_hr, yind*pixel_size_hr\n", - "# frmind = (bind.numpy() + b*batch_size + 1).tolist()\n", - "# xind = xind.numpy().tolist()\n", - "# yind = yind.numpy().tolist()\n", - "# confidence = confidence.numpy().tolist()\n", - "\n", - "frmind = (bind + 1).tolist()\n", - "xind = xind.tolist()\n", - "yind = yind.tolist()\n", - "confidence = confidence.tolist()\n", - "\n", - "with open(os.path.join(deepimagej_model_path, 'Localizations_resultImage_max.csv'), \"w\", newline='') as file:\n", - " writer = csv.writer(file)\n", - " writer.writerow(['frame', 'x [nm]', 'y [nm]', 'confidence [a.u]'])\n", - " locs = list(zip(frmind, xind, yind, confidence))\n", - " writer.writerows(locs)\n", - "# Zip the bundled model to download\n", - "shutil.make_archive(deepimagej_model_path, 'zip', deepimagej_model_path)\n", - "print(\"Localization csv file has been added to {0}.zip.\".format(deepimagej_model_path))\n", - "\n", - "## Prepare the macro file to process a stack with the trained model\n", - "path_macro_for_stacks = \"DeepSTORM4stacksThunderSTORM.ijm\"\n", - "urllib.request.urlretrieve(\"https://raw.githubusercontent.com/deepimagej/imagej-macros/master/DeepSTORM4stacksThunderSTORM.ijm\", path_macro_for_stacks)\n", - "shutil.copy2(path_macro_for_stacks, os.path.join(deepimagej_model_path, path_macro_for_stacks))\n", - "\n", - "# Save the Averaged Maxima Localization\n", - "shutil.copy2(path_postprocessing_avg, os.path.join(deepimagej_model_path, path_postprocessing_avg))\n", - "\n", - "print(\"An ImageJ macro file to process a entire stack has been added to the folder {0} under the name {1}.\".format(deepimagej_model_path, path_macro_for_stacks))\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "yTRou0izLjhd" - }, - "source": [ - "# **6. Using the trained model**\n", - "\n", - "---\n", - "\n", - "In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "eAf8aBDmWTx7" - }, - "source": [ - "## **6.1 Generate image prediction and localizations from unseen dataset**\n", - "---\n", - "\n", - "The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n", - "\n", - "**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n", - "\n", - "**`Result_folder`:** This folder will contain the found localizations csv.\n", - "\n", - "**`batch_size`:** This paramter determines how many frames are processed by any single pass on the GPU. A higher `batch_size` will make the prediction faster but will use more GPU memory. If an OutOfMemory (OOM) error occurs, decrease the `batch_size`. **DEFAULT: 4**\n", - "\n", - "**`threshold`:** This paramter determines threshold for local maxima finding. The value is expected to reside in the range **[0,1]**. A higher `threshold` will result in less localizations. **DEFAULT: 0.1**\n", - "\n", - "**`neighborhood_size`:** This paramter determines size of the neighborhood within which the prediction needs to be a local maxima in recovery pixels (CCD pixel/upsampling_factor). A high `neighborhood_size` will make the prediction slower and potentially discard nearby localizations. **DEFAULT: 3**\n", - "\n", - "**`use_local_average`:** This paramter determines whether to locally average the prediction in a 3x3 neighborhood to get the final localizations. If set to **True** it will make inference slightly slower depending on the size of the FOV. **DEFAULT: True**\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "7qn06T_A0lxf", - "cellView": "form" - }, - "source": [ - "\n", - "# ------------------------------- User input -------------------------------\n", - "#@markdown ### Data parameters\n", - "Data_folder = \"\" #@param {type:\"string\"}\n", - "Result_folder = \"\" #@param {type:\"string\"}\n", - "#@markdown Get pixel size from file?\n", - "get_pixel_size_from_file = True #@param {type:\"boolean\"}\n", - "#@markdown Otherwise, use this value (in nm):\n", - "pixel_size = 100 #@param {type:\"number\"}\n", - "\n", - "#@markdown ### Model parameters\n", - "#@markdown Do you want to use the model you just trained?\n", - "Use_the_current_trained_model = True #@param {type:\"boolean\"}\n", - "#@markdown Otherwise, please provide path to the model folder below\n", - "prediction_model_path = \"\" #@param {type:\"string\"}\n", - "\n", - "#@markdown ### Prediction parameters\n", - "batch_size = 4#@param {type:\"integer\"}\n", - "\n", - "#@markdown ### Post processing parameters\n", - "threshold = 0.1#@param {type:\"number\"}\n", - "neighborhood_size = 3#@param {type:\"integer\"}\n", - "#@markdown Do you want to locally average the model output with CoG estimator ?\n", - "use_local_average = True #@param {type:\"boolean\"}\n", - "\n", - "\n", - "if get_pixel_size_from_file:\n", - " pixel_size = None\n", - "\n", - "if (Use_the_current_trained_model): \n", - " prediction_model_path = os.path.join(model_path, model_name)\n", - "\n", - "if os.path.exists(prediction_model_path):\n", - " print(\"The \"+os.path.basename(prediction_model_path)+\" model will be used.\")\n", - "else:\n", - " print(bcolors.WARNING+'!! WARNING: The chosen model does not exist !!'+bcolors.NORMAL)\n", - " print('Please make sure you provide a valid model path before proceeding further.')\n", - "\n", - "# inform user whether local averaging is being used\n", - "if use_local_average == True: \n", - " print('Using local averaging')\n", - "\n", - "if not os.path.exists(Result_folder):\n", - " print('Result folder was created.')\n", - " os.makedirs(Result_folder)\n", - "\n", - "\n", - "# ------------------------------- Run predictions -------------------------------\n", - "\n", - "start = time.time()\n", - "#%% This script tests the trained fully convolutional network based on the \n", - "# saved training weights, and normalization created using train_model.\n", - "\n", - "if os.path.isdir(Data_folder): \n", - " for filename in list_files(Data_folder, 'tif'):\n", - " # run the testing/reconstruction process\n", - " print(\"------------------------------------\")\n", - " print(\"Running prediction on: \"+ filename)\n", - " batchFramePredictionLocalization(Data_folder, filename, prediction_model_path, Result_folder, \n", - " batch_size, \n", - " threshold, \n", - " neighborhood_size, \n", - " use_local_average,\n", - " pixel_size = pixel_size)\n", - "\n", - "elif os.path.isfile(Data_folder):\n", - " batchFramePredictionLocalization(os.path.dirname(Data_folder), os.path.basename(Data_folder), prediction_model_path, Result_folder, \n", - " batch_size, \n", - " threshold, \n", - " neighborhood_size, \n", - " use_local_average, \n", - " pixel_size = pixel_size)\n", - "\n", - "\n", - "\n", - "print('--------------------------------------------------------------------')\n", - "# Displaying the time elapsed for training\n", - "dt = time.time() - start\n", - "minutes, seconds = divmod(dt, 60) \n", - "hours, minutes = divmod(minutes, 60) \n", - "print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n", - "\n", - "\n", - "# ------------------------------- Interactive display -------------------------------\n", - "\n", - "print('--------------------------------------------------------------------')\n", - "print('---------------------------- Previews ------------------------------')\n", - "print('--------------------------------------------------------------------')\n", - "\n", - "if os.path.isdir(Data_folder): \n", - " @interact\n", - " def show_QC_results(file = list_files(Data_folder, 'tif')):\n", - "\n", - " plt.figure(figsize=(15,7.5))\n", - " # Wide-field\n", - " plt.subplot(1,2,1)\n", - " plt.axis('off')\n", - " img_Source = io.imread(os.path.join(Result_folder, 'Widefield_'+file))\n", - " plt.imshow(img_Source, norm = simple_norm(img_Source, percent = 99.5))\n", - " plt.title('Widefield', fontsize=15)\n", - " # Prediction\n", - " plt.subplot(1,2,2)\n", - " plt.axis('off')\n", - " img_Prediction = io.imread(os.path.join(Result_folder, 'Predicted_'+file))\n", - " plt.imshow(img_Prediction, norm = simple_norm(img_Prediction, percent = 99.5))\n", - " plt.title('Predicted',fontsize=15)\n", - "\n", - "if os.path.isfile(Data_folder):\n", - "\n", - " plt.figure(figsize=(15,7.5))\n", - " # Wide-field\n", - " plt.subplot(1,2,1)\n", - " plt.axis('off')\n", - " img_Source = io.imread(os.path.join(Result_folder, 'Widefield_'+os.path.basename(Data_folder)))\n", - " plt.imshow(img_Source, norm = simple_norm(img_Source, percent = 99.5))\n", - " plt.title('Widefield', fontsize=15)\n", - " # Prediction\n", - " plt.subplot(1,2,2)\n", - " plt.axis('off')\n", - " img_Prediction = io.imread(os.path.join(Result_folder, 'Predicted_'+os.path.basename(Data_folder)))\n", - " plt.imshow(img_Prediction, norm = simple_norm(img_Prediction, percent = 99.5))\n", - " plt.title('Predicted',fontsize=15)\n", - "\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ZekzexaPmzFZ" - }, - "source": [ - "## **6.2 Drift correction**\n", - "---\n", - "\n", - "The visualization above is the raw output of the network and displayed at the `upsampling_factor` chosen during model training. The display is a preview without any drift correction applied. This section performs drift correction using cross-correlation between time bins to estimate the drift.\n", - "\n", - "**`Loc_file_path`:** is the path to the localization file to use for visualization.\n", - "\n", - "**`original_image_path`:** is the path to the original image. This only serves to extract the original image size and pixel size to shape the visualization properly.\n", - "\n", - "**`visualization_pixel_size`:** This parameter corresponds to the pixel size to use for the image reconstructions used for the Drift Correction estmication (in **nm**). A smaller pixel size will be more precise but will take longer to compute. **DEFAULT: 20**\n", - "\n", - "**`number_of_bins`:** This parameter defines how many temporal bins are used across the full dataset. All localizations in each bins are used ot build an image. This image is used to find the drift with respect to the image obtained from the very first bin. A typical value would correspond to about 500 frames per bin. **DEFAULT: Total number of frames / 500**\n", - "\n", - "**`polynomial_fit_degree`:** The drift obtained for each temporal bins needs to be interpolated to every single frames. This is performed by polynomial fit, the degree of which is defined here. **DEFAULT: 4**\n", - "\n", - " The drift-corrected localization data is automaticaly saved in the `save_path` folder." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "hYtP_vh6mzUP", - "cellView": "form" - }, - "source": [ - "# @markdown ##Data parameters\n", - "Loc_file_path = \"\" #@param {type:\"string\"}\n", - "# @markdown Provide information about original data. Get the info automatically from the raw data?\n", - "Get_info_from_file = True #@param {type:\"boolean\"}\n", - "# Loc_file_path = \"/content/gdrive/My Drive/Colab notebooks testing/DeepSTORM/Glia data from CL/Results from prediction/20200615-M6 with CoM localizations/Localizations_glia_actin_2D - 1-500fr_avg.csv\" #@param {type:\"string\"}\n", - "original_image_path = \"\" #@param {type:\"string\"}\n", - "# @markdown Otherwise, please provide image width, height (in pixels) and pixel size (in nm)\n", - "image_width = 256#@param {type:\"integer\"}\n", - "image_height = 256#@param {type:\"integer\"}\n", - "pixel_size = 100 #@param {type:\"number\"}\n", - "\n", - "# @markdown ##Drift correction parameters\n", - "visualization_pixel_size = 20#@param {type:\"number\"}\n", - "number_of_bins = 50#@param {type:\"integer\"}\n", - "polynomial_fit_degree = 4#@param {type:\"integer\"}\n", - "\n", - "# @markdown ##Saving parameters\n", - "save_path = '' #@param {type:\"string\"}\n", - "\n", - "\n", - "# Let's go !\n", - "start = time.time()\n", - "\n", - "# Get info from the raw file if selected\n", - "if Get_info_from_file:\n", - " pixel_size, image_width, image_height = getPixelSizeTIFFmetadata(original_image_path, display=True)\n", - "\n", - "# Read the localizations in\n", - "LocData = pd.read_csv(Loc_file_path)\n", - "\n", - "# Calculate a few variables \n", - "Mhr = int(math.ceil(image_height*pixel_size/visualization_pixel_size))\n", - "Nhr = int(math.ceil(image_width*pixel_size/visualization_pixel_size))\n", - "nFrames = max(LocData['frame'])\n", - "x_max = max(LocData['x [nm]'])\n", - "y_max = max(LocData['y [nm]'])\n", - "image_size = (Mhr, Nhr)\n", - "n_locs = len(LocData.index)\n", - "\n", - "print('Image size: '+str(image_size))\n", - "print('Number of frames in data: '+str(nFrames))\n", - "print('Number of localizations in data: '+str(n_locs))\n", - "\n", - "blocksize = math.ceil(nFrames/number_of_bins)\n", - "print('Number of frames per block: '+str(blocksize))\n", - "\n", - "blockDataFrame = LocData[(LocData['frame'] < blocksize)].copy()\n", - "xc_array = blockDataFrame['x [nm]'].to_numpy(dtype=np.float32)\n", - "yc_array = blockDataFrame['y [nm]'].to_numpy(dtype=np.float32)\n", - "\n", - "# Preparing the Reference image\n", - "photon_array = np.ones(yc_array.shape[0])\n", - "sigma_array = np.ones(yc_array.shape[0])\n", - "ImageRef = FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = image_size, pixel_size = visualization_pixel_size)\n", - "ImagesRef = np.rot90(ImageRef, k=2)\n", - "\n", - "xDrift = np.zeros(number_of_bins)\n", - "yDrift = np.zeros(number_of_bins)\n", - "\n", - "filename_no_extension = os.path.splitext(os.path.basename(Loc_file_path))[0]\n", - "\n", - "with open(os.path.join(save_path, filename_no_extension+\"_DriftCorrectionData.csv\"), \"w\", newline='') as file:\n", - " writer = csv.writer(file)\n", - "\n", - " # Write the header in the csv file\n", - " writer.writerow([\"Block #\", \"x-drift [nm]\",\"y-drift [nm]\"])\n", - "\n", - " for b in tqdm(range(number_of_bins)):\n", - "\n", - " blockDataFrame = LocData[(LocData['frame'] >= (b*blocksize)) & (LocData['frame'] < ((b+1)*blocksize))].copy()\n", - " xc_array = blockDataFrame['x [nm]'].to_numpy(dtype=np.float32)\n", - " yc_array = blockDataFrame['y [nm]'].to_numpy(dtype=np.float32)\n", - "\n", - " photon_array = np.ones(yc_array.shape[0])\n", - " sigma_array = np.ones(yc_array.shape[0])\n", - " ImageBlock = FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = image_size, pixel_size = visualization_pixel_size)\n", - "\n", - " XC = fftconvolve(ImagesRef, ImageBlock, mode = 'same')\n", - " yDrift[b], xDrift[b] = subPixelMaxLocalization(XC, method = 'CoM')\n", - "\n", - " # saveAsTIF(save_path, 'ImageBlock'+str(b), ImageBlock, visualization_pixel_size)\n", - " # saveAsTIF(save_path, 'XCBlock'+str(b), XC, visualization_pixel_size)\n", - " writer.writerow([str(b), str((xDrift[b]-xDrift[0])*visualization_pixel_size), str((yDrift[b]-yDrift[0])*visualization_pixel_size)])\n", - "\n", - "\n", - "print('--------------------------------------------------------------------')\n", - "# Displaying the time elapsed for training\n", - "dt = time.time() - start\n", - "minutes, seconds = divmod(dt, 60) \n", - "hours, minutes = divmod(minutes, 60) \n", - "print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n", - "\n", - "print('Fitting drift data...')\n", - "bin_number = np.arange(number_of_bins)*blocksize + blocksize/2\n", - "xDrift = (xDrift-xDrift[0])*visualization_pixel_size\n", - "yDrift = (yDrift-yDrift[0])*visualization_pixel_size\n", - "\n", - "xDriftCoeff = np.polyfit(bin_number, xDrift, polynomial_fit_degree)\n", - "yDriftCoeff = np.polyfit(bin_number, yDrift, polynomial_fit_degree)\n", - "\n", - "xDriftFit = np.poly1d(xDriftCoeff)\n", - "yDriftFit = np.poly1d(yDriftCoeff)\n", - "bins = np.arange(nFrames)\n", - "xDriftInterpolated = xDriftFit(bins)\n", - "yDriftInterpolated = yDriftFit(bins)\n", - "\n", - "\n", - "# ------------------ Displaying the image results ------------------\n", - "\n", - "plt.figure(figsize=(15,10))\n", - "plt.plot(bin_number,xDrift, 'r+', label='x-drift')\n", - "plt.plot(bin_number,yDrift, 'b+', label='y-drift')\n", - "plt.plot(bins,xDriftInterpolated, 'r-', label='y-drift (fit)')\n", - "plt.plot(bins,yDriftInterpolated, 'b-', label='y-drift (fit)')\n", - "plt.title('Cross-correlation estimated drift')\n", - "plt.ylabel('Drift [nm]')\n", - "plt.xlabel('Bin number')\n", - "plt.legend();\n", - "\n", - "dt = time.time() - start\n", - "minutes, seconds = divmod(dt, 60) \n", - "hours, minutes = divmod(minutes, 60) \n", - "print(\"Time elapsed:\", hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n", - "\n", - "\n", - "# ------------------ Actual drift correction -------------------\n", - "\n", - "print('Correcting localization data...')\n", - "xc_array = LocData['x [nm]'].to_numpy(dtype=np.float32)\n", - "yc_array = LocData['y [nm]'].to_numpy(dtype=np.float32)\n", - "frames = LocData['frame'].to_numpy(dtype=np.int32)\n", - "\n", - "\n", - "xc_array_Corr, yc_array_Corr = correctDriftLocalization(xc_array, yc_array, frames, xDriftInterpolated, yDriftInterpolated)\n", - "ImageRaw = FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = image_size, pixel_size = visualization_pixel_size)\n", - "ImageCorr = FromLoc2Image_SimpleHistogram(xc_array_Corr, yc_array_Corr, image_size = image_size, pixel_size = visualization_pixel_size)\n", - "\n", - "\n", - "# ------------------ Displaying the imge results ------------------\n", - "plt.figure(figsize=(15,7.5))\n", - "# Raw\n", - "plt.subplot(1,2,1)\n", - "plt.axis('off')\n", - "plt.imshow(ImageRaw, norm = simple_norm(ImageRaw, percent = 99.5))\n", - "plt.title('Raw', fontsize=15);\n", - "# Corrected\n", - "plt.subplot(1,2,2)\n", - "plt.axis('off')\n", - "plt.imshow(ImageCorr, norm = simple_norm(ImageCorr, percent = 99.5))\n", - "plt.title('Corrected',fontsize=15);\n", - "\n", - "\n", - "# ------------------ Table with info -------------------\n", - "driftCorrectedLocData = pd.DataFrame()\n", - "driftCorrectedLocData['frame'] = frames\n", - "driftCorrectedLocData['x [nm]'] = xc_array_Corr\n", - "driftCorrectedLocData['y [nm]'] = yc_array_Corr\n", - "driftCorrectedLocData['confidence [a.u]'] = LocData['confidence [a.u]']\n", - "\n", - "driftCorrectedLocData.to_csv(os.path.join(save_path, filename_no_extension+'_DriftCorrected.csv'))\n", - "print('-------------------------------')\n", - "print('Corrected localizations saved.')\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "mzOuc-V7rB-r" - }, - "source": [ - "## **6.3 Visualization of the localizations**\n", - "---\n", - "\n", - "\n", - "The visualization in section 6.1 is the raw output of the network and displayed at the `upsampling_factor` chosen during model training. This section performs visualization of the result by plotting the localizations as a simple histogram.\n", - "\n", - "**`Loc_file_path`:** is the path to the localization file to use for visualization.\n", - "\n", - "**`original_image_path`:** is the path to the original image. This only serves to extract the original image size and pixel size to shape the visualization properly.\n", - "\n", - "**`visualization_pixel_size`:** This parameter corresponds to the pixel size to use for the final image reconstruction (in **nm**). **DEFAULT: 10**\n", - "\n", - "**`visualization_mode`:** This parameter defines what visualization method is used to visualize the final image. NOTES: The Integrated Gaussian can be quite slow. **DEFAULT: Simple histogram.**\n", - "\n", - "\n", - "\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "876yIXnqq-nW", - "cellView": "form" - }, - "source": [ - "# @markdown ##Data parameters\n", - "Use_current_drift_corrected_localizations = True #@param {type:\"boolean\"}\n", - "# @markdown Otherwise provide a localization file path\n", - "Loc_file_path = \"\" #@param {type:\"string\"}\n", - "# @markdown Provide information about original data. Get the info automatically from the raw data?\n", - "Get_info_from_file = True #@param {type:\"boolean\"}\n", - "# Loc_file_path = \"/content/gdrive/My Drive/Colab notebooks testing/DeepSTORM/Glia data from CL/Results from prediction/20200615-M6 with CoM localizations/Localizations_glia_actin_2D - 1-500fr_avg.csv\" #@param {type:\"string\"}\n", - "original_image_path = \"\" #@param {type:\"string\"}\n", - "# @markdown Otherwise, please provide image width, height (in pixels) and pixel size (in nm)\n", - "image_width = 256#@param {type:\"integer\"}\n", - "image_height = 256#@param {type:\"integer\"}\n", - "pixel_size = 100#@param {type:\"number\"}\n", - "\n", - "# @markdown ##Visualization parameters\n", - "visualization_pixel_size = 10#@param {type:\"number\"}\n", - "visualization_mode = \"Simple histogram\" #@param [\"Simple histogram\", \"Integrated Gaussian (SLOW!)\"]\n", - "\n", - "if not Use_current_drift_corrected_localizations:\n", - " filename_no_extension = os.path.splitext(os.path.basename(Loc_file_path))[0]\n", - "\n", - "\n", - "if Get_info_from_file:\n", - " pixel_size, image_width, image_height = getPixelSizeTIFFmetadata(original_image_path, display=True)\n", - "\n", - "if Use_current_drift_corrected_localizations:\n", - " LocData = driftCorrectedLocData\n", - "else:\n", - " LocData = pd.read_csv(Loc_file_path)\n", - "\n", - "Mhr = int(math.ceil(image_height*pixel_size/visualization_pixel_size))\n", - "Nhr = int(math.ceil(image_width*pixel_size/visualization_pixel_size))\n", - "\n", - "\n", - "nFrames = max(LocData['frame'])\n", - "x_max = max(LocData['x [nm]'])\n", - "y_max = max(LocData['y [nm]'])\n", - "image_size = (Mhr, Nhr)\n", - "\n", - "print('Image size: '+str(image_size))\n", - "print('Number of frames in data: '+str(nFrames))\n", - "print('Number of localizations in data: '+str(len(LocData.index)))\n", - "\n", - "xc_array = LocData['x [nm]'].to_numpy()\n", - "yc_array = LocData['y [nm]'].to_numpy()\n", - "if (visualization_mode == 'Simple histogram'):\n", - " locImage = FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = image_size, pixel_size = visualization_pixel_size)\n", - "elif (visualization_mode == 'Shifted histogram'):\n", - " print(bcolors.WARNING+'Method not implemented yet!'+bcolors.NORMAL)\n", - " locImage = np.zeros(image_size)\n", - "elif (visualization_mode == 'Integrated Gaussian (SLOW!)'):\n", - " photon_array = np.ones(xc_array.shape)\n", - " sigma_array = np.ones(xc_array.shape)\n", - " locImage = FromLoc2Image_Erf(xc_array, yc_array, photon_array, sigma_array, image_size = image_size, pixel_size = visualization_pixel_size)\n", - "\n", - "print('--------------------------------------------------------------------')\n", - "# Displaying the time elapsed for training\n", - "dt = time.time() - start\n", - "minutes, seconds = divmod(dt, 60) \n", - "hours, minutes = divmod(minutes, 60) \n", - "print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n", - "\n", - "# Display\n", - "plt.figure(figsize=(20,10))\n", - "plt.axis('off')\n", - "# plt.imshow(locImage, cmap='gray');\n", - "plt.imshow(locImage, norm = simple_norm(locImage, percent = 99.5));\n", - "\n", - "\n", - "LocData.head()\n", - "\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "PdOhWwMn1zIT", - "cellView": "form" - }, - "source": [ - "# @markdown ---\n", - "# @markdown #Play this cell to save the visualization\n", - "# @markdown ####Please select a path to the folder where to save the visualization.\n", - "save_path = \"\" #@param {type:\"string\"}\n", - "\n", - "if not os.path.exists(save_path):\n", - " os.makedirs(save_path)\n", - " print('Folder created.')\n", - "\n", - "saveAsTIF(save_path, filename_no_extension+'_Visualization', locImage, visualization_pixel_size)\n", - "print('Image saved.')" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1EszIF4Dkz_n" - }, - "source": [ - "## **6.4. Download your predictions**\n", - "---\n", - "\n", - "**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "UgN-NooKk3nV" - }, - "source": [ - "\n", - "#**Thank you for using Deep-STORM 2D!**" - ] - } - ] -} \ No newline at end of file +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"Deep-STORM_2D_ZeroCostDL4Mic_BioImageModelZoo_export.ipynb","provenance":[{"file_id":"1kD3rjN5XX5C33cQuX1DVc_n89cMqNvS_","timestamp":1610633423190},{"file_id":"1w95RljMrg15FLDRnEJiLIEa-lW-jEjQS","timestamp":1602684895691},{"file_id":"169qcwQo-yw15PwoGatXAdBvjs4wt_foD","timestamp":1592147948265},{"file_id":"1gjRCgDORKi_GNBu4QnVCBkSWrfPtqL-E","timestamp":1588525976305},{"file_id":"1DFy6aCi1XAVdjA5KLRZirB2aMZkMFdv-","timestamp":1587998755430},{"file_id":"1NpzigQoXGy3GFdxh4_jvG1PnBfyrcpBs","timestamp":1587569988032},{"file_id":"1jdI540qAfMSQwjnMhoAFkGJH9EbHwNSf","timestamp":1587486196143}],"collapsed_sections":[],"machine_shape":"hm"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"FpCtYevLHfl4"},"source":["# **Deep-STORM (2D)**\n","\n","---\n","\n","Deep-STORM is a neural network capable of image reconstruction from high-density single-molecule localization microscopy (SMLM), first published in 2018 by [Nehme *et al.* in Optica](https://www.osapublishing.org/optica/abstract.cfm?uri=optica-5-4-458). The architecture used here is a U-Net based network without skip connections. This network allows image reconstruction of 2D super-resolution images, in a supervised training manner. The network is trained using simulated high-density SMLM data for which the ground-truth is available. These simulations are obtained from random distribution of single molecules in a field-of-view and therefore do not imprint structural priors during training. The network output a super-resolution image with increased pixel density (typically upsampling factor of 8 in each dimension).\n","\n","Deep-STORM has **two key advantages**:\n","- SMLM reconstruction at high density of emitters\n","- fast prediction (reconstruction) once the model is trained appropriately, compared to more common multi-emitter fitting processes.\n","\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the *Zero-Cost Deep-Learning to Enhance Microscopy* project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is based on the following paper: \n","\n","**Deep-STORM: super-resolution single-molecule microscopy by deep learning**, Optica (2018) by *Elias Nehme, Lucien E. Weiss, Tomer Michaeli, and Yoav Shechtman* (https://www.osapublishing.org/optica/abstract.cfm?uri=optica-5-4-458)\n","\n","And source code found in: /~https://github.com/EliasNehme/Deep-STORM\n","\n","\n","**Please also cite this original paper when using or developing this notebook.**"]},{"cell_type":"markdown","metadata":{"id":"wyzTn3IcHq6Y"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"bEy4EBXHHyAX"},"source":["#**0. Before getting started**\n","---\n"," Deep-STORM is able to train on simulated dataset of SMLM data (see https://www.osapublishing.org/optica/abstract.cfm?uri=optica-5-4-458 for more info). Here, we provide a simulator that will generate training dataset (section 3.1.b). A few parameters will allow you to match the simulation to your experimental data. Similarly to what is described in the paper, simulations obtained from ThunderSTORM can also be loaded here (section 3.1.a).\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"jRnQZWSZhArJ"},"source":["# **1. Install Deep-STORM and dependencies**\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"qi1fN0tz3E8W"},"source":["## **1.1. Install key dependencies**\n","---\n"," "]},{"cell_type":"code","metadata":{"cellView":"form","id":"9tMgoFUS3CG8"},"source":["#@markdown ##Play to install DeepSTORM dependencies\n","\n","!pip install pydeepimagej==2.1.2\n","!pip install data\n","!pip install fpdf\n","!pip install h5py==2.10\n","\n","#Force session restart\n","exit(0)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Kl-3ilMs3QcO"},"source":["\n","## **1.2. Restart your runtime**\n","---\n","\n","\n","\n","** Ignore the following message error message. Your Runtime has automatically restarted. This is normal.**\n","\n","\"\"
"]},{"cell_type":"markdown","metadata":{"id":"6X3WptB33T7U"},"source":["## **1.3. Load key dependencies**\n","---\n"," "]},{"cell_type":"code","metadata":{"id":"kSrZMo3X_NhO","cellView":"form"},"source":["Notebook_version = '1.13'\n","Network = 'Deep-STORM'\n","\n","\n","\n","from builtins import any as b_any\n","\n","def get_requirements_path():\n"," # Store requirements file in 'contents' directory \n"," current_dir = os.getcwd()\n"," dir_count = current_dir.count('/') - 1\n"," path = '../' * (dir_count) + 'requirements.txt'\n"," return path\n","\n","def filter_files(file_list, filter_list):\n"," filtered_list = []\n"," for fname in file_list:\n"," if b_any(fname.split('==')[0] in s for s in filter_list):\n"," filtered_list.append(fname)\n"," return filtered_list\n","\n","def build_requirements_file(before, after):\n"," path = get_requirements_path()\n","\n"," # Exporting requirements.txt for local run\n"," !pip freeze > $path\n","\n"," # Get minimum requirements file\n"," df = pd.read_csv(path, delimiter = \"\\n\")\n"," mod_list = [m.split('.')[0] for m in after if not m in before]\n"," req_list_temp = df.values.tolist()\n"," req_list = [x[0] for x in req_list_temp]\n","\n"," # Replace with package name and handle cases where import name is different to module name\n"," mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]\n"," mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list] \n"," filtered_list = filter_files(req_list, mod_replace_list)\n","\n"," file=open(path,'w')\n"," for item in filtered_list:\n"," file.writelines(item + '\\n')\n","\n"," file.close()\n","\n","import sys\n","before = [str(m) for m in sys.modules]\n","\n","#@markdown ##Install Deep-STORM\n","# %% Model definition + helper functions\n","\n","%tensorflow_version 1.x\n","import tensorflow as tf\n","\n","# Import keras modules and libraries\n","from tensorflow import keras\n","from tensorflow.keras.models import Model\n","from tensorflow.keras.layers import Input, Activation, UpSampling2D, Convolution2D, MaxPooling2D, BatchNormalization, Layer\n","from tensorflow.keras.callbacks import Callback\n","from tensorflow.keras import backend as K\n","from tensorflow.keras import optimizers, losses\n","\n","from tensorflow.keras.preprocessing.image import ImageDataGenerator\n","from tensorflow.keras.callbacks import ModelCheckpoint\n","from tensorflow.keras.callbacks import ReduceLROnPlateau\n","from skimage.transform import warp\n","from skimage.transform import SimilarityTransform\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from scipy.signal import fftconvolve\n","\n","# Import common libraries\n","import numpy as np\n","import pandas as pd\n","import matplotlib.pyplot as plt\n","import h5py\n","import scipy.io as sio\n","from os.path import abspath\n","from sklearn.model_selection import train_test_split\n","from skimage import io\n","import time\n","import os\n","import shutil\n","import csv\n","from PIL import Image \n","from PIL.TiffTags import TAGS\n","from scipy.ndimage import gaussian_filter\n","import math\n","from astropy.visualization import simple_norm\n","from sys import getsizeof\n","from fpdf import FPDF, HTMLMixin\n","from pip._internal.operations.freeze import freeze\n","import subprocess\n","from datetime import datetime\n","\n","# For sliders and dropdown menu, progress bar\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","from tqdm import tqdm\n","\n","# For Multi-threading in simulation\n","from numba import njit, prange\n","\n","\n","\n","# define a function that projects and rescales an image to the range [0,1]\n","def project_01(im):\n"," im = np.squeeze(im)\n"," min_val = im.min()\n"," max_val = im.max()\n"," return (im - min_val)/(max_val - min_val)\n","\n","# normalize image given mean and std\n","def normalize_im(im, dmean, dstd):\n"," im = np.squeeze(im)\n"," im_norm = np.zeros(im.shape,dtype=np.float32)\n"," im_norm = (im - dmean)/dstd\n"," return im_norm\n","\n","# Define the loss history recorder\n","class LossHistory(Callback):\n"," def on_train_begin(self, logs={}):\n"," self.losses = []\n","\n"," def on_batch_end(self, batch, logs={}):\n"," self.losses.append(logs.get('loss'))\n"," \n","# Define a matlab like gaussian 2D filter\n","def matlab_style_gauss2D(shape=(7,7),sigma=1):\n"," \"\"\" \n"," 2D gaussian filter - should give the same result as:\n"," MATLAB's fspecial('gaussian',[shape],[sigma]) \n"," \"\"\"\n"," m,n = [(ss-1.)/2. for ss in shape]\n"," y,x = np.ogrid[-m:m+1,-n:n+1]\n"," h = np.exp( -(x*x + y*y) / (2.*sigma*sigma) )\n"," h.astype(dtype=K.floatx())\n"," h[ h < np.finfo(h.dtype).eps*h.max() ] = 0\n"," sumh = h.sum()\n"," if sumh != 0:\n"," h /= sumh\n"," h = h*2.0\n"," h = h.astype('float32')\n"," return h\n","\n","# Expand the filter dimensions\n","psf_heatmap = matlab_style_gauss2D(shape = (7,7),sigma=1)\n","gfilter = tf.reshape(psf_heatmap, [7, 7, 1, 1])\n","\n","# Combined MSE + L1 loss\n","def L1L2loss(input_shape):\n"," def bump_mse(heatmap_true, spikes_pred):\n","\n"," # generate the heatmap corresponding to the predicted spikes\n"," heatmap_pred = K.conv2d(spikes_pred, gfilter, strides=(1, 1), padding='same')\n","\n"," # heatmaps MSE\n"," loss_heatmaps = losses.mean_squared_error(heatmap_true,heatmap_pred)\n","\n"," # l1 on the predicted spikes\n"," loss_spikes = losses.mean_absolute_error(spikes_pred,tf.zeros(input_shape))\n"," return loss_heatmaps + loss_spikes\n"," return bump_mse\n","\n","# Define the concatenated conv2, batch normalization, and relu block\n","def conv_bn_relu(nb_filter, rk, ck, name):\n"," def f(input):\n"," conv = Convolution2D(nb_filter, kernel_size=(rk, ck), strides=(1,1),\\\n"," padding=\"same\", use_bias=False,\\\n"," kernel_initializer=\"Orthogonal\",name='conv-'+name)(input)\n"," conv_norm = BatchNormalization(name='BN-'+name)(conv)\n"," conv_norm_relu = Activation(activation = \"relu\",name='Relu-'+name)(conv_norm)\n"," return conv_norm_relu\n"," return f\n","\n","# Define the model architechture\n","def CNN(input,names):\n"," Features1 = conv_bn_relu(32,3,3,names+'F1')(input)\n"," pool1 = MaxPooling2D(pool_size=(2,2),name=names+'Pool1')(Features1)\n"," Features2 = conv_bn_relu(64,3,3,names+'F2')(pool1)\n"," pool2 = MaxPooling2D(pool_size=(2, 2),name=names+'Pool2')(Features2)\n"," Features3 = conv_bn_relu(128,3,3,names+'F3')(pool2)\n"," pool3 = MaxPooling2D(pool_size=(2, 2),name=names+'Pool3')(Features3)\n"," Features4 = conv_bn_relu(512,3,3,names+'F4')(pool3)\n"," up5 = UpSampling2D(size=(2, 2),name=names+'Upsample1')(Features4)\n"," Features5 = conv_bn_relu(128,3,3,names+'F5')(up5)\n"," up6 = UpSampling2D(size=(2, 2),name=names+'Upsample2')(Features5)\n"," Features6 = conv_bn_relu(64,3,3,names+'F6')(up6)\n"," up7 = UpSampling2D(size=(2, 2),name=names+'Upsample3')(Features6)\n"," Features7 = conv_bn_relu(32,3,3,names+'F7')(up7)\n"," return Features7\n","\n","# Define the Model building for an arbitrary input size\n","def buildModel(input_dim, initial_learning_rate = 0.001):\n"," input_ = Input (shape = (input_dim))\n"," act_ = CNN (input_,'CNN')\n"," density_pred = Convolution2D(1, kernel_size=(1, 1), strides=(1, 1), padding=\"same\",\\\n"," activation=\"linear\", use_bias = False,\\\n"," kernel_initializer=\"Orthogonal\",name='Prediction')(act_)\n"," model = Model (inputs= input_, outputs=density_pred)\n"," opt = optimizers.Adam(lr = initial_learning_rate)\n"," model.compile(optimizer=opt, loss = L1L2loss(input_dim))\n"," return model\n","\n","\n","# define a function that trains a model for a given data SNR and density\n","def train_model(patches, heatmaps, modelPath, epochs, steps_per_epoch, batch_size, upsampling_factor=8, validation_split = 0.3, initial_learning_rate = 0.001, pretrained_model_path = '', L2_weighting_factor = 100):\n"," \n"," \"\"\"\n"," This function trains a CNN model on the desired training set, given the \n"," upsampled training images and labels generated in MATLAB.\n"," \n"," # Inputs\n"," # TO UPDATE ----------\n","\n"," # Outputs\n"," function saves the weights of the trained model to a hdf5, and the \n"," normalization factors to a mat file. These will be loaded later for testing \n"," the model in test_model. \n"," \"\"\"\n"," \n"," # for reproducibility\n"," np.random.seed(123)\n","\n"," X_train, X_test, y_train, y_test = train_test_split(patches, heatmaps, test_size = validation_split, random_state=42)\n"," print('Number of training examples: %d' % X_train.shape[0])\n"," print('Number of validation examples: %d' % X_test.shape[0])\n"," \n"," # Setting type\n"," X_train = X_train.astype('float32')\n"," X_test = X_test.astype('float32')\n"," y_train = y_train.astype('float32')\n"," y_test = y_test.astype('float32')\n","\n"," \n"," #===================== Training set normalization ==========================\n"," # normalize training images to be in the range [0,1] and calculate the \n"," # training set mean and std\n"," mean_train = np.zeros(X_train.shape[0],dtype=np.float32)\n"," std_train = np.zeros(X_train.shape[0], dtype=np.float32)\n"," for i in range(X_train.shape[0]):\n"," X_train[i, :, :] = project_01(X_train[i, :, :])\n"," mean_train[i] = X_train[i, :, :].mean()\n"," std_train[i] = X_train[i, :, :].std()\n","\n"," # resulting normalized training images\n"," mean_val_train = mean_train.mean()\n"," std_val_train = std_train.mean()\n"," X_train_norm = np.zeros(X_train.shape, dtype=np.float32)\n"," for i in range(X_train.shape[0]):\n"," X_train_norm[i, :, :] = normalize_im(X_train[i, :, :], mean_val_train, std_val_train)\n"," \n"," # patch size\n"," psize = X_train_norm.shape[1]\n","\n"," # Reshaping\n"," X_train_norm = X_train_norm.reshape(X_train.shape[0], psize, psize, 1)\n","\n"," # ===================== Test set normalization ==========================\n"," # normalize test images to be in the range [0,1] and calculate the test set \n"," # mean and std\n"," mean_test = np.zeros(X_test.shape[0],dtype=np.float32)\n"," std_test = np.zeros(X_test.shape[0], dtype=np.float32)\n"," for i in range(X_test.shape[0]):\n"," X_test[i, :, :] = project_01(X_test[i, :, :])\n"," mean_test[i] = X_test[i, :, :].mean()\n"," std_test[i] = X_test[i, :, :].std()\n","\n"," # resulting normalized test images\n"," mean_val_test = mean_test.mean()\n"," std_val_test = std_test.mean()\n"," X_test_norm = np.zeros(X_test.shape, dtype=np.float32)\n"," for i in range(X_test.shape[0]):\n"," X_test_norm[i, :, :] = normalize_im(X_test[i, :, :], mean_val_test, std_val_test)\n"," \n"," # Reshaping\n"," X_test_norm = X_test_norm.reshape(X_test.shape[0], psize, psize, 1)\n","\n"," # Reshaping labels\n"," Y_train = y_train.reshape(y_train.shape[0], psize, psize, 1)\n"," Y_test = y_test.reshape(y_test.shape[0], psize, psize, 1)\n","\n"," # Save datasets to a matfile to open later in matlab\n"," mdict = {\"mean_test\": mean_val_test, \"std_test\": std_val_test, \"upsampling_factor\": upsampling_factor, \"Normalization factor\": L2_weighting_factor}\n"," sio.savemat(os.path.join(modelPath,\"model_metadata.mat\"), mdict)\n","\n","\n"," # Set the dimensions ordering according to tensorflow consensous\n"," # K.set_image_dim_ordering('tf')\n"," K.set_image_data_format('channels_last')\n","\n"," # Save the model weights after each epoch if the validation loss decreased\n"," checkpointer = ModelCheckpoint(filepath=os.path.join(modelPath,\"weights_best.hdf5\"), verbose=1,\n"," save_best_only=True)\n","\n"," # Change learning when loss reaches a plataeu\n"," change_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5, min_lr=0.00005)\n"," \n"," # Model building and complitation\n"," model = buildModel((psize, psize, 1), initial_learning_rate = initial_learning_rate)\n"," model.summary()\n","\n"," # Load pretrained model\n"," if not pretrained_model_path:\n"," print('Using random initial model weights.')\n"," else:\n"," print('Loading model weights from '+pretrained_model_path)\n"," model.load_weights(pretrained_model_path)\n"," \n"," # Create an image data generator for real time data augmentation\n"," datagen = ImageDataGenerator(\n"," featurewise_center=False, # set input mean to 0 over the dataset\n"," samplewise_center=False, # set each sample mean to 0\n"," featurewise_std_normalization=False, # divide inputs by std of the dataset\n"," samplewise_std_normalization=False, # divide each input by its std\n"," zca_whitening=False, # apply ZCA whitening\n"," rotation_range=0., # randomly rotate images in the range (degrees, 0 to 180)\n"," width_shift_range=0., # randomly shift images horizontally (fraction of total width)\n"," height_shift_range=0., # randomly shift images vertically (fraction of total height)\n"," zoom_range=0.,\n"," shear_range=0.,\n"," horizontal_flip=False, # randomly flip images\n"," vertical_flip=False, # randomly flip images\n"," fill_mode='constant',\n"," data_format=K.image_data_format())\n","\n"," # Fit the image generator on the training data\n"," datagen.fit(X_train_norm)\n"," \n"," # loss history recorder\n"," history = LossHistory()\n","\n"," # Inform user training begun\n"," print('-------------------------------')\n"," print('Training model...')\n","\n"," # Fit model on the batches generated by datagen.flow()\n"," train_history = model.fit_generator(datagen.flow(X_train_norm, Y_train, batch_size=batch_size), \n"," steps_per_epoch=steps_per_epoch, epochs=epochs, verbose=1, \n"," validation_data=(X_test_norm, Y_test), \n"," callbacks=[history, checkpointer, change_lr]) \n","\n"," # Inform user training ended\n"," print('-------------------------------')\n"," print('Training Complete!')\n"," \n"," # Save the last model\n"," model.save(os.path.join(modelPath, 'weights_last.hdf5'))\n","\n"," # convert the history.history dict to a pandas DataFrame: \n"," lossData = pd.DataFrame(train_history.history) \n","\n"," if os.path.exists(os.path.join(modelPath,\"Quality Control\")):\n"," shutil.rmtree(os.path.join(modelPath,\"Quality Control\"))\n","\n"," os.makedirs(os.path.join(modelPath,\"Quality Control\"))\n","\n"," # The training evaluation.csv is saved (overwrites the Files if needed). \n"," lossDataCSVpath = os.path.join(modelPath,\"Quality Control/training_evaluation.csv\")\n"," with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss','learning rate'])\n"," for i in range(len(train_history.history['loss'])):\n"," writer.writerow([train_history.history['loss'][i], train_history.history['val_loss'][i], train_history.history['lr'][i]])\n","\n"," return\n","\n","\n","# Normalization functions from Martin Weigert used in CARE\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","\n","# Multi-threaded Erf-based image construction\n","@njit(parallel=True)\n","def FromLoc2Image_Erf(xc_array, yc_array, photon_array, sigma_array, image_size = (64,64), pixel_size = 100):\n"," w = image_size[0]\n"," h = image_size[1]\n"," erfImage = np.zeros((w, h))\n"," for ij in prange(w*h):\n"," j = int(ij/w)\n"," i = ij - j*w\n"," for (xc, yc, photon, sigma) in zip(xc_array, yc_array, photon_array, sigma_array):\n"," # Don't bother if the emitter has photons <= 0 or if Sigma <= 0\n"," if (sigma > 0) and (photon > 0):\n"," S = sigma*math.sqrt(2)\n"," x = i*pixel_size - xc\n"," y = j*pixel_size - yc\n"," # Don't bother if the emitter is further than 4 sigma from the centre of the pixel\n"," if (x+pixel_size/2)**2 + (y+pixel_size/2)**2 < 16*sigma**2:\n"," ErfX = math.erf((x+pixel_size)/S) - math.erf(x/S)\n"," ErfY = math.erf((y+pixel_size)/S) - math.erf(y/S)\n"," erfImage[j][i] += 0.25*photon*ErfX*ErfY\n"," return erfImage\n","\n","\n","@njit(parallel=True)\n","def FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = (64,64), pixel_size = 100):\n"," w = image_size[0]\n"," h = image_size[1]\n"," locImage = np.zeros((image_size[0],image_size[1]) )\n"," n_locs = len(xc_array)\n","\n"," for e in prange(n_locs):\n"," locImage[int(max(min(round(yc_array[e]/pixel_size),w-1),0))][int(max(min(round(xc_array[e]/pixel_size),h-1),0))] += 1\n","\n"," return locImage\n","\n","\n","\n","def getPixelSizeTIFFmetadata(TIFFpath, display=False):\n"," with Image.open(TIFFpath) as img:\n"," meta_dict = {TAGS[key] : img.tag[key] for key in img.tag.keys()}\n","\n","\n"," # TIFF tags\n"," # https://www.loc.gov/preservation/digital/formats/content/tiff_tags.shtml\n"," # https://www.awaresystems.be/imaging/tiff/tifftags/resolutionunit.html\n"," ResolutionUnit = meta_dict['ResolutionUnit'][0] # unit of resolution\n"," width = meta_dict['ImageWidth'][0]\n"," height = meta_dict['ImageLength'][0]\n","\n"," xResolution = meta_dict['XResolution'][0] # number of pixels / ResolutionUnit\n","\n"," if len(xResolution) == 1:\n"," xResolution = xResolution[0]\n"," elif len(xResolution) == 2:\n"," xResolution = xResolution[0]/xResolution[1]\n"," else:\n"," print('Image resolution not defined.')\n"," xResolution = 1\n","\n"," if ResolutionUnit == 2:\n"," # Units given are in inches\n"," pixel_size = 0.025*1e9/xResolution\n"," elif ResolutionUnit == 3:\n"," # Units given are in cm\n"," pixel_size = 0.01*1e9/xResolution\n"," else: \n"," # ResolutionUnit is therefore 1\n"," print('Resolution unit not defined. Assuming: um')\n"," pixel_size = 1e3/xResolution\n","\n"," if display:\n"," print('Pixel size obtained from metadata: '+str(pixel_size)+' nm')\n"," print('Image size: '+str(width)+'x'+str(height))\n"," \n"," return (pixel_size, width, height)\n","\n","\n","def saveAsTIF(path, filename, array, pixel_size):\n"," \"\"\"\n"," Image saving using PIL to save as .tif format\n"," # Input \n"," path - path where it will be saved\n"," filename - name of the file to save (no extension)\n"," array - numpy array conatining the data at the required format\n"," pixel_size - physical size of pixels in nanometers (identical for x and y)\n"," \"\"\"\n","\n"," # print('Data type: '+str(array.dtype))\n"," if (array.dtype == np.uint16):\n"," mode = 'I;16'\n"," elif (array.dtype == np.uint32):\n"," mode = 'I'\n"," else:\n"," mode = 'F'\n","\n"," # Rounding the pixel size to the nearest number that divides exactly 1cm.\n"," # Resolution needs to be a rational number --> see TIFF format\n"," # pixel_size = 10000/(round(10000/pixel_size))\n","\n"," if len(array.shape) == 2:\n"," im = Image.fromarray(array)\n"," im.save(os.path.join(path, filename+'.tif'),\n"," mode = mode, \n"," resolution_unit = 3,\n"," resolution = 0.01*1e9/pixel_size)\n","\n","\n"," elif len(array.shape) == 3:\n"," imlist = []\n"," for frame in array:\n"," imlist.append(Image.fromarray(frame))\n","\n"," imlist[0].save(os.path.join(path, filename+'.tif'), save_all=True,\n"," append_images=imlist[1:],\n"," mode = mode, \n"," resolution_unit = 3,\n"," resolution = 0.01*1e9/pixel_size)\n","\n"," return\n","\n","\n","\n","\n","class Maximafinder(Layer):\n"," def __init__(self, thresh, neighborhood_size, use_local_avg, **kwargs):\n"," super(Maximafinder, self).__init__(**kwargs)\n"," self.thresh = tf.constant(thresh, dtype=tf.float32)\n"," self.nhood = neighborhood_size\n"," self.use_local_avg = use_local_avg\n","\n"," def build(self, input_shape):\n"," if self.use_local_avg is True:\n"," self.kernel_x = tf.reshape(tf.constant([[-1,0,1],[-1,0,1],[-1,0,1]], dtype=tf.float32), [3, 3, 1, 1])\n"," self.kernel_y = tf.reshape(tf.constant([[-1,-1,-1],[0,0,0],[1,1,1]], dtype=tf.float32), [3, 3, 1, 1])\n"," self.kernel_sum = tf.reshape(tf.constant([[1,1,1],[1,1,1],[1,1,1]], dtype=tf.float32), [3, 3, 1, 1])\n","\n"," def call(self, inputs):\n","\n"," # local maxima positions\n"," max_pool_image = MaxPooling2D(pool_size=(self.nhood,self.nhood), strides=(1,1), padding='same')(inputs)\n"," cond = tf.math.greater(max_pool_image, self.thresh) & tf.math.equal(max_pool_image, inputs)\n"," indices = tf.where(cond)\n"," bind, xind, yind = indices[:, 0], indices[:, 2], indices[:, 1]\n"," confidence = tf.gather_nd(inputs, indices)\n","\n"," # local CoG estimator\n"," if self.use_local_avg:\n"," x_image = K.conv2d(inputs, self.kernel_x, padding='same')\n"," y_image = K.conv2d(inputs, self.kernel_y, padding='same')\n"," sum_image = K.conv2d(inputs, self.kernel_sum, padding='same')\n"," confidence = tf.cast(tf.gather_nd(sum_image, indices), dtype=tf.float32)\n"," x_local = tf.math.divide(tf.gather_nd(x_image, indices),tf.gather_nd(sum_image, indices))\n"," y_local = tf.math.divide(tf.gather_nd(y_image, indices),tf.gather_nd(sum_image, indices))\n"," xind = tf.cast(xind, dtype=tf.float32) + tf.cast(x_local, dtype=tf.float32)\n"," yind = tf.cast(yind, dtype=tf.float32) + tf.cast(y_local, dtype=tf.float32)\n"," else:\n"," xind = tf.cast(xind, dtype=tf.float32)\n"," yind = tf.cast(yind, dtype=tf.float32)\n"," \n"," return bind, xind, yind, confidence\n","\n"," def get_config(self):\n","\n"," # Implement get_config to enable serialization. This is optional.\n"," base_config = super(Maximafinder, self).get_config()\n"," config = {}\n"," return dict(list(base_config.items()) + list(config.items()))\n","\n","\n","\n","# ------------------------------- Prediction with postprocessing function-------------------------------\n","def batchFramePredictionLocalization(dataPath, filename, modelPath, savePath, batch_size=1, thresh=0.1, neighborhood_size=3, use_local_avg = False, pixel_size = None):\n"," \"\"\"\n"," This function tests a trained model on the desired test set, given the \n"," tiff stack of test images, learned weights, and normalization factors.\n"," \n"," # Inputs\n"," dataPath - the path to the folder containing the tiff stack(s) to run prediction on \n"," filename - the name of the file to process\n"," modelPath - the path to the folder containing the weights file and the mean and standard deviation file generated in train_model\n"," savePath - the path to the folder where to save the prediction\n"," batch_size. - the number of frames to predict on for each iteration\n"," thresh - threshoold percentage from the maximum of the gaussian scaling\n"," neighborhood_size - the size of the neighborhood for local maxima finding\n"," use_local_average - Boolean whether to perform local averaging or not\n"," \"\"\"\n"," \n"," # load mean and std\n"," matfile = sio.loadmat(os.path.join(modelPath,'model_metadata.mat'))\n"," test_mean = np.array(matfile['mean_test'])\n"," test_std = np.array(matfile['std_test']) \n"," upsampling_factor = np.array(matfile['upsampling_factor'])\n"," upsampling_factor = upsampling_factor.item() # convert to scalar\n"," L2_weighting_factor = np.array(matfile['Normalization factor'])\n"," L2_weighting_factor = L2_weighting_factor.item() # convert to scalar\n","\n"," # Read in the raw file\n"," Images = io.imread(os.path.join(dataPath, filename))\n"," if pixel_size == None:\n"," pixel_size, _, _ = getPixelSizeTIFFmetadata(os.path.join(dataPath, filename), display=True)\n"," pixel_size_hr = pixel_size/upsampling_factor\n","\n"," # get dataset dimensions\n"," (nFrames, M, N) = Images.shape\n"," print('Input image is '+str(N)+'x'+str(M)+' with '+str(nFrames)+' frames.')\n","\n"," # Build the model for a bigger image\n"," model = buildModel((upsampling_factor*M, upsampling_factor*N, 1))\n","\n"," # Load the trained weights\n"," model.load_weights(os.path.join(modelPath,'weights_best.hdf5'))\n","\n"," # add a post-processing module\n"," max_layer = Maximafinder(thresh*L2_weighting_factor, neighborhood_size, use_local_avg)\n","\n"," # Initialise the results: lists will be used to collect all the localizations\n"," frame_number_list, x_nm_list, y_nm_list, confidence_au_list = [], [], [], []\n","\n"," # Initialise the results\n"," Prediction = np.zeros((M*upsampling_factor, N*upsampling_factor), dtype=np.float32)\n"," Widefield = np.zeros((M, N), dtype=np.float32)\n","\n"," # run model in batches\n"," n_batches = math.ceil(nFrames/batch_size)\n"," for b in tqdm(range(n_batches)):\n","\n"," nF = min(batch_size, nFrames - b*batch_size)\n"," Images_norm = np.zeros((nF, M, N),dtype=np.float32)\n"," Images_upsampled = np.zeros((nF, M*upsampling_factor, N*upsampling_factor), dtype=np.float32)\n","\n"," # Upsampling using a simple nearest neighbor interp and calculating - MULTI-THREAD this?\n"," for f in range(nF):\n"," Images_norm[f,:,:] = project_01(Images[b*batch_size+f,:,:])\n"," Images_norm[f,:,:] = normalize_im(Images_norm[f,:,:], test_mean, test_std)\n"," Images_upsampled[f,:,:] = np.kron(Images_norm[f,:,:], np.ones((upsampling_factor,upsampling_factor)))\n"," Widefield += Images[b*batch_size+f,:,:]\n","\n"," # Reshaping\n"," Images_upsampled = np.expand_dims(Images_upsampled,axis=3)\n","\n"," # Run prediction and local amxima finding\n"," predicted_density = model.predict_on_batch(Images_upsampled)\n"," predicted_density[predicted_density < 0] = 0\n"," Prediction += predicted_density.sum(axis = 3).sum(axis = 0)\n","\n"," bind, xind, yind, confidence = max_layer(predicted_density)\n"," bind = bind.eval(session=tf.compat.v1.Session())\n"," xind = xind.eval(session=tf.compat.v1.Session())\n"," yind = yind.eval(session=tf.compat.v1.Session())\n"," confidence = confidence.eval(session=tf.compat.v1.Session())\n","\n"," # normalizing the confidence by the L2_weighting_factor\n"," confidence /= L2_weighting_factor \n","\n"," # turn indices to nms and append to the results\n"," xind, yind = xind*pixel_size_hr, yind*pixel_size_hr\n"," # frmind = (bind.numpy() + b*batch_size + 1).tolist()\n"," # xind = xind.numpy().tolist()\n"," # yind = yind.numpy().tolist()\n"," # confidence = confidence.numpy().tolist()\n"," frmind = (bind + b*batch_size + 1).tolist()\n"," xind = xind.tolist()\n"," yind = yind.tolist()\n"," confidence = confidence.tolist()\n"," \n"," frame_number_list += frmind\n"," x_nm_list += xind\n"," y_nm_list += yind\n"," confidence_au_list += confidence\n","\n"," # Open and create the csv file that will contain all the localizations\n"," if use_local_avg:\n"," ext = '_avg'\n"," else:\n"," ext = '_max'\n"," with open(os.path.join(savePath, 'Localizations_' + os.path.splitext(filename)[0] + ext + '.csv'), \"w\", newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow(['frame', 'x [nm]', 'y [nm]', 'confidence [a.u]'])\n"," locs = list(zip(frame_number_list, x_nm_list, y_nm_list, confidence_au_list))\n"," writer.writerows(locs)\n","\n"," # Save the prediction and widefield image\n"," Widefield = np.kron(Widefield, np.ones((upsampling_factor,upsampling_factor)))\n"," Widefield = np.float32(Widefield)\n","\n"," # io.imsave(os.path.join(savePath, 'Predicted_'+os.path.splitext(filename)[0]+'.tif'), Prediction)\n"," # io.imsave(os.path.join(savePath, 'Widefield_'+os.path.splitext(filename)[0]+'.tif'), Widefield)\n","\n"," saveAsTIF(savePath, 'Predicted_'+os.path.splitext(filename)[0], Prediction, pixel_size_hr)\n"," saveAsTIF(savePath, 'Widefield_'+os.path.splitext(filename)[0], Widefield, pixel_size_hr)\n","\n","\n"," return\n","\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n"," NORMAL = '\\033[0m' # white (normal)\n","\n","\n","\n","def list_files(directory, extension):\n"," return (f for f in os.listdir(directory) if f.endswith('.' + extension))\n","\n","\n","# @njit(parallel=True)\n","def subPixelMaxLocalization(array, method = 'CoM', patch_size = 3):\n"," xMaxInd, yMaxInd = np.unravel_index(array.argmax(), array.shape, order='C')\n"," centralPatch = XC[(xMaxInd-patch_size):(xMaxInd+patch_size+1),(yMaxInd-patch_size):(yMaxInd+patch_size+1)]\n","\n"," if (method == 'MAX'):\n"," x0 = xMaxInd\n"," y0 = yMaxInd\n","\n"," elif (method == 'CoM'):\n"," x0 = 0\n"," y0 = 0\n"," S = 0\n"," for xy in range(patch_size*patch_size):\n"," y = math.floor(xy/patch_size)\n"," x = xy - y*patch_size\n"," x0 += x*array[x,y]\n"," y0 += y*array[x,y]\n"," S = array[x,y]\n"," \n"," x0 = x0/S - patch_size/2 + xMaxInd\n"," y0 = y0/S - patch_size/2 + yMaxInd\n"," \n"," elif (method == 'Radiality'):\n"," # Not implemented yet\n"," x0 = xMaxInd\n"," y0 = yMaxInd\n"," \n"," return (x0, y0)\n","\n","\n","@njit(parallel=True)\n","def correctDriftLocalization(xc_array, yc_array, frames, xDrift, yDrift):\n"," n_locs = xc_array.shape[0]\n"," xc_array_Corr = np.empty(n_locs)\n"," yc_array_Corr = np.empty(n_locs)\n"," \n"," for loc in prange(n_locs):\n"," xc_array_Corr[loc] = xc_array[loc] - xDrift[frames[loc]]\n"," yc_array_Corr[loc] = yc_array[loc] - yDrift[frames[loc]]\n","\n"," return (xc_array_Corr, yc_array_Corr)\n","\n","\n","print('--------------------------------')\n","print('DeepSTORM installation complete.')\n","\n","# Check if this is the latest version of the notebook\n","\n","All_notebook_versions = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_Notebook_versions.csv\", dtype=str)\n","print('Notebook version: '+Notebook_version)\n","\n","Latest_Notebook_version = All_notebook_versions[All_notebook_versions[\"Notebook\"] == Network]['Version'].iloc[0]\n","print('Latest notebook version: '+Latest_Notebook_version)\n","\n","\n","if Notebook_version == Latest_Notebook_version:\n"," print(\"This notebook is up-to-date.\")\n","else:\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","\n","# Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n","\n","# if Notebook_version == list(Latest_notebook_version.columns):\n","# print(\"This notebook is up-to-date.\")\n","\n","# if not Notebook_version == list(Latest_notebook_version.columns):\n","# print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","def pdf_export(trained = False, raw_data = False, pretrained_model = False):\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," \n"," #model_name = 'little_CARE_test'\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," \n"," # add another cell \n"," if trained:\n"," training_time = \"Training time: \"+str(hours)+ \"hour(s) \"+str(minutes)+\"min(s) \"+str(round(seconds))+\"sec(s)\"\n"," pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n"," pdf.ln(1)\n","\n"," Header_2 = 'Information for your materials and method:'\n"," pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n"," #print(all_packages)\n","\n"," #Main Packages\n"," main_packages = ''\n"," version_numbers = []\n"," for name in ['tensorflow','numpy','Keras']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n"," cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n"," cuda_version = cuda_version.stdout.decode('utf-8')\n"," cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n"," gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n"," gpu_name = gpu_name.stdout.decode('utf-8')\n"," gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n"," #print(cuda_version[cuda_version.find(', V')+3:-1])\n"," #print(gpu_name)\n"," if raw_data == True:\n"," shape = (M,N)\n"," else:\n"," shape = (int(FOV_size/pixel_size),int(FOV_size/pixel_size))\n"," #dataset_size = len(os.listdir(Training_source))\n","\n"," text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(n_patches)+' paired image patches (image dimensions: '+str(patch_size)+', patch size (upsampled): ('+str(int(patch_size))+','+str(int(patch_size))+') with a batch size of '+str(batch_size)+', using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Losses were calculated using MSE for the heatmaps and L1 loss for the spike prediction. Key python packages used include tensorflow (v '+version_numbers[0]+'), numpy (v '+version_numbers[1]+'), Keras (v '+version_numbers[2]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+' GPU.'\n","\n"," if pretrained_model:\n"," text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(n_patches)+' paired image patches (image dimensions: '+str(patch_size)+', patch size (upsampled): ('+str(int(patch_size))+','+str(int(patch_size))+') with a batch size of '+str(batch_size)+', using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Losses were calculated using MSE for the heatmaps and L1 loss for the spike prediction. The models was retrained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), numpy (v '+version_numbers[1]+'), Keras (v '+version_numbers[2]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+' GPU.'\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.multi_cell(180, 5, txt = text, align='L')\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," pdf.ln(1)\n"," pdf.cell(190, 5, txt = 'Training dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," if raw_data==False:\n"," simul_text = 'The training dataset was created in the notebook using the following simulation settings:'\n"," pdf.cell(200, 5, txt=simul_text, align='L')\n"," pdf.ln(1)\n"," html = \"\"\" \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
SettingSimulated Value
FOV_size{0}
pixel_size{1}
ADC_per_photon_conversion{2}
ReadOutNoise_ADC{3}
ADC_offset{4}
emitter_density{5}
emitter_density_std{6}
number_of_frames{7}
sigma{8}
sigma_std{9}
n_photons{10}
n_photons_std{11}
\n"," \"\"\".format(FOV_size, pixel_size, ADC_per_photon_conversion, ReadOutNoise_ADC, ADC_offset, emitter_density, emitter_density_std, number_of_frames, sigma, sigma_std, n_photons, n_photons_std)\n"," pdf.write_html(html)\n"," else:\n"," simul_text = 'The training dataset was simulated using ThunderSTORM and loaded into the notebook.'\n"," pdf.multi_cell(190, 5, txt=simul_text, align='L')\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," #pdf.ln(1)\n"," #pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(29, 5, txt= 'ImageData_path', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = ImageData_path, align = 'L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(28, 5, txt= 'LocalizationData_path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = LocalizationData_path, align = 'L')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(28, 5, txt= 'pixel_size:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = str(pixel_size), align = 'L')\n"," #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," # if Use_Default_Advanced_Parameters:\n"," # pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n"," pdf.cell(200, 5, txt='The following parameters were used to generate patches:')\n"," pdf.ln(1)\n"," html = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(str(patch_size)+'x'+str(patch_size), upsampling_factor, num_patches_per_frame, min_number_of_emitters_per_patch, max_num_patches, gaussian_sigma, Automatic_normalization, L2_weighting_factor)\n"," pdf.write_html(html)\n"," pdf.ln(3)\n"," pdf.set_font('Arial', size=10)\n"," pdf.cell(200, 5, txt='The following parameters were used for training:')\n"," pdf.ln(1)\n"," html = \"\"\" \n","
Patch ParameterValue
patch_size{0}
upsampling_factor{1}
num_patches_per_frame{2}
min_number_of_emitters_per_patch{3}
max_num_patches{4}
gaussian_sigma{5}
Automatic_normalization{6}
L2_weighting_factor{7}
\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
Training ParameterValue
number_of_epochs{0}
batch_size{1}
number_of_steps{2}
percentage_validation{3}
initial_learning_rate{4}
\n"," \"\"\".format(number_of_epochs,batch_size,number_of_steps,percentage_validation,initial_learning_rate)\n"," pdf.write_html(html)\n","\n"," pdf.ln(1)\n"," # pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(21, 5, txt= 'Model Path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n","\n"," pdf.ln(1)\n"," pdf.cell(60, 5, txt = 'Example Training Images', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread('/content/TrainingDataExample_DeepSTORM2D.png').shape\n"," pdf.image('/content/TrainingDataExample_DeepSTORM2D.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- Deep-STORM: Nehme, Elias, et al. \"Deep-STORM: super-resolution single-molecule microscopy by deep learning.\" Optica 5.4 (2018): 458-464.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," # if Use_Data_augmentation:\n"," # ref_3 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\n"," # pdf.multi_cell(190, 5, txt = ref_3, align='L')\n"," pdf.ln(3)\n"," reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(model_path+'/'+model_name+'/'+model_name+'_training_report.pdf')\n"," print('------------------------------')\n"," print('PDF report exported in '+model_path+'/'+model_name+'/')\n","\n","def qc_pdf_export():\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'Deep-STORM'\n"," #model_name = os.path.basename(full_QC_model_path)\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Quality Control report for '+Network+' model ('+os.path.basename(QC_model_path)+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(2)\n"," pdf.cell(190, 5, txt = 'Loss curves', ln=1, align='L')\n"," pdf.ln(1)\n"," if os.path.exists(savePath+'/lossCurvePlots.png'):\n"," exp_size = io.imread(savePath+'/lossCurvePlots.png').shape\n"," pdf.image(savePath+'/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/10))\n"," else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," pdf.cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.')\n"," pdf.ln(2)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(3)\n"," pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread(savePath+'/QC_example_data.png').shape\n"," pdf.image(savePath+'/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n","\n"," pdf.ln(1)\n"," html = \"\"\"\n"," \n"," \n"," \"\"\"\n"," with open(savePath+'/'+os.path.basename(QC_model_path)+'_QC_metrics.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," mSSIM_PvsGT = header[1]\n"," mSSIM_SvsGT = header[2]\n"," NRMSE_PvsGT = header[3]\n"," NRMSE_SvsGT = header[4]\n"," PSNR_PvsGT = header[5]\n"," PSNR_SvsGT = header[6]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,mSSIM_PvsGT,mSSIM_SvsGT,NRMSE_PvsGT,NRMSE_SvsGT,PSNR_PvsGT,PSNR_SvsGT)\n"," html = html+header\n"," for row in metrics:\n"," image = row[0]\n"," mSSIM_PvsGT = row[1]\n"," mSSIM_SvsGT = row[2]\n"," NRMSE_PvsGT = row[3]\n"," NRMSE_SvsGT = row[4]\n"," PSNR_PvsGT = row[5]\n"," PSNR_SvsGT = row[6]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,str(round(float(mSSIM_PvsGT),3)),str(round(float(mSSIM_SvsGT),3)),str(round(float(NRMSE_PvsGT),3)),str(round(float(NRMSE_SvsGT),3)),str(round(float(PSNR_PvsGT),3)),str(round(float(PSNR_SvsGT),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}{3}{4}{5}{6}
{0}{1}{2}{3}{4}{5}{6}
\"\"\"\n"," \n"," pdf.write_html(html)\n","\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- Deep-STORM: Nehme, Elias, et al. \"Deep-STORM: super-resolution single-molecule microscopy by deep learning.\" Optica 5.4 (2018): 458-464.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n"," pdf.ln(3)\n"," reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(savePath+'/'+os.path.basename(QC_model_path)+'_QC_report.pdf')\n","\n","\n"," print('------------------------------')\n"," print('QC PDF report exported as '+savePath+'/'+os.path.basename(QC_model_path)+'_QC_report.pdf')\n","\n","# Build requirements file for local run\n","after = [str(m) for m in sys.modules]\n","build_requirements_file(before, after)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"E04mOlG_H5Tz"},"source":["# **2. Complete the Colab session**\n","---"]},{"cell_type":"markdown","metadata":{"id":"F_tjlGzsH-Dn"},"source":["\n","## **2.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelerator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"gn-LaaNNICqL","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","# %tensorflow_version 1.x\n","\n","import tensorflow as tf\n","# if tf.__version__ != '2.2.0':\n","# !pip install tensorflow==2.2.0\n","\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime settings are correct then Google did not allocate GPU to your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi\n","\n","# from tensorflow.python.client import device_lib \n","# device_lib.list_local_devices()\n","\n","# print the tensorflow version\n","print('Tensorflow version is ' + str(tf.__version__))\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"tnP7wM79IKW-"},"source":["## **2.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"1R-7Fo34_gOd","cellView":"form"},"source":["#@markdown ##Run this cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","#mounts user's Google Drive to Google Colab.\n","\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"vu8f5NGJkJos"},"source":["\n","# **3. Generate patches for training**\n","---\n","\n","For Deep-STORM the training data can be obtained in two ways:\n","* Simulated using ThunderSTORM or other simulation tool and loaded here (**using Section 3.1.a**)\n","* Directly simulated in this notebook (**using Section 3.1.b**)\n"]},{"cell_type":"markdown","metadata":{"id":"WSV8xnlynp0l"},"source":["## **3.1.a Load training data**\n","---\n","\n","Here you can load your simulated data along with its corresponding localization file.\n","* The `pixel_size` is defined in nanometer (nm). "]},{"cell_type":"code","metadata":{"id":"CT6SNcfNg6j0","cellView":"form"},"source":["#@markdown ##Load raw data\n","\n","load_raw_data = True\n","\n","# Get user input\n","ImageData_path = \"\" #@param {type:\"string\"}\n","LocalizationData_path = \"\" #@param {type: \"string\"}\n","#@markdown Get pixel size from file?\n","get_pixel_size_from_file = True #@param {type:\"boolean\"}\n","#@markdown Otherwise, use this value:\n","pixel_size = 100 #@param {type:\"number\"}\n","\n","if get_pixel_size_from_file:\n"," pixel_size,_,_ = getPixelSizeTIFFmetadata(ImageData_path, True)\n","\n","# load the tiff data\n","Images = io.imread(ImageData_path)\n","# get dataset dimensions\n","if len(Images.shape) == 3:\n"," (number_of_frames, M, N) = Images.shape\n","elif len(Images.shape) == 2:\n"," (M, N) = Images.shape\n"," number_of_frames = 1\n","print('Loaded images: '+str(M)+'x'+str(N)+' with '+str(number_of_frames)+' frames')\n","\n","# Interactive display of the stack\n","def scroll_in_time(frame):\n"," f=plt.figure(figsize=(6,6))\n"," plt.imshow(Images[frame-1], interpolation='nearest', cmap = 'gray')\n"," plt.title('Training source at frame = ' + str(frame))\n"," plt.axis('off');\n","\n","if number_of_frames > 1:\n"," interact(scroll_in_time, frame=widgets.IntSlider(min=1, max=Images.shape[0], step=1, value=0, continuous_update=False));\n","else:\n"," f=plt.figure(figsize=(6,6))\n"," plt.imshow(Images, interpolation='nearest', cmap = 'gray')\n"," plt.title('Training source')\n"," plt.axis('off');\n","\n","# Load the localization file and display the first\n","LocData = pd.read_csv(LocalizationData_path, index_col=0)\n","LocData.tail()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"K9xE5GeYiks9"},"source":["## **3.1.b Simulate training data**\n","---\n","This simulation tool allows you to generate SMLM data of randomly distrubuted emitters in a field-of-view. \n","The assumptions are as follows:\n","\n","* Gaussian Point Spread Function (PSF) with standard deviation defined by `Sigma`. The nominal value of `sigma` can be evaluated using `sigma = 0.21 x Lambda / NA`. (from [Zhang *et al.*, Applied Optics 2007](https://doi.org/10.1364/AO.46.001819))\n","* Each emitter will emit `n_photons` per frame, and generate their equivalent Poisson noise.\n","* The camera will contribute Gaussian noise to the signal with a standard deviation defined by `ReadOutNoise_ADC` in ADC\n","* The `emitter_density` is defined as the number of emitters / um^2 on any given frame. Variability in the emitter density can be applied by adjusting `emitter_density_std`. The latter parameter represents the standard deviation of the normal distribution that the density is drawn from for each individual frame. `emitter_density` **is defined in number of emitters / um^2**.\n","* The `n_photons` and `sigma` can additionally include some Gaussian variability by setting `n_photons_std` and `sigma_std`.\n","\n","Important note:\n","- All dimensions are in nanometer (e.g. `FOV_size` = 6400 represents a field of view of 6.4 um x 6.4 um).\n","\n"]},{"cell_type":"code","metadata":{"id":"sQyLXpEhitsg","cellView":"form"},"source":["load_raw_data = False\n","\n","# ---------------------------- User input ----------------------------\n","#@markdown Run the simulation\n","#@markdown --- \n","#@markdown Camera settings: \n","FOV_size = 6400#@param {type:\"number\"}\n","pixel_size = 100#@param {type:\"number\"}\n","ADC_per_photon_conversion = 1 #@param {type:\"number\"}\n","ReadOutNoise_ADC = 4.5#@param {type:\"number\"}\n","ADC_offset = 50#@param {type:\"number\"}\n","\n","#@markdown Acquisition settings: \n","emitter_density = 6#@param {type:\"number\"}\n","emitter_density_std = 0#@param {type:\"number\"}\n","\n","number_of_frames = 20#@param {type:\"integer\"}\n","\n","sigma = 110 #@param {type:\"number\"}\n","sigma_std = 5 #@param {type:\"number\"}\n","# NA = 1.1 #@param {type:\"number\"}\n","# wavelength = 800#@param {type:\"number\"}\n","# wavelength_std = 150#@param {type:\"number\"}\n","n_photons = 2250#@param {type:\"number\"}\n","n_photons_std = 250#@param {type:\"number\"}\n","\n","\n","# ---------------------------- Variable initialisation ----------------------------\n","# Start the clock to measure how long it takes\n","start = time.time()\n","\n","print('-----------------------------------------------------------')\n","n_molecules = emitter_density*FOV_size*FOV_size/10**6\n","n_molecules_std = emitter_density_std*FOV_size*FOV_size/10**6\n","print('Number of molecules / FOV: '+str(round(n_molecules,2))+' +/- '+str((round(n_molecules_std,2))))\n","\n","# sigma = 0.21*wavelength/NA\n","# sigma_std = 0.21*wavelength_std/NA\n","# print('Gaussian PSF sigma: '+str(round(sigma,2))+' +/- '+str(round(sigma_std,2))+' nm')\n","\n","M = N = round(FOV_size/pixel_size)\n","FOV_size = M*pixel_size\n","print('Final image size: '+str(M)+'x'+str(M)+' ('+str(round(FOV_size/1000, 3))+'um x'+str(round(FOV_size/1000,3))+' um)')\n","\n","np.random.seed(1)\n","display_upsampling = 8 # used to display the loc map here\n","NoiseFreeImages = np.zeros((number_of_frames, M, M))\n","locImage = np.zeros((number_of_frames, display_upsampling*M, display_upsampling*N))\n","\n","frames = []\n","all_xloc = []\n","all_yloc = []\n","all_photons = []\n","all_sigmas = []\n","\n","# ---------------------------- Main simulation loop ----------------------------\n","print('-----------------------------------------------------------')\n","for f in tqdm(range(number_of_frames)):\n"," \n"," # Define the coordinates of emitters by randomly distributing them across the FOV\n"," n_mol = int(max(round(np.random.normal(n_molecules, n_molecules_std, size=1)[0]), 0))\n"," x_c = np.random.uniform(low=0.0, high=FOV_size, size=n_mol)\n"," y_c = np.random.uniform(low=0.0, high=FOV_size, size=n_mol)\n"," photon_array = np.random.normal(n_photons, n_photons_std, size=n_mol)\n"," sigma_array = np.random.normal(sigma, sigma_std, size=n_mol)\n"," # x_c = np.linspace(0,3000,5)\n"," # y_c = np.linspace(0,3000,5)\n","\n"," all_xloc += x_c.tolist()\n"," all_yloc += y_c.tolist()\n"," frames += ((f+1)*np.ones(x_c.shape[0])).tolist()\n"," all_photons += photon_array.tolist()\n"," all_sigmas += sigma_array.tolist()\n","\n"," locImage[f] = FromLoc2Image_SimpleHistogram(x_c, y_c, image_size = (N*display_upsampling, M*display_upsampling), pixel_size = pixel_size/display_upsampling)\n","\n"," # # Get the approximated locations according to the grid pixel size\n"," # Chr_emitters = [int(max(min(round(display_upsampling*x_c[i]/pixel_size),N*display_upsampling-1),0)) for i in range(len(x_c))]\n"," # Rhr_emitters = [int(max(min(round(display_upsampling*y_c[i]/pixel_size),M*display_upsampling-1),0)) for i in range(len(y_c))]\n","\n"," # # Build Localization image\n"," # for (r,c) in zip(Rhr_emitters, Chr_emitters):\n"," # locImage[f][r][c] += 1\n","\n"," NoiseFreeImages[f] = FromLoc2Image_Erf(x_c, y_c, photon_array, sigma_array, image_size = (M,M), pixel_size = pixel_size)\n","\n","\n","# ---------------------------- Create DataFrame fof localization file ----------------------------\n","# Table with localization info as dataframe output\n","LocData = pd.DataFrame()\n","LocData[\"frame\"] = frames\n","LocData[\"x [nm]\"] = all_xloc\n","LocData[\"y [nm]\"] = all_yloc\n","LocData[\"Photon #\"] = all_photons\n","LocData[\"Sigma [nm]\"] = all_sigmas\n","LocData.index += 1 # set indices to start at 1 and not 0 (same as ThunderSTORM)\n","\n","\n","# ---------------------------- Estimation of SNR ----------------------------\n","n_frames_for_SNR = 100\n","M_SNR = 10\n","x_c = np.random.uniform(low=0.0, high=pixel_size*M_SNR, size=n_frames_for_SNR)\n","y_c = np.random.uniform(low=0.0, high=pixel_size*M_SNR, size=n_frames_for_SNR)\n","photon_array = np.random.normal(n_photons, n_photons_std, size=n_frames_for_SNR)\n","sigma_array = np.random.normal(sigma, sigma_std, size=n_frames_for_SNR)\n","\n","SNR = np.zeros(n_frames_for_SNR)\n","for i in range(n_frames_for_SNR):\n"," SingleEmitterImage = FromLoc2Image_Erf(np.array([x_c[i]]), np.array([x_c[i]]), np.array([photon_array[i]]), np.array([sigma_array[i]]), (M_SNR, M_SNR), pixel_size)\n"," Signal_photon = np.max(SingleEmitterImage)\n"," Noise_photon = math.sqrt((ReadOutNoise_ADC/ADC_per_photon_conversion)**2 + Signal_photon)\n"," SNR[i] = Signal_photon/Noise_photon\n","\n","print('SNR: '+str(round(np.mean(SNR),2))+' +/- '+str(round(np.std(SNR),2)))\n","# ---------------------------- ----------------------------\n","\n","\n","# Table with info\n","simParameters = pd.DataFrame()\n","simParameters[\"FOV size (nm)\"] = [FOV_size]\n","simParameters[\"Pixel size (nm)\"] = [pixel_size]\n","simParameters[\"ADC/photon\"] = [ADC_per_photon_conversion]\n","simParameters[\"Read-out noise (ADC)\"] = [ReadOutNoise_ADC]\n","simParameters[\"Constant offset (ADC)\"] = [ADC_offset]\n","\n","simParameters[\"Emitter density (emitters/um^2)\"] = [emitter_density]\n","simParameters[\"STD of emitter density (emitters/um^2)\"] = [emitter_density_std]\n","simParameters[\"Number of frames\"] = [number_of_frames]\n","# simParameters[\"NA\"] = [NA]\n","# simParameters[\"Wavelength (nm)\"] = [wavelength]\n","# simParameters[\"STD of wavelength (nm)\"] = [wavelength_std]\n","simParameters[\"Sigma (nm))\"] = [sigma]\n","simParameters[\"STD of Sigma (nm))\"] = [sigma_std]\n","simParameters[\"Number of photons\"] = [n_photons]\n","simParameters[\"STD of number of photons\"] = [n_photons_std]\n","simParameters[\"SNR\"] = [np.mean(SNR)]\n","simParameters[\"STD of SNR\"] = [np.std(SNR)]\n","\n","\n","# ---------------------------- Finish simulation ----------------------------\n","# Calculating the noisy image\n","Images = ADC_per_photon_conversion * np.random.poisson(NoiseFreeImages) + ReadOutNoise_ADC * np.random.normal(size = (number_of_frames, M, N)) + ADC_offset\n","Images[Images <= 0] = 0\n","\n","# Convert to 16-bit or 32-bits integers\n","if Images.max() < (2**16-1):\n"," Images = Images.astype(np.uint16)\n","else:\n"," Images = Images.astype(np.uint32)\n","\n","\n","# ---------------------------- Display ----------------------------\n","# Displaying the time elapsed for simulation\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds,1),\"sec(s)\")\n","\n","\n","# Interactively display the results using Widgets\n","def scroll_in_time(frame):\n"," f = plt.figure(figsize=(18,6))\n"," plt.subplot(1,3,1)\n"," plt.imshow(locImage[frame-1], interpolation='bilinear', vmin = 0, vmax=0.1)\n"," plt.title('Localization image')\n"," plt.axis('off');\n","\n"," plt.subplot(1,3,2)\n"," plt.imshow(NoiseFreeImages[frame-1], interpolation='nearest', cmap='gray')\n"," plt.title('Noise-free simulation')\n"," plt.axis('off');\n","\n"," plt.subplot(1,3,3)\n"," plt.imshow(Images[frame-1], interpolation='nearest', cmap='gray')\n"," plt.title('Noisy simulation')\n"," plt.axis('off');\n","\n","interact(scroll_in_time, frame=widgets.IntSlider(min=1, max=Images.shape[0], step=1, value=0, continuous_update=False));\n","\n","# Display the head of the dataframe with localizations\n","LocData.tail()\n"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"Pz7RfSuoeJeq","cellView":"form"},"source":["#@markdown ---\n","#@markdown ##Play this cell to save the simulated stack\n","#@markdown Please select a path to the folder where to save the simulated data. It is not necessary to save the data to run the training, but keeping the simulated for your own record can be useful to check its validity.\n","Save_path = \"\" #@param {type:\"string\"}\n","\n","if not os.path.exists(Save_path):\n"," os.makedirs(Save_path)\n"," print('Folder created.')\n","else:\n"," print('Training data already exists in folder: Data overwritten.')\n","\n","saveAsTIF(Save_path, 'SimulatedDataset', Images, pixel_size)\n","# io.imsave(os.path.join(Save_path, 'SimulatedDataset.tif'),Images)\n","LocData.to_csv(os.path.join(Save_path, 'SimulatedDataset.csv'))\n","simParameters.to_csv(os.path.join(Save_path, 'SimulatedParameters.csv'))\n","print('Training dataset saved.')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"K_8e3kE-JhVY"},"source":["## **3.2. Generate training patches**\n","---\n","\n","Training patches need to be created from the training data generated above. \n","* The `patch_size` needs to give sufficient contextual information and for most cases a `patch_size` of 26 (corresponding to patches of 26x26 pixels) works fine. **DEFAULT: 26**\n","* The `upsampling_factor` defines the effective magnification of the final super-resolved image compared to the input image (this is called magnification in ThunderSTORM). This is used to generate the super-resolved patches as target dataset. Using an `upsampling_factor` of 16 will require the use of more memory and it may be necessary to decreae the `patch_size` to 16 for example. **DEFAULT: 8**\n","* The `num_patches_per_frame` defines the number of patches extracted from each frame generated in section 3.1. **DEFAULT: 500**\n","* The `min_number_of_emitters_per_patch` defines the minimum number of emitters that need to be present in the patch to be a valid patch. An empty patch does not contain useful information for the network to learn from. **DEFAULT: 7**\n","* The `max_num_patches` defines the maximum number of patches to generate. Fewer may be generated depending on how many pacthes are rejected and how many frames are available. **DEFAULT: 10000**\n","* The `gaussian_sigma` defines the Gaussian standard deviation (in magnified pixels) applied to generate the super-resolved target image. **DEFAULT: 1**\n","* The `L2_weighting_factor` is a normalization factor used in the loss function. It helps balancing the loss from the L2 norm. When using higher densities, this factor should be decreased and vice-versa. This factor can be autimatically calculated using an empiraical formula. **DEFAULT: 100**\n","\n"]},{"cell_type":"code","metadata":{"id":"AsNx5KzcFNvC","cellView":"form"},"source":["#@markdown ## **Provide patch parameters**\n","\n","\n","# -------------------- User input --------------------\n","patch_size = 26 #@param {type:\"integer\"}\n","upsampling_factor = 8 #@param [\"4\", \"8\", \"16\"] {type:\"raw\"}\n","num_patches_per_frame = 500#@param {type:\"integer\"}\n","min_number_of_emitters_per_patch = 7#@param {type:\"integer\"}\n","max_num_patches = 10000#@param {type:\"integer\"}\n","gaussian_sigma = 1#@param {type:\"integer\"}\n","\n","#@markdown Estimate the optimal normalization factor automatically?\n","Automatic_normalization = True #@param {type:\"boolean\"}\n","#@markdown Otherwise, it will use the following value:\n","L2_weighting_factor = 100 #@param {type:\"number\"}\n","\n","\n","# -------------------- Prepare variables --------------------\n","# Start the clock to measure how long it takes\n","start = time.time()\n","\n","# Initialize some parameters\n","pixel_size_hr = pixel_size/upsampling_factor # in nm\n","n_patches = min(number_of_frames*num_patches_per_frame, max_num_patches)\n","patch_size = patch_size*upsampling_factor\n","\n","# Dimensions of the high-res grid\n","Mhr = upsampling_factor*M # in pixels\n","Nhr = upsampling_factor*N # in pixels\n","\n","# Initialize the training patches and labels\n","patches = np.zeros((n_patches, patch_size, patch_size), dtype = np.float32)\n","spikes = np.zeros((n_patches, patch_size, patch_size), dtype = np.float32)\n","heatmaps = np.zeros((n_patches, patch_size, patch_size), dtype = np.float32)\n","\n","# Run over all frames and construct the training examples\n","k = 1 # current patch count\n","skip_counter = 0 # number of dataset skipped due to low density\n","id_start = 0 # id position in LocData for current frame\n","print('Generating '+str(n_patches)+' patches of '+str(patch_size)+'x'+str(patch_size))\n","\n","n_locs = len(LocData.index)\n","print('Total number of localizations: '+str(n_locs))\n","density = n_locs/(M*N*number_of_frames*(0.001*pixel_size)**2)\n","print('Density: '+str(round(density,2))+' locs/um^2')\n","n_locs_per_patch = patch_size**2*density\n","\n","if Automatic_normalization:\n"," # This empirical formulae attempts to balance the loss L2 function between the background and the bright spikes\n"," # A value of 100 was originally chosen to balance L2 for a patch size of 2.6x2.6^2 0.1um pixel size and density of 3 (hence the 20.28), at upsampling_factor = 8\n"," L2_weighting_factor = 100/math.sqrt(min(n_locs_per_patch, min_number_of_emitters_per_patch)*8**2/(upsampling_factor**2*20.28))\n"," print('Normalization factor: '+str(round(L2_weighting_factor,2)))\n","\n","# -------------------- Patch generation loop --------------------\n","\n","print('-----------------------------------------------------------')\n","for (f, thisFrame) in enumerate(tqdm(Images)):\n","\n"," # Upsample the frame\n"," upsampledFrame = np.kron(thisFrame, np.ones((upsampling_factor,upsampling_factor)))\n"," # Read all the provided high-resolution locations for current frame\n"," DataFrame = LocData[LocData['frame'] == f+1].copy()\n","\n"," # Get the approximated locations according to the high-res grid pixel size\n"," Chr_emitters = [int(max(min(round(DataFrame['x [nm]'][i]/pixel_size_hr),Nhr-1),0)) for i in range(id_start+1,id_start+1+len(DataFrame.index))]\n"," Rhr_emitters = [int(max(min(round(DataFrame['y [nm]'][i]/pixel_size_hr),Mhr-1),0)) for i in range(id_start+1,id_start+1+len(DataFrame.index))]\n"," id_start += len(DataFrame.index)\n","\n"," # Build Localization image\n"," LocImage = np.zeros((Mhr,Nhr))\n"," LocImage[(Rhr_emitters, Chr_emitters)] = 1\n","\n"," # Here, there's a choice between the original Gaussian (classification approach) and using the erf function\n"," HeatMapImage = L2_weighting_factor*gaussian_filter(LocImage, gaussian_sigma) \n"," # HeatMapImage = L2_weighting_factor*FromLoc2Image_MultiThreaded(np.array(list(DataFrame['x [nm]'])), np.array(list(DataFrame['y [nm]'])), \n"," # np.ones(len(DataFrame.index)), pixel_size_hr*gaussian_sigma*np.ones(len(DataFrame.index)), \n"," # Mhr, pixel_size_hr)\n"," \n","\n"," # Generate random position for the top left corner of the patch\n"," xc = np.random.randint(0, Mhr-patch_size, size=num_patches_per_frame)\n"," yc = np.random.randint(0, Nhr-patch_size, size=num_patches_per_frame)\n","\n"," for c in range(len(xc)):\n"," if LocImage[xc[c]:xc[c]+patch_size, yc[c]:yc[c]+patch_size].sum() < min_number_of_emitters_per_patch:\n"," skip_counter += 1\n"," continue\n"," \n"," else:\n"," # Limit maximal number of training examples to 15k\n"," if k > max_num_patches:\n"," break\n"," else:\n"," # Assign the patches to the right part of the images\n"," patches[k-1] = upsampledFrame[xc[c]:xc[c]+patch_size, yc[c]:yc[c]+patch_size]\n"," spikes[k-1] = LocImage[xc[c]:xc[c]+patch_size, yc[c]:yc[c]+patch_size]\n"," heatmaps[k-1] = HeatMapImage[xc[c]:xc[c]+patch_size, yc[c]:yc[c]+patch_size]\n"," k += 1 # increment current patch count\n","\n","# Remove the empty data\n","patches = patches[:k-1]\n","spikes = spikes[:k-1]\n","heatmaps = heatmaps[:k-1]\n","n_patches = k-1\n","\n","# -------------------- Failsafe --------------------\n","# Check if the size of the training set is smaller than 5k to notify user to simulate more images using ThunderSTORM\n","if ((k-1) < 5000):\n"," # W = '\\033[0m' # white (normal)\n"," # R = '\\033[31m' # red\n"," print(bcolors.WARNING+'!! WARNING: Training set size is below 5K - Consider simulating more images in ThunderSTORM. !!'+bcolors.NORMAL)\n","\n","\n","\n","# -------------------- Displays --------------------\n","print('Number of patches skipped due to low density: '+str(skip_counter))\n","# dataSize = int((getsizeof(patches)+getsizeof(heatmaps)+getsizeof(spikes))/(1024*1024)) #rounded in MB\n","# print('Size of patches: '+str(dataSize)+' MB')\n","print(str(n_patches)+' patches were generated.')\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n","\n","# Display patches interactively with a slider\n","def scroll_patches(patch):\n"," f = plt.figure(figsize=(16,6))\n"," plt.subplot(1,3,1)\n"," plt.imshow(patches[patch-1], interpolation='nearest', cmap='gray')\n"," plt.title('Raw data (frame #'+str(patch)+')')\n"," plt.axis('off');\n","\n"," plt.subplot(1,3,2)\n"," plt.imshow(heatmaps[patch-1], interpolation='nearest')\n"," plt.title('Heat map')\n"," plt.axis('off');\n","\n"," plt.subplot(1,3,3)\n"," plt.imshow(spikes[patch-1], interpolation='nearest')\n"," plt.title('Localization map')\n"," plt.axis('off');\n"," \n"," plt.savefig('/content/TrainingDataExample_DeepSTORM2D.png',bbox_inches='tight',pad_inches=0)\n","\n","\n","interact(scroll_patches, patch=widgets.IntSlider(min=1, max=patches.shape[0], step=1, value=0, continuous_update=False));\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"DSjXFMevK7Iz"},"source":["# **4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"hVeyKU0MdAPx"},"source":["## **4.1. Select your paths and parameters**\n","\n","---\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","\n","**Training parameters**\n","\n","**`number_of_epochs`:**Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for ~100 epochs. Evaluate the performance after training (see 5). **Default value: 80**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 16**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. **If this value is set to 0**, by default this parameter is calculated so that each patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during training. **Default value: 30** \n","\n","**`initial_learning_rate`:** This parameter represents the initial value to be used as learning rate in the optimizer. **Default value: 0.001**"]},{"cell_type":"code","metadata":{"id":"oa5cDZ7f_PF6","cellView":"form"},"source":["#@markdown ###Path to training images and parameters\n","\n","model_path = \"\" #@param {type: \"string\"} \n","model_name = \"\" #@param {type: \"string\"} \n","number_of_epochs = 100#@param {type:\"integer\"}\n","batch_size = 16#@param {type:\"integer\"}\n","\n","number_of_steps = 0#@param {type:\"integer\"}\n","percentage_validation = 30 #@param {type:\"number\"}\n","initial_learning_rate = 0.001 #@param {type:\"number\"}\n","\n","\n","percentage_validation /= 100\n","if number_of_steps == 0: \n"," number_of_steps = int((1-percentage_validation)*n_patches/batch_size)\n"," print('Number of steps: '+str(number_of_steps))\n","\n","# Pretrained model path initialised here so next cell does not need to be run\n","h5_file_path = ''\n","Use_pretrained_model = False\n","\n","if not ('patches' in locals()):\n"," # W = '\\033[0m' # white (normal)\n"," # R = '\\033[31m' # red\n"," print(WARNING+'!! WARNING: No patches were found in memory currently. !!')\n","\n","Save_path = os.path.join(model_path, model_name)\n","if os.path.exists(Save_path):\n"," print(bcolors.WARNING+'The model folder already exists and will be overwritten.'+bcolors.NORMAL)\n","\n","print('-----------------------------')\n","print('Training parameters set.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"WIyEvQBWLp9n"},"source":["\n","## **4.2. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a Deep-STORM 2D model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"oHL5g0w8LqR0","cellView":"form"},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","Weights_choice = \"last\" #@param [\"last\", \"best\"]\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".hdf5\")\n","\n","# --------------------- Download the a model provided in the XXX ------------------------\n","\n"," if pretrained_model_choice == \"Model_name\":\n"," pretrained_model_name = \"Model_name\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the 2D_Demo_Model_from_Stardist_2D_paper\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path) \n"," wget.download(\"\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".hdf5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_'+Weights_choice+'.hdf5 pretrained model does not exist'+bcolors.NORMAL)\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead.'+bcolors.NORMAL)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+bcolors.NORMAL)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print('No pretrained network will be used.')\n"," h5_file_path = ''\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"OADNcie-LHxA"},"source":["## **4.4. Start Training**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches.\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder from Google Drive as all data can be erased at the next training if using the same folder."]},{"cell_type":"code","metadata":{"id":"qDgMu_mAK8US","cellView":"form"},"source":["#@markdown ##Start training\n","\n","# Start the clock to measure how long it takes\n","start = time.time()\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","\n","#here we check that no model with the same name already exist, if so delete\n","if os.path.exists(Save_path):\n"," shutil.rmtree(Save_path)\n","\n","# Create the model folder!\n","os.makedirs(Save_path)\n","\n","# Export pdf summary \n","pdf_export(raw_data = load_raw_data, pretrained_model = Use_pretrained_model)\n","\n","# Let's go !\n","train_model(patches, heatmaps, Save_path, \n"," steps_per_epoch=number_of_steps, epochs=number_of_epochs, batch_size=batch_size,\n"," upsampling_factor = upsampling_factor,\n"," validation_split = percentage_validation,\n"," initial_learning_rate = initial_learning_rate, \n"," pretrained_model_path = h5_file_path,\n"," L2_weighting_factor = L2_weighting_factor)\n","\n","# # Show info about the GPU memory useage\n","# !nvidia-smi\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n","\n","# export pdf after training to update the existing document\n","pdf_export(trained = True, raw_data = load_raw_data, pretrained_model = Use_pretrained_model)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"4N7-ShZpLhwr"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**"]},{"cell_type":"code","metadata":{"id":"JDRsm7uKoBa-","cellView":"form"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","#@markdown #####During training, the model files are automatically saved inside a folder named after the parameter `model_name` (see section 4.1). Provide the name of this folder as `QC_model_path` . \n","\n","QC_model_path = \"\" #@param {type:\"string\"}\n","\n","if (Use_the_current_trained_model): \n"," QC_model_path = os.path.join(model_path, model_name)\n","\n","if os.path.exists(QC_model_path):\n"," print(\"The \"+os.path.basename(QC_model_path)+\" model will be evaluated\")\n","else:\n"," print(bcolors.WARNING+'!! WARNING: The chosen model does not exist !!'+bcolors.NORMAL)\n"," print('Please make sure you provide a valid model path before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Gw7KaHZUoHC4"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"id":"qUc-JMOcoGNZ","cellView":"form"},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(os.path.join(QC_model_path,'Quality Control/training_evaluation.csv'),'r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," if row:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(os.path.join(QC_model_path,'Quality Control/lossCurvePlots.png'), bbox_inches='tight', pad_inches=0)\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"32eNQjFioQkY"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"QC_image_folder\" using teh corresponding localization data contained in \"QC_loc_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"dhlTnxC5lUZy","cellView":"form"},"source":["\n","# ------------------------ User input ------------------------\n","#@markdown ##Choose the folders that contain your Quality Control dataset\n","QC_image_folder = \"\" #@param{type:\"string\"}\n","QC_loc_folder = \"\" #@param{type:\"string\"}\n","#@markdown Get pixel size from file?\n","get_pixel_size_from_file = True #@param {type:\"boolean\"}\n","#@markdown Otherwise, use this value:\n","pixel_size = 100 #@param {type:\"number\"}\n","\n","if get_pixel_size_from_file:\n"," pixel_size_INPUT = None\n","else:\n"," pixel_size_INPUT = pixel_size\n","\n","\n","# ------------------------ QC analysis loop over provided dataset ------------------------\n","\n","savePath = os.path.join(QC_model_path, 'Quality Control')\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(os.path.join(savePath, os.path.basename(QC_model_path)+\"_QC_metrics.csv\"), \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"WF v. GT mSSIM\", \"Prediction v. GT NRMSE\",\"WF v. GT NRMSE\", \"Prediction v. GT PSNR\", \"WF v. GT PSNR\"])\n","\n"," # These lists will be used to collect all the metrics values per slice\n"," file_name_list = []\n"," slice_number_list = []\n"," mSSIM_GvP_list = []\n"," mSSIM_GvWF_list = []\n"," NRMSE_GvP_list = []\n"," NRMSE_GvWF_list = []\n"," PSNR_GvP_list = []\n"," PSNR_GvWF_list = []\n","\n"," # Let's loop through the provided dataset in the QC folders\n","\n"," for (imageFilename, locFilename) in zip(list_files(QC_image_folder, 'tif'), list_files(QC_loc_folder, 'csv')):\n"," print('--------------')\n"," print(imageFilename)\n"," print(locFilename)\n","\n"," # Get the prediction\n"," batchFramePredictionLocalization(QC_image_folder, imageFilename, QC_model_path, savePath, pixel_size = pixel_size_INPUT)\n","\n"," # test_model(QC_image_folder, imageFilename, QC_model_path, savePath, display=False);\n"," thisPrediction = io.imread(os.path.join(savePath, 'Predicted_'+imageFilename))\n"," thisWidefield = io.imread(os.path.join(savePath, 'Widefield_'+imageFilename))\n","\n"," Mhr = thisPrediction.shape[0]\n"," Nhr = thisPrediction.shape[1]\n","\n"," if pixel_size_INPUT == None:\n"," pixel_size, N, M = getPixelSizeTIFFmetadata(os.path.join(QC_image_folder,imageFilename))\n","\n"," upsampling_factor = int(Mhr/M)\n"," print('Upsampling factor: '+str(upsampling_factor))\n"," pixel_size_hr = pixel_size/upsampling_factor # in nm\n","\n"," # Load the localization file and display the first\n"," LocData = pd.read_csv(os.path.join(QC_loc_folder,locFilename), index_col=0)\n","\n"," x = np.array(list(LocData['x [nm]']))\n"," y = np.array(list(LocData['y [nm]']))\n"," locImage = FromLoc2Image_SimpleHistogram(x, y, image_size = (Mhr,Nhr), pixel_size = pixel_size_hr)\n","\n"," # Remove extension from filename\n"," imageFilename_no_extension = os.path.splitext(imageFilename)[0]\n","\n"," # io.imsave(os.path.join(savePath, 'GT_image_'+imageFilename), locImage)\n"," saveAsTIF(savePath, 'GT_image_'+imageFilename_no_extension, locImage, pixel_size_hr)\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and prediction\n"," test_GT_norm, test_prediction_norm = norm_minmse(locImage, thisPrediction, normalize_gt=True)\n"," # Normalize the images wrt each other by minimizing the MSE between GT and Source image\n"," test_GT_norm, test_wf_norm = norm_minmse(locImage, thisWidefield, normalize_gt=True)\n","\n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\n","\n"," # Calculate the SSIM maps\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = structural_similarity(test_GT_norm, test_prediction_norm, data_range=1., full=True)\n"," index_SSIM_GTvsWF, img_SSIM_GTvsWF = structural_similarity(test_GT_norm, test_wf_norm, data_range=1., full=True)\n","\n","\n"," # Save ssim_maps\n"," img_SSIM_GTvsPrediction_32bit = np.float32(img_SSIM_GTvsPrediction)\n"," # io.imsave(os.path.join(savePath,'SSIM_GTvsPrediction_'+imageFilename),img_SSIM_GTvsPrediction_32bit)\n"," saveAsTIF(savePath,'SSIM_GTvsPrediction_'+imageFilename_no_extension, img_SSIM_GTvsPrediction_32bit, pixel_size_hr)\n","\n","\n"," img_SSIM_GTvsWF_32bit = np.float32(img_SSIM_GTvsWF)\n"," # io.imsave(os.path.join(savePath,'SSIM_GTvsWF_'+imageFilename),img_SSIM_GTvsWF_32bit)\n"," saveAsTIF(savePath,'SSIM_GTvsWF_'+imageFilename_no_extension, img_SSIM_GTvsWF_32bit, pixel_size_hr)\n","\n"," \n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsWF = np.sqrt(np.square(test_GT_norm - test_wf_norm))\n","\n"," # Save SE maps\n"," img_RSE_GTvsPrediction_32bit = np.float32(img_RSE_GTvsPrediction)\n"," # io.imsave(os.path.join(savePath,'RSE_GTvsPrediction_'+imageFilename),img_RSE_GTvsPrediction_32bit)\n"," saveAsTIF(savePath,'RSE_GTvsPrediction_'+imageFilename_no_extension, img_RSE_GTvsPrediction_32bit, pixel_size_hr)\n","\n"," img_RSE_GTvsWF_32bit = np.float32(img_RSE_GTvsWF)\n"," # io.imsave(os.path.join(savePath,'RSE_GTvsWF_'+imageFilename),img_RSE_GTvsWF_32bit)\n"," saveAsTIF(savePath,'RSE_GTvsWF_'+imageFilename_no_extension, img_RSE_GTvsWF_32bit, pixel_size_hr)\n","\n","\n"," # -------------------------------- Calculate the RSE metrics and save them --------------------------------\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsWF = np.sqrt(np.mean(img_RSE_GTvsWF))\n"," \n"," # We can also measure the peak signal to noise ratio between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsWF = psnr(test_GT_norm,test_wf_norm,data_range=1.0)\n","\n"," writer.writerow([imageFilename,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsWF),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsWF),str(PSNR_GTvsPrediction), str(PSNR_GTvsWF)])\n","\n"," # Collect values to display in dataframe output\n"," file_name_list.append(imageFilename)\n"," mSSIM_GvP_list.append(index_SSIM_GTvsPrediction)\n"," mSSIM_GvWF_list.append(index_SSIM_GTvsWF)\n"," NRMSE_GvP_list.append(NRMSE_GTvsPrediction)\n"," NRMSE_GvWF_list.append(NRMSE_GTvsWF)\n"," PSNR_GvP_list.append(PSNR_GTvsPrediction)\n"," PSNR_GvWF_list.append(PSNR_GTvsWF)\n","\n","\n","# Table with metrics as dataframe output\n","pdResults = pd.DataFrame(index = file_name_list)\n","pdResults[\"Prediction v. GT mSSIM\"] = mSSIM_GvP_list\n","pdResults[\"Wide-field v. GT mSSIM\"] = mSSIM_GvWF_list\n","pdResults[\"Prediction v. GT NRMSE\"] = NRMSE_GvP_list\n","pdResults[\"Wide-field v. GT NRMSE\"] = NRMSE_GvWF_list\n","pdResults[\"Prediction v. GT PSNR\"] = PSNR_GvP_list\n","pdResults[\"Wide-field v. GT PSNR\"] = PSNR_GvWF_list\n","\n","\n","# ------------------------ Display ------------------------\n","\n","print('--------------------------------------------')\n","@interact\n","def show_QC_results(file = list_files(QC_image_folder, 'tif')):\n","\n"," plt.figure(figsize=(15,15))\n"," # Target (Ground-truth)\n"," plt.subplot(3,3,1)\n"," plt.axis('off')\n"," img_GT = io.imread(os.path.join(savePath, 'GT_image_'+file))\n"," plt.imshow(img_GT, norm = simple_norm(img_GT, percent = 99.5))\n"," plt.title('Target',fontsize=15)\n","\n"," # Wide-field\n"," plt.subplot(3,3,2)\n"," plt.axis('off')\n"," img_Source = io.imread(os.path.join(savePath, 'Widefield_'+file))\n"," plt.imshow(img_Source, norm = simple_norm(img_Source, percent = 99.5))\n"," plt.title('Widefield',fontsize=15)\n","\n"," #Prediction\n"," plt.subplot(3,3,3)\n"," plt.axis('off')\n"," img_Prediction = io.imread(os.path.join(savePath, 'Predicted_'+file))\n"," plt.imshow(img_Prediction, norm = simple_norm(img_Prediction, percent = 99.5))\n"," plt.title('Prediction',fontsize=15)\n","\n"," #Setting up colours\n"," cmap = plt.cm.CMRmap\n","\n"," #SSIM between GT and Source\n"," plt.subplot(3,3,5)\n"," #plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n"," img_SSIM_GTvsWF = io.imread(os.path.join(savePath, 'SSIM_GTvsWF_'+file))\n"," imSSIM_GTvsWF = plt.imshow(img_SSIM_GTvsWF, cmap = cmap, vmin=0, vmax=1)\n"," plt.colorbar(imSSIM_GTvsWF,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Widefield',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(pdResults.loc[file][\"Wide-field v. GT mSSIM\"],3)),fontsize=14)\n"," plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n"," #SSIM between GT and Prediction\n"," plt.subplot(3,3,6)\n"," #plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n"," img_SSIM_GTvsPrediction = io.imread(os.path.join(savePath, 'SSIM_GTvsPrediction_'+file))\n"," imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n"," plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(pdResults.loc[file][\"Prediction v. GT mSSIM\"],3)),fontsize=14)\n","\n"," #Root Squared Error between GT and Source\n"," plt.subplot(3,3,8)\n"," #plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n"," img_RSE_GTvsWF = io.imread(os.path.join(savePath, 'RSE_GTvsWF_'+file))\n"," imRSE_GTvsWF = plt.imshow(img_RSE_GTvsWF, cmap = cmap, vmin=0, vmax = 1)\n"," plt.colorbar(imRSE_GTvsWF,fraction=0.046,pad=0.04)\n"," plt.title('Target vs. Widefield',fontsize=15)\n"," plt.xlabel('NRMSE: '+str(round(pdResults.loc[file][\"Wide-field v. GT NRMSE\"],3))+', PSNR: '+str(round(pdResults.loc[file][\"Wide-field v. GT PSNR\"],3)),fontsize=14)\n"," plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n"," #Root Squared Error between GT and Prediction\n"," plt.subplot(3,3,9)\n"," #plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n"," img_RSE_GTvsPrediction = io.imread(os.path.join(savePath, 'RSE_GTvsPrediction_'+file))\n"," imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n"," plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('NRMSE: '+str(round(pdResults.loc[file][\"Prediction v. GT NRMSE\"],3))+', PSNR: '+str(round(pdResults.loc[file][\"Prediction v. GT PSNR\"],3)),fontsize=14)\n"," plt.savefig(QC_model_path+'/Quality Control/QC_example_data.png', bbox_inches='tight', pad_inches=0)\n","print('--------------------------------------------')\n","pdResults.head()\n","\n","# Export pdf wth summary of QC results\n","qc_pdf_export()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"aqG902y25P4r"},"source":["## **5.4. Export your model into the BioImage Model Zoo format**\n","---\n","This section exports the model into the BioImage Model Zoo format so it can be used directly with DeepImageJ. The new files will be stored in the model folder specified at the beginning of Section 5. \n","\n","Once the cell is executed, you will find a new zip file with the name specified in `Trained_model_name.bioimage.io.model`.\n","\n","To use it with deepImageJ, download it and unzip it in the ImageJ/models/ or Fiji/models/ folder of your local machine. \n","\n","In ImageJ, open the example image given within the downloaded zip file. Go to Plugins > DeepImageJ > DeepImageJ Run. Choose this model from the list and click OK.\n","\n"," More information at https://deepimagej.github.io/deepimagej/"]},{"cell_type":"code","metadata":{"cellView":"form","id":"IWqJjg9B5QS4"},"source":["# ------------- User input ------------\n","# information about the model\n","# @markdown ##Introduce the metadata of the model architecture:\n","Trained_model_name = \"\" #@param {type:\"string\"}\n","Trained_model_authors = \"[Author 1, Author 2, Author 3]\" #@param {type:\"string\"}\n","Trained_model_description = \"\" #@param {type:\"string\"}\n","Trained_model_license = 'MIT'#@param {type:\"string\"}\n","Trained_model_references = [\"Nehme E. et al., Optica 2018;\", \"Lucas von Chamier et al., biorXiv 2020\"]\n","Trained_model_DOI = [\"https://doi.org/10.1364/OPTICA.5.000458\", \"https://doi.org/10.1101/2020.03.20.000133\"]\n","\n","# information about the example image\n","#@markdown ##Do you want to choose the example image?\n","default_example_image = True #@param {type:\"boolean\"}\n","#@markdown ###If not, please input the path to the file:\n","example_image_file = \"\" #@param {type:\"string\"}\n","#@markdown ##Get pixel size from file?\n","get_pixel_size_from_file = True #@param {type:\"boolean\"}\n","#@markdown ###Otherwise, use this value (in nm):\n","PixelSize = 100 #@param {type:\"number\"}\n","\n","if default_example_image:\n"," files = [file for file in list_files(QC_image_folder, 'tif')]\n"," example_image_file = os.path.join(QC_image_folder,files[0])\n","\n","if get_pixel_size_from_file:\n"," PixelSize, _, _ = getPixelSizeTIFFmetadata(example_image_file, display=True)\n","\n","# Create one example image (a 2D slice) and its output\n","\n","## Default values and data especific ones.\n","matfile = sio.loadmat(os.path.join(QC_model_path,'model_metadata.mat'))\n","test_mean = matfile['mean_test'].item() # convert to scalar\n","test_std = matfile['std_test'].item() # convert to scalar\n","upsampling_factor = np.array(matfile['upsampling_factor'])\n","upsampling_factor = upsampling_factor.item() # convert to scalar\n","L2_weighting_factor = np.array(matfile['Normalization factor'])\n","L2_weighting_factor = L2_weighting_factor.item() # convert to scalar\n","thresh = 0.1\n","pixel_size_hr = PixelSize/upsampling_factor\n","neighborhood_size = 3\n","# max_layer_thresh = thresh*L2_weighting_factor\n","\n","# Read the input image\n","test_img_stack = io.imread(example_image_file)\n","if len(test_img_stack.shape)>2:\n"," test_img = test_img_stack[0]\n","else:\n"," test_img = test_img_stack\n","input_im = project_01(test_img)\n","input_im = normalize_im(input_im, test_mean, test_std)\n","input_im = np.kron(input_im, np.ones((upsampling_factor,upsampling_factor)))\n","\n","# get dataset dimensions\n","(M, N) = input_im.shape\n","\n","# Build the model for a bigger image\n","model = buildModel((M, N, 1))\n","# Load the trained weights\n","model.load_weights(os.path.join(QC_model_path,'weights_best.hdf5'))\n","\n","# Reshaping\n","input_im = np.expand_dims(input_im, axis=[0,-1])\n","input_im = input_im.astype(np.float32)\n","# Inference\n","predicted_density = model.predict(input_im)\n","# Post-processing\n","predicted_density[predicted_density < 0] = 0\n","test_prediction = predicted_density.sum(axis = 3).sum(axis = 0)\n","test_prediction /= L2_weighting_factor\n","\n","# # Reduce model input size if necessary to avoid out of memory errors in Fiji\n","# M = np.min((M, 512))\n","# N = np.min((N, 512))\n","# Build the model for a bigger image\n","model = buildModel((M, N, 1))\n","# Load the trained weights\n","model.load_weights(os.path.join(QC_model_path,'weights_best.hdf5'))\n","\n","# Run this cell to export the model to the BioImage Model Zoo format.\n","####\n","from pydeepimagej.yaml import BioImageModelZooConfig\n","# from pydeepimagej.yaml.bioimage_specifications import get_specification\n","import urllib\n","\n","# ------------- Execute bioimage model zoo configuration ------------\n","# Check minimum size: it is [8,8] for the 2D XY plane\n","pooling_steps = 0\n","for keras_layer in model.layers:\n"," if keras_layer.name.startswith('max') or \"pool\" in keras_layer.name:\n"," pooling_steps += 1\n","MinimumSize = [2**(pooling_steps), 2**(pooling_steps)]\n","\n","dij_config = BioImageModelZooConfig(model, MinimumSize)\n","# we avoid padding for SMLM\n","dij_config.Halo = [0, 0]\n","\n","# Model developer details\n","dij_config.Authors = Trained_model_authors[1:-1].split(',')\n","dij_config.Description = Trained_model_description\n","dij_config.Name = Trained_model_name\n","dij_config.References = Trained_model_references\n","dij_config.DOI = Trained_model_DOI\n","dij_config.License = Trained_model_license\n","\n","# Additional information about the model\n","dij_config.GitHub = '/~https://github.com/HenriquesLab/ZeroCostDL4Mic'\n","dij_config.Date = datetime.now()\n","dij_config.Documentation = '/~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki'\n","dij_config.Tags = ['ZeroCostDL4Mic', 'deepimagej', 'SMLM', 'super-resolution', 'image reconstruction']\n","dij_config.Framework = 'tensorflow'\n","\n","# Add the information about the test image. Note here PixelSize is given in nm\n","dij_config.add_test_info(test_img, test_prediction, [0.001*PixelSize, 0.001*PixelSize])\n","dij_config.create_covers([test_img, test_prediction])\n","dij_config.Covers = ['./input.png', './output.png']\n","\n","# Store the model weights\n","# ---------------------------------------\n","# used_bioimageio_model_for_training_URL = \"/Some/URL/bioimage.io/\"\n","# dij_config.Parent = used_bioimageio_model_for_training_URL\n","\n","# Add weights information\n","format_authors = [\"pydeepimagej\"]\n","dij_config.add_weights_formats(model, 'TensorFlow', \n"," parent=\"keras_hdf5\",\n"," authors=[a for a in format_authors])\n","dij_config.add_weights_formats(model, 'KerasHDF5', \n"," authors=[a for a in format_authors])\n","\n","## Prepare preprocessing file\n","path_preprocessing = \"MeanNormalization.ijm\"\n","urllib.request.urlretrieve(\"https://raw.githubusercontent.com/deepimagej/imagej-macros/master/MeanNormalization.ijm\", path_preprocessing )\n","# Modify the threshold in the macro to the chosen threshold\n","ijmacro = open(path_preprocessing,\"r\") \n","list_of_lines = ijmacro. readlines()\n","# Change model especific parameters\n","list_of_lines[19] = \"paramMean = {};\\n\".format(test_mean)\n","list_of_lines[20] = \"paramStd = {};\\n\".format(test_std)\n","list_of_lines.insert(len(list_of_lines), '\\n')\n","list_of_lines.insert(len(list_of_lines), '// Scaling\\n')\n","list_of_lines.insert(len(list_of_lines), 'run(\"Scale...\", \"x={0} y={0} interpolation=None create title=upsampled_input\");\\n'.format(upsampling_factor, N*upsampling_factor, M*upsampling_factor))\n","ijmacro.close()\n","ijmacro = open(path_preprocessing,\"w\") \n","ijmacro. writelines(list_of_lines)\n","ijmacro. close()\n","\n","## Prepare postprocessing file for Maxima Localization\n","path_postprocessing_max = \"LocalMaximaSMLM.ijm\"\n","urllib.request.urlretrieve(\"https://raw.githubusercontent.com/deepimagej/imagej-macros/master/LocalMaximaSMLM.ijm\", path_postprocessing_max )\n","\n","ijmacro = open(path_postprocessing_max,\"r\") \n","list_of_lines = ijmacro. readlines()\n","list_of_lines[11] = \"thresh = {};\\n\".format(thresh)\n","list_of_lines[12] = \"L2_weighting_factor = {};\\n\".format(L2_weighting_factor)\n","list_of_lines[13] = \"neighborhood_size = {};\\n\".format(neighborhood_size)\n","list_of_lines[14] = \"pixelSize = {}; // in nm and after upsampling\\n\".format(pixel_size_hr)\n","ijmacro.close()\n","ijmacro = open(path_postprocessing_max,\"w\") \n","ijmacro. writelines(list_of_lines)\n","ijmacro. close()\n","\n","## Prepare postprocessing file for Averaged Maxima Localization\n","path_postprocessing_avg = \"AveragedMaximaSMLM.ijm\"\n","urllib.request.urlretrieve(\"https://raw.githubusercontent.com/deepimagej/imagej-macros/master/AveragedMaximaSMLM.ijm\", path_postprocessing_avg)\n","\n","ijmacro = open(path_postprocessing_avg,\"r\") \n","list_of_lines = ijmacro. readlines()\n","list_of_lines[11] = \"thresh = {};\\n\".format(thresh)\n","list_of_lines[12] = \"L2_weighting_factor = {};\\n\".format(L2_weighting_factor)\n","list_of_lines[13] = \"neighborhood_size = {};\\n\".format(neighborhood_size)\n","list_of_lines[14] = \"pixelSize = {}; // in nm and after upsampling\\n\".format(pixel_size_hr)\n","ijmacro.close()\n","ijmacro = open(path_postprocessing_avg,\"w\") \n","ijmacro. writelines(list_of_lines)\n","ijmacro. close()\n","\n","# Include the info about the macros \n","dij_config.Preprocessing = [path_preprocessing]\n","dij_config.Preprocessing_files = [path_preprocessing]\n","\n","dij_config.Postprocessing = [path_postprocessing_max]\n","dij_config.Postprocessing_files = [path_postprocessing_max]\n","\n","## EXPORT THE MODEL\n","deepimagej_model_path = os.path.join(QC_model_path, Trained_model_name+'.bioimage.io.model')\n","dij_config.export_model(deepimagej_model_path)\n","\n","## Add csv with maxima localization and their confidence: \n","# Maxima localization\n","max_layer = Maximafinder(thresh*L2_weighting_factor, neighborhood_size, False)\n","bind, xind, yind, confidence = max_layer(predicted_density)\n","bind = bind.eval(session=tf.compat.v1.Session())\n","xind = xind.eval(session=tf.compat.v1.Session())\n","yind = yind.eval(session=tf.compat.v1.Session())\n","confidence = confidence.eval(session=tf.compat.v1.Session())\n","# normalizing the confidence by the L2_weighting_factor\n","confidence /= L2_weighting_factor \n","\n","# turn indices to nms and append to the results\n","xind, yind = xind*pixel_size_hr, yind*pixel_size_hr\n","# frmind = (bind.numpy() + b*batch_size + 1).tolist()\n","# xind = xind.numpy().tolist()\n","# yind = yind.numpy().tolist()\n","# confidence = confidence.numpy().tolist()\n","\n","frmind = (bind + 1).tolist()\n","xind = xind.tolist()\n","yind = yind.tolist()\n","confidence = confidence.tolist()\n","\n","with open(os.path.join(deepimagej_model_path, 'Localizations_resultImage_max.csv'), \"w\", newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow(['frame', 'x [nm]', 'y [nm]', 'confidence [a.u]'])\n"," locs = list(zip(frmind, xind, yind, confidence))\n"," writer.writerows(locs)\n","# Zip the bundled model to download\n","shutil.make_archive(deepimagej_model_path, 'zip', deepimagej_model_path)\n","print(\"Localization csv file has been added to {0}.zip.\".format(deepimagej_model_path))\n","\n","## Prepare the macro file to process a stack with the trained model\n","path_macro_for_stacks = \"DeepSTORM4stacksThunderSTORM.ijm\"\n","urllib.request.urlretrieve(\"https://raw.githubusercontent.com/deepimagej/imagej-macros/master/DeepSTORM4stacksThunderSTORM.ijm\", path_macro_for_stacks)\n","shutil.copy2(path_macro_for_stacks, os.path.join(deepimagej_model_path, path_macro_for_stacks))\n","\n","# Save the Averaged Maxima Localization\n","shutil.copy2(path_postprocessing_avg, os.path.join(deepimagej_model_path, path_postprocessing_avg))\n","\n","print(\"An ImageJ macro file to process a entire stack has been added to the folder {0} under the name {1}.\".format(deepimagej_model_path, path_macro_for_stacks))\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"yTRou0izLjhd"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"eAf8aBDmWTx7"},"source":["## **6.1 Generate image prediction and localizations from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the found localizations csv.\n","\n","**`batch_size`:** This paramter determines how many frames are processed by any single pass on the GPU. A higher `batch_size` will make the prediction faster but will use more GPU memory. If an OutOfMemory (OOM) error occurs, decrease the `batch_size`. **DEFAULT: 4**\n","\n","**`threshold`:** This paramter determines threshold for local maxima finding. The value is expected to reside in the range **[0,1]**. A higher `threshold` will result in less localizations. **DEFAULT: 0.1**\n","\n","**`neighborhood_size`:** This paramter determines size of the neighborhood within which the prediction needs to be a local maxima in recovery pixels (CCD pixel/upsampling_factor). A high `neighborhood_size` will make the prediction slower and potentially discard nearby localizations. **DEFAULT: 3**\n","\n","**`use_local_average`:** This paramter determines whether to locally average the prediction in a 3x3 neighborhood to get the final localizations. If set to **True** it will make inference slightly slower depending on the size of the FOV. **DEFAULT: True**\n"]},{"cell_type":"code","metadata":{"id":"7qn06T_A0lxf","cellView":"form"},"source":["\n","# ------------------------------- User input -------------------------------\n","#@markdown ### Data parameters\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","#@markdown Get pixel size from file?\n","get_pixel_size_from_file = True #@param {type:\"boolean\"}\n","#@markdown Otherwise, use this value (in nm):\n","pixel_size = 100 #@param {type:\"number\"}\n","\n","#@markdown ### Model parameters\n","#@markdown Do you want to use the model you just trained?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","#@markdown Otherwise, please provide path to the model folder below\n","prediction_model_path = \"\" #@param {type:\"string\"}\n","\n","#@markdown ### Prediction parameters\n","batch_size = 4#@param {type:\"integer\"}\n","\n","#@markdown ### Post processing parameters\n","threshold = 0.1#@param {type:\"number\"}\n","neighborhood_size = 3#@param {type:\"integer\"}\n","#@markdown Do you want to locally average the model output with CoG estimator ?\n","use_local_average = True #@param {type:\"boolean\"}\n","\n","\n","if get_pixel_size_from_file:\n"," pixel_size = None\n","\n","if (Use_the_current_trained_model): \n"," prediction_model_path = os.path.join(model_path, model_name)\n","\n","if os.path.exists(prediction_model_path):\n"," print(\"The \"+os.path.basename(prediction_model_path)+\" model will be used.\")\n","else:\n"," print(bcolors.WARNING+'!! WARNING: The chosen model does not exist !!'+bcolors.NORMAL)\n"," print('Please make sure you provide a valid model path before proceeding further.')\n","\n","# inform user whether local averaging is being used\n","if use_local_average == True: \n"," print('Using local averaging')\n","\n","if not os.path.exists(Result_folder):\n"," print('Result folder was created.')\n"," os.makedirs(Result_folder)\n","\n","\n","# ------------------------------- Run predictions -------------------------------\n","\n","start = time.time()\n","#%% This script tests the trained fully convolutional network based on the \n","# saved training weights, and normalization created using train_model.\n","\n","if os.path.isdir(Data_folder): \n"," for filename in list_files(Data_folder, 'tif'):\n"," # run the testing/reconstruction process\n"," print(\"------------------------------------\")\n"," print(\"Running prediction on: \"+ filename)\n"," batchFramePredictionLocalization(Data_folder, filename, prediction_model_path, Result_folder, \n"," batch_size, \n"," threshold, \n"," neighborhood_size, \n"," use_local_average,\n"," pixel_size = pixel_size)\n","\n","elif os.path.isfile(Data_folder):\n"," batchFramePredictionLocalization(os.path.dirname(Data_folder), os.path.basename(Data_folder), prediction_model_path, Result_folder, \n"," batch_size, \n"," threshold, \n"," neighborhood_size, \n"," use_local_average, \n"," pixel_size = pixel_size)\n","\n","\n","\n","print('--------------------------------------------------------------------')\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n","\n","\n","# ------------------------------- Interactive display -------------------------------\n","\n","print('--------------------------------------------------------------------')\n","print('---------------------------- Previews ------------------------------')\n","print('--------------------------------------------------------------------')\n","\n","if os.path.isdir(Data_folder): \n"," @interact\n"," def show_QC_results(file = list_files(Data_folder, 'tif')):\n","\n"," plt.figure(figsize=(15,7.5))\n"," # Wide-field\n"," plt.subplot(1,2,1)\n"," plt.axis('off')\n"," img_Source = io.imread(os.path.join(Result_folder, 'Widefield_'+file))\n"," plt.imshow(img_Source, norm = simple_norm(img_Source, percent = 99.5))\n"," plt.title('Widefield', fontsize=15)\n"," # Prediction\n"," plt.subplot(1,2,2)\n"," plt.axis('off')\n"," img_Prediction = io.imread(os.path.join(Result_folder, 'Predicted_'+file))\n"," plt.imshow(img_Prediction, norm = simple_norm(img_Prediction, percent = 99.5))\n"," plt.title('Predicted',fontsize=15)\n","\n","if os.path.isfile(Data_folder):\n","\n"," plt.figure(figsize=(15,7.5))\n"," # Wide-field\n"," plt.subplot(1,2,1)\n"," plt.axis('off')\n"," img_Source = io.imread(os.path.join(Result_folder, 'Widefield_'+os.path.basename(Data_folder)))\n"," plt.imshow(img_Source, norm = simple_norm(img_Source, percent = 99.5))\n"," plt.title('Widefield', fontsize=15)\n"," # Prediction\n"," plt.subplot(1,2,2)\n"," plt.axis('off')\n"," img_Prediction = io.imread(os.path.join(Result_folder, 'Predicted_'+os.path.basename(Data_folder)))\n"," plt.imshow(img_Prediction, norm = simple_norm(img_Prediction, percent = 99.5))\n"," plt.title('Predicted',fontsize=15)\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"ZekzexaPmzFZ"},"source":["## **6.2 Drift correction**\n","---\n","\n","The visualization above is the raw output of the network and displayed at the `upsampling_factor` chosen during model training. The display is a preview without any drift correction applied. This section performs drift correction using cross-correlation between time bins to estimate the drift.\n","\n","**`Loc_file_path`:** is the path to the localization file to use for visualization.\n","\n","**`original_image_path`:** is the path to the original image. This only serves to extract the original image size and pixel size to shape the visualization properly.\n","\n","**`visualization_pixel_size`:** This parameter corresponds to the pixel size to use for the image reconstructions used for the Drift Correction estmication (in **nm**). A smaller pixel size will be more precise but will take longer to compute. **DEFAULT: 20**\n","\n","**`number_of_bins`:** This parameter defines how many temporal bins are used across the full dataset. All localizations in each bins are used ot build an image. This image is used to find the drift with respect to the image obtained from the very first bin. A typical value would correspond to about 500 frames per bin. **DEFAULT: Total number of frames / 500**\n","\n","**`polynomial_fit_degree`:** The drift obtained for each temporal bins needs to be interpolated to every single frames. This is performed by polynomial fit, the degree of which is defined here. **DEFAULT: 4**\n","\n"," The drift-corrected localization data is automaticaly saved in the `save_path` folder."]},{"cell_type":"code","metadata":{"id":"hYtP_vh6mzUP","cellView":"form"},"source":["# @markdown ##Data parameters\n","Loc_file_path = \"\" #@param {type:\"string\"}\n","# @markdown Provide information about original data. Get the info automatically from the raw data?\n","Get_info_from_file = True #@param {type:\"boolean\"}\n","# Loc_file_path = \"/content/gdrive/My Drive/Colab notebooks testing/DeepSTORM/Glia data from CL/Results from prediction/20200615-M6 with CoM localizations/Localizations_glia_actin_2D - 1-500fr_avg.csv\" #@param {type:\"string\"}\n","original_image_path = \"\" #@param {type:\"string\"}\n","# @markdown Otherwise, please provide image width, height (in pixels) and pixel size (in nm)\n","image_width = 256#@param {type:\"integer\"}\n","image_height = 256#@param {type:\"integer\"}\n","pixel_size = 100 #@param {type:\"number\"}\n","\n","# @markdown ##Drift correction parameters\n","visualization_pixel_size = 20#@param {type:\"number\"}\n","number_of_bins = 50#@param {type:\"integer\"}\n","polynomial_fit_degree = 4#@param {type:\"integer\"}\n","\n","# @markdown ##Saving parameters\n","save_path = '' #@param {type:\"string\"}\n","\n","\n","# Let's go !\n","start = time.time()\n","\n","# Get info from the raw file if selected\n","if Get_info_from_file:\n"," pixel_size, image_width, image_height = getPixelSizeTIFFmetadata(original_image_path, display=True)\n","\n","# Read the localizations in\n","LocData = pd.read_csv(Loc_file_path)\n","\n","# Calculate a few variables \n","Mhr = int(math.ceil(image_height*pixel_size/visualization_pixel_size))\n","Nhr = int(math.ceil(image_width*pixel_size/visualization_pixel_size))\n","nFrames = max(LocData['frame'])\n","x_max = max(LocData['x [nm]'])\n","y_max = max(LocData['y [nm]'])\n","image_size = (Mhr, Nhr)\n","n_locs = len(LocData.index)\n","\n","print('Image size: '+str(image_size))\n","print('Number of frames in data: '+str(nFrames))\n","print('Number of localizations in data: '+str(n_locs))\n","\n","blocksize = math.ceil(nFrames/number_of_bins)\n","print('Number of frames per block: '+str(blocksize))\n","\n","blockDataFrame = LocData[(LocData['frame'] < blocksize)].copy()\n","xc_array = blockDataFrame['x [nm]'].to_numpy(dtype=np.float32)\n","yc_array = blockDataFrame['y [nm]'].to_numpy(dtype=np.float32)\n","\n","# Preparing the Reference image\n","photon_array = np.ones(yc_array.shape[0])\n","sigma_array = np.ones(yc_array.shape[0])\n","ImageRef = FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = image_size, pixel_size = visualization_pixel_size)\n","ImagesRef = np.rot90(ImageRef, k=2)\n","\n","xDrift = np.zeros(number_of_bins)\n","yDrift = np.zeros(number_of_bins)\n","\n","filename_no_extension = os.path.splitext(os.path.basename(Loc_file_path))[0]\n","\n","with open(os.path.join(save_path, filename_no_extension+\"_DriftCorrectionData.csv\"), \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"Block #\", \"x-drift [nm]\",\"y-drift [nm]\"])\n","\n"," for b in tqdm(range(number_of_bins)):\n","\n"," blockDataFrame = LocData[(LocData['frame'] >= (b*blocksize)) & (LocData['frame'] < ((b+1)*blocksize))].copy()\n"," xc_array = blockDataFrame['x [nm]'].to_numpy(dtype=np.float32)\n"," yc_array = blockDataFrame['y [nm]'].to_numpy(dtype=np.float32)\n","\n"," photon_array = np.ones(yc_array.shape[0])\n"," sigma_array = np.ones(yc_array.shape[0])\n"," ImageBlock = FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = image_size, pixel_size = visualization_pixel_size)\n","\n"," XC = fftconvolve(ImagesRef, ImageBlock, mode = 'same')\n"," yDrift[b], xDrift[b] = subPixelMaxLocalization(XC, method = 'CoM')\n","\n"," # saveAsTIF(save_path, 'ImageBlock'+str(b), ImageBlock, visualization_pixel_size)\n"," # saveAsTIF(save_path, 'XCBlock'+str(b), XC, visualization_pixel_size)\n"," writer.writerow([str(b), str((xDrift[b]-xDrift[0])*visualization_pixel_size), str((yDrift[b]-yDrift[0])*visualization_pixel_size)])\n","\n","\n","print('--------------------------------------------------------------------')\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n","\n","print('Fitting drift data...')\n","bin_number = np.arange(number_of_bins)*blocksize + blocksize/2\n","xDrift = (xDrift-xDrift[0])*visualization_pixel_size\n","yDrift = (yDrift-yDrift[0])*visualization_pixel_size\n","\n","xDriftCoeff = np.polyfit(bin_number, xDrift, polynomial_fit_degree)\n","yDriftCoeff = np.polyfit(bin_number, yDrift, polynomial_fit_degree)\n","\n","xDriftFit = np.poly1d(xDriftCoeff)\n","yDriftFit = np.poly1d(yDriftCoeff)\n","bins = np.arange(nFrames)\n","xDriftInterpolated = xDriftFit(bins)\n","yDriftInterpolated = yDriftFit(bins)\n","\n","\n","# ------------------ Displaying the image results ------------------\n","\n","plt.figure(figsize=(15,10))\n","plt.plot(bin_number,xDrift, 'r+', label='x-drift')\n","plt.plot(bin_number,yDrift, 'b+', label='y-drift')\n","plt.plot(bins,xDriftInterpolated, 'r-', label='y-drift (fit)')\n","plt.plot(bins,yDriftInterpolated, 'b-', label='y-drift (fit)')\n","plt.title('Cross-correlation estimated drift')\n","plt.ylabel('Drift [nm]')\n","plt.xlabel('Bin number')\n","plt.legend();\n","\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\", hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n","\n","\n","# ------------------ Actual drift correction -------------------\n","\n","print('Correcting localization data...')\n","xc_array = LocData['x [nm]'].to_numpy(dtype=np.float32)\n","yc_array = LocData['y [nm]'].to_numpy(dtype=np.float32)\n","frames = LocData['frame'].to_numpy(dtype=np.int32)\n","\n","\n","xc_array_Corr, yc_array_Corr = correctDriftLocalization(xc_array, yc_array, frames, xDriftInterpolated, yDriftInterpolated)\n","ImageRaw = FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = image_size, pixel_size = visualization_pixel_size)\n","ImageCorr = FromLoc2Image_SimpleHistogram(xc_array_Corr, yc_array_Corr, image_size = image_size, pixel_size = visualization_pixel_size)\n","\n","\n","# ------------------ Displaying the imge results ------------------\n","plt.figure(figsize=(15,7.5))\n","# Raw\n","plt.subplot(1,2,1)\n","plt.axis('off')\n","plt.imshow(ImageRaw, norm = simple_norm(ImageRaw, percent = 99.5))\n","plt.title('Raw', fontsize=15);\n","# Corrected\n","plt.subplot(1,2,2)\n","plt.axis('off')\n","plt.imshow(ImageCorr, norm = simple_norm(ImageCorr, percent = 99.5))\n","plt.title('Corrected',fontsize=15);\n","\n","\n","# ------------------ Table with info -------------------\n","driftCorrectedLocData = pd.DataFrame()\n","driftCorrectedLocData['frame'] = frames\n","driftCorrectedLocData['x [nm]'] = xc_array_Corr\n","driftCorrectedLocData['y [nm]'] = yc_array_Corr\n","driftCorrectedLocData['confidence [a.u]'] = LocData['confidence [a.u]']\n","\n","driftCorrectedLocData.to_csv(os.path.join(save_path, filename_no_extension+'_DriftCorrected.csv'))\n","print('-------------------------------')\n","print('Corrected localizations saved.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"mzOuc-V7rB-r"},"source":["## **6.3 Visualization of the localizations**\n","---\n","\n","\n","The visualization in section 6.1 is the raw output of the network and displayed at the `upsampling_factor` chosen during model training. This section performs visualization of the result by plotting the localizations as a simple histogram.\n","\n","**`Loc_file_path`:** is the path to the localization file to use for visualization.\n","\n","**`original_image_path`:** is the path to the original image. This only serves to extract the original image size and pixel size to shape the visualization properly.\n","\n","**`visualization_pixel_size`:** This parameter corresponds to the pixel size to use for the final image reconstruction (in **nm**). **DEFAULT: 10**\n","\n","**`visualization_mode`:** This parameter defines what visualization method is used to visualize the final image. NOTES: The Integrated Gaussian can be quite slow. **DEFAULT: Simple histogram.**\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"876yIXnqq-nW","cellView":"form"},"source":["# @markdown ##Data parameters\n","Use_current_drift_corrected_localizations = True #@param {type:\"boolean\"}\n","# @markdown Otherwise provide a localization file path\n","Loc_file_path = \"\" #@param {type:\"string\"}\n","# @markdown Provide information about original data. Get the info automatically from the raw data?\n","Get_info_from_file = True #@param {type:\"boolean\"}\n","# Loc_file_path = \"/content/gdrive/My Drive/Colab notebooks testing/DeepSTORM/Glia data from CL/Results from prediction/20200615-M6 with CoM localizations/Localizations_glia_actin_2D - 1-500fr_avg.csv\" #@param {type:\"string\"}\n","original_image_path = \"\" #@param {type:\"string\"}\n","# @markdown Otherwise, please provide image width, height (in pixels) and pixel size (in nm)\n","image_width = 256#@param {type:\"integer\"}\n","image_height = 256#@param {type:\"integer\"}\n","pixel_size = 100#@param {type:\"number\"}\n","\n","# @markdown ##Visualization parameters\n","visualization_pixel_size = 10#@param {type:\"number\"}\n","visualization_mode = \"Simple histogram\" #@param [\"Simple histogram\", \"Integrated Gaussian (SLOW!)\"]\n","\n","if not Use_current_drift_corrected_localizations:\n"," filename_no_extension = os.path.splitext(os.path.basename(Loc_file_path))[0]\n","\n","\n","if Get_info_from_file:\n"," pixel_size, image_width, image_height = getPixelSizeTIFFmetadata(original_image_path, display=True)\n","\n","if Use_current_drift_corrected_localizations:\n"," LocData = driftCorrectedLocData\n","else:\n"," LocData = pd.read_csv(Loc_file_path)\n","\n","Mhr = int(math.ceil(image_height*pixel_size/visualization_pixel_size))\n","Nhr = int(math.ceil(image_width*pixel_size/visualization_pixel_size))\n","\n","\n","nFrames = max(LocData['frame'])\n","x_max = max(LocData['x [nm]'])\n","y_max = max(LocData['y [nm]'])\n","image_size = (Mhr, Nhr)\n","\n","print('Image size: '+str(image_size))\n","print('Number of frames in data: '+str(nFrames))\n","print('Number of localizations in data: '+str(len(LocData.index)))\n","\n","xc_array = LocData['x [nm]'].to_numpy()\n","yc_array = LocData['y [nm]'].to_numpy()\n","if (visualization_mode == 'Simple histogram'):\n"," locImage = FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = image_size, pixel_size = visualization_pixel_size)\n","elif (visualization_mode == 'Shifted histogram'):\n"," print(bcolors.WARNING+'Method not implemented yet!'+bcolors.NORMAL)\n"," locImage = np.zeros(image_size)\n","elif (visualization_mode == 'Integrated Gaussian (SLOW!)'):\n"," photon_array = np.ones(xc_array.shape)\n"," sigma_array = np.ones(xc_array.shape)\n"," locImage = FromLoc2Image_Erf(xc_array, yc_array, photon_array, sigma_array, image_size = image_size, pixel_size = visualization_pixel_size)\n","\n","print('--------------------------------------------------------------------')\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n","\n","# Display\n","plt.figure(figsize=(20,10))\n","plt.axis('off')\n","# plt.imshow(locImage, cmap='gray');\n","plt.imshow(locImage, norm = simple_norm(locImage, percent = 99.5));\n","\n","\n","LocData.head()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"PdOhWwMn1zIT","cellView":"form"},"source":["# @markdown ---\n","\n","# @markdown #Play this cell to save the visualization\n","\n","# @markdown ####Please select a path to the folder where to save the visualization.\n","save_path = \"\" #@param {type:\"string\"}\n","\n","if not os.path.exists(save_path):\n"," os.makedirs(save_path)\n"," print('Folder created.')\n","\n","saveAsTIF(save_path, filename_no_extension+'_Visualization', locImage, visualization_pixel_size)\n","print('Image saved.')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"1EszIF4Dkz_n"},"source":["## **6.4. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"0BvykD0YIk89"},"source":["# **7. Version log**\n","\n","---\n","**v1.13**: \n","* The section 1 and 2 are now swapped for better export of *requirements.txt*. \n","* This version also now includes built-in version check and the version log that you're reading now.\n","\n","---"]},{"cell_type":"markdown","metadata":{"id":"UgN-NooKk3nV"},"source":["\n","#**Thank you for using Deep-STORM 2D!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/Beta notebooks/DenoiSeg_ZeroCostDL4Mic.ipynb b/Colab_notebooks/Beta notebooks/DenoiSeg_ZeroCostDL4Mic.ipynb index aca3cb87..fdadbedc 100644 --- a/Colab_notebooks/Beta notebooks/DenoiSeg_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/Beta notebooks/DenoiSeg_ZeroCostDL4Mic.ipynb @@ -1,2166 +1 @@ -{ - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "DenoiSeg_ZeroCostDL4Mic.ipynb", - "provenance": [], - "collapsed_sections": [], - "toc_visible": true, - "machine_shape": "hm" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.7" - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "accelerator": "GPU" - }, - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "IkSguVy8Xv83" - }, - "source": [ - "# **Image denoising and segmentation using DenoiSeg 2D**\n", - "\n", - "---\n", - "\n", - " DenoiSeg 2D is deep-learning method that can be used to jointly denoise and segment 2D microscopy images. By running this notebook, you can train your and use you own network. \n", - "\n", - " The benefits of using DenoiSeg (compared to other Deep Learning-based segmentation methods) are more prononced when only a few annotated images are available. However, the denoising part requires many images to perform well. All the noisy images don't need to be labeled to train DenoiSeg.\n", - "---\n", - "\n", - "*Disclaimer*:\n", - "\n", - "This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n", - "\n", - "This notebook is largely based on the paper: **DenoiSeg: Joint Denoising and Segmentation**\n", - "Tim-Oliver Buchholz, Mangal Prakash, Alexander Krull, Florian Jug\n", - "https://arxiv.org/abs/2005.02987\n", - "\n", - "And source code found in: /~https://github.com/juglab/DenoiSeg/wiki\n", - "\n", - "\n", - "\n", - "**Please also cite this original paper when using or developing this notebook.**\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jWAz2i7RdxUV" - }, - "source": [ - "# **How to use this notebook?**\n", - "\n", - "---\n", - "\n", - "Video describing how to use our notebooks are available on youtube:\n", - " - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n", - " - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n", - "\n", - "\n", - "---\n", - "###**Structure of a notebook**\n", - "\n", - "The notebook contains two types of cell: \n", - "\n", - "**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n", - "\n", - "**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n", - "\n", - "---\n", - "###**Table of contents, Code snippets** and **Files**\n", - "\n", - "On the top left side of the notebook you find three tabs which contain from top to bottom:\n", - "\n", - "*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n", - "\n", - "*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n", - "\n", - "*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n", - "\n", - "**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n", - "\n", - "**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n", - "\n", - "---\n", - "###**Making changes to the notebook**\n", - "\n", - "**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n", - "\n", - "To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n", - "You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "gKDLkLWUd-YX" - }, - "source": [ - "# **0. Before getting started**\n", - "---\n", - "\n", - "Before you run the notebook, please ensure that you are logged into your Google account and have the training and/or data to process in your Google Drive.\n", - "\n", - "**it needs to have access to a paired training dataset made of images and their corresponding masks**. Information on how to generate a training dataset is available in our Wiki page: /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n", - "\n", - "**Importantly, the benefits of using DenoiSeg are more pronounced when only limited numbers of segmentation annotations are available for training. However, DenoiSeg also expects that lots of noisy raw images are available to train the denoising part. It is therefore not required for all the noisy images to be annotated to train DenoiSeg**.\n", - "\n", - "**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model**. The quality control assessment can be done directly in this notebook.\n", - "\n", - "The data structure is important. It is necessary that all the input data are in the same folder and that all the output data is in a separate folder. The provided training dataset is already split.\n", - "\n", - "Additionally, the corresponding Training_source and Training_target files need to have **the same name**.\n", - "\n", - "Please note that you currently can **only use .tif files!**\n", - "\n", - "You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed. This can include Test dataset for which you have the equivalent output and can compare to what the network provides.\n", - "\n", - "Here's a common data structure that can work:\n", - "* Experiment A\n", - " - **Training dataset**\n", - " - Noisy Images (Training_source)\n", - " - img_1.tif, img_2.tif, img_3.tif, img_4.tif, ...\n", - " - Masks (Training_target)\n", - " - img_1.tif, img_2.tif\n", - " - **Quality control dataset (optional, not required for training)**\n", - " - Noisy Images\n", - " - img_1.tif, img_2.tif\n", - " - High SNR Images\n", - " - img_1.tif, img_2.tif\n", - " - Masks \n", - " - img_1.tif, img_2.tif\n", - " - **Data to be predicted**\n", - " - **Results**\n", - "\n", - "---\n", - "**Important note**\n", - "\n", - "- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n", - "\n", - "- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n", - "\n", - "- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n", - "---\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "n4yWFoJNnoin" - }, - "source": [ - "# **1. Initialise the Colab session**\n", - "\n", - "\n", - "\n", - "\n", - "---\n", - "\n", - "\n", - "\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "DMNHVZfHmbKb" - }, - "source": [ - "\n", - "## **1.1. Check for GPU access**\n", - "---\n", - "\n", - "By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n", - "\n", - "Go to **Runtime -> Change the Runtime type**\n", - "\n", - "**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n", - "\n", - "**Accelerator: GPU** *(Graphics processing unit)*\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "zCvebubeSaGY", - "cellView": "form" - }, - "source": [ - "#@markdown ##Run this cell to check if you have GPU access\n", - "%tensorflow_version 1.x\n", - "\n", - "\n", - "import tensorflow as tf\n", - "if tf.test.gpu_device_name()=='':\n", - " print('You do not have GPU access.') \n", - " print('Did you change your runtime ?') \n", - " print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n", - " print('Expect slow performance. To access GPU try reconnecting later')\n", - "\n", - "else:\n", - " print('You have GPU access')\n", - " !nvidia-smi" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "sNIVx8_CLolt" - }, - "source": [ - "## **1.2. Mount your Google Drive**\n", - "---\n", - " To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n", - "\n", - " Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n", - "\n", - " Once this is done, your data are available in the **Files** tab on the top left of notebook." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "01Djr8v-5pPk", - "cellView": "form" - }, - "source": [ - "#@markdown ##Play the cell to connect your Google Drive to Colab\n", - "\n", - "#@markdown * Click on the URL. \n", - "\n", - "#@markdown * Sign in your Google Account. \n", - "\n", - "#@markdown * Copy the authorization code. \n", - "\n", - "#@markdown * Enter the authorization code. \n", - "\n", - "#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n", - "\n", - "# mount user's Google Drive to Google Colab.\n", - "from google.colab import drive\n", - "drive.mount('/content/gdrive')" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "AdN8B91xZO0x" - }, - "source": [ - "# **2. Install DenoiSeg and Dependencies**\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "nwYTuUfOrtPj" - }, - "source": [ - "## **2.1. Install key dependencies**\n", - "---\n", - " " - ] - }, - { - "cell_type": "code", - "metadata": { - "cellView": "form", - "id": "7TIzloUDrwOA" - }, - "source": [ - "\n", - "#@markdown ##Install denoiseg and dependencies\n", - "\n", - "%tensorflow_version 1.x\n", - "import tensorflow \n", - "\n", - "!pip uninstall -y keras-nightly\n", - "\n", - "!pip3 install h5py==2.10.0\n", - "\n", - "!pip install denoiseg\n", - "!pip install wget\n", - "!pip install fpdf\n", - "!pip install memory_profiler\n", - "\n", - "!pip install q keras==2.2.5\n", - "\n", - "#Force session restart\n", - "exit(0)\n", - "\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "XY4ZDsAMr3Ef" - }, - "source": [ - "## **2.2. Restart your runtime**\n", - "---\n", - "\n", - "\n", - "\n", - "** Your Runtime has automatically restarted. This is normal.**\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "i2d3Z_Ggr7Ng" - }, - "source": [ - "## **2.3. Load key dependencies**\n", - "---\n", - " " - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "fq21zJVFNASx", - "cellView": "form" - }, - "source": [ - "#@markdown ##Load key dependencies\n", - "\n", - "\n", - "Notebook_version = ['1.12']\n", - "\n", - "from builtins import any as b_any\n", - "\n", - "def get_requirements_path():\n", - " # Store requirements file in 'contents' directory\n", - " current_dir = os.getcwd()\n", - " dir_count = current_dir.count('/') - 1\n", - " path = '../' * (dir_count) + 'requirements.txt'\n", - " return path\n", - "\n", - "def filter_files(file_list, filter_list):\n", - " filtered_list = []\n", - " for fname in file_list:\n", - " if b_any(fname.split('==')[0] in s for s in filter_list):\n", - " filtered_list.append(fname)\n", - " return filtered_list\n", - "\n", - "def build_requirements_file(before, after):\n", - " path = get_requirements_path()\n", - "\n", - " # Exporting requirements.txt for local run\n", - " !pip freeze > $path\n", - "\n", - " # Get minimum requirements file\n", - " df = pd.read_csv(path, delimiter = \"\\n\")\n", - " mod_list = [m.split('.')[0] for m in after if not m in before]\n", - " req_list_temp = df.values.tolist()\n", - " req_list = [x[0] for x in req_list_temp]\n", - "\n", - " # Replace with package name and handle cases where import name is different to module name\n", - " mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]\n", - " mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list]\n", - " filtered_list = filter_files(req_list, mod_replace_list)\n", - "\n", - " file=open(path,'w')\n", - " for item in filtered_list:\n", - " file.writelines(item + '\\n')\n", - "\n", - " file.close()\n", - "\n", - "import sys\n", - "before = [str(m) for m in sys.modules]\n", - "\n", - "# Here we enable Tensorflow 1. \n", - "%tensorflow_version 1.x\n", - "import tensorflow\n", - "print(tensorflow.__version__)\n", - "\n", - "print(\"Tensorflow enabled.\")\n", - "\n", - "%load_ext memory_profiler\n", - "\n", - "\n", - "# Here we install all libraries and other depencies to run the notebook.\n", - "\n", - "# ------- Variable specific to Denoiseg -------\n", - "\n", - "import numpy as np\n", - "from matplotlib import pyplot as plt\n", - "from scipy import ndimage\n", - "\n", - "from denoiseg.models import DenoiSeg, DenoiSegConfig\n", - "from denoiseg.utils.misc_utils import combine_train_test_data, shuffle_train_data, augment_data\n", - "from denoiseg.utils.seg_utils import *\n", - "from denoiseg.utils.compute_precision_threshold import measure_precision, compute_labels\n", - "\n", - "from csbdeep.utils import plot_history\n", - "from tifffile import imread, imsave\n", - "from glob import glob\n", - "\n", - "import urllib\n", - "import os\n", - "import zipfile\n", - "\n", - "# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n", - "import numpy as np\n", - "from matplotlib import pyplot as plt\n", - "import urllib\n", - "import os, random\n", - "import shutil \n", - "import zipfile\n", - "from tifffile import imread, imsave\n", - "import time\n", - "import sys\n", - "import wget\n", - "from pathlib import Path\n", - "import pandas as pd\n", - "import csv\n", - "from glob import glob\n", - "from scipy import signal\n", - "from scipy import ndimage\n", - "from skimage import io\n", - "from sklearn.linear_model import LinearRegression\n", - "from skimage.util import img_as_uint\n", - "import matplotlib as mpl\n", - "from skimage.metrics import structural_similarity\n", - "from skimage.metrics import peak_signal_noise_ratio as psnr\n", - "from astropy.visualization import simple_norm\n", - "from skimage import img_as_float32\n", - "\n", - "from fpdf import FPDF, HTMLMixin\n", - "from datetime import datetime\n", - "import subprocess\n", - "from pip._internal.operations.freeze import freeze\n", - "\n", - "# Colors for the warning messages\n", - "class bcolors:\n", - " WARNING = '\\033[31m'\n", - "W = '\\033[0m' # white (normal)\n", - "R = '\\033[31m' # red\n", - "\n", - "#Disable some of the tensorflow warnings\n", - "import warnings\n", - "warnings.filterwarnings(\"ignore\")\n", - "\n", - "# For sliders and dropdown menu and progress bar\n", - "from ipywidgets import interact\n", - "import ipywidgets as widgets\n", - "\n", - "\n", - "# Check if this is the latest version of the notebook\n", - "Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n", - "\n", - "if Notebook_version == list(Latest_notebook_version.columns):\n", - " print(\"This notebook is up-to-date.\")\n", - "\n", - "if not Notebook_version == list(Latest_notebook_version.columns):\n", - " print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n", - "\n", - "\n", - "def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n", - " class MyFPDF(FPDF, HTMLMixin):\n", - " pass\n", - "\n", - " pdf = MyFPDF()\n", - " pdf.add_page()\n", - " pdf.set_right_margin(-1)\n", - " pdf.set_font(\"Arial\", size = 11, style='B') \n", - "\n", - " Network = 'DenoiSeg 2D'\n", - "\n", - " day = datetime.now()\n", - " datetime_str = str(day)[0:10]\n", - "\n", - " Header = 'Training report for '+Network+' model ('+model_name+')\\nDate and Time: '+datetime_str\n", - " pdf.multi_cell(180, 5, txt = Header, align = 'L') \n", - " \n", - " # add another cell \n", - " if trained:\n", - " training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n", - " pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n", - " pdf.ln(1)\n", - "\n", - " Header_2 = 'Information for your materials and method:'\n", - " pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n", - "\n", - " all_packages = ''\n", - " for requirement in freeze(local_only=True):\n", - " all_packages = all_packages+requirement+', '\n", - " #print(all_packages)\n", - "\n", - " #Main Packages\n", - " main_packages = ''\n", - " version_numbers = []\n", - " for name in ['tensorflow','numpy','Keras','csbdeep']:\n", - " find_name=all_packages.find(name)\n", - " main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n", - " #Version numbers only here:\n", - " version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n", - "\n", - " cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n", - " cuda_version = cuda_version.stdout.decode('utf-8')\n", - " cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n", - " gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n", - " gpu_name = gpu_name.stdout.decode('utf-8')\n", - " gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n", - " #print(cuda_version[cuda_version.find(', V')+3:-1])\n", - " #print(gpu_name)\n", - "\n", - " shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n", - " dataset_size = len(os.listdir(Training_source))\n", - "\n", - " text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', with a batch size of '+str(batch_size)+' and using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n", - "\n", - " #text = 'The '+Network+' model ('+model_name+') was trained using '+str(dataset_size)+' paired images (image dimensions: '+str(shape)+') using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The GPU used was a '+gpu_name+'.'\n", - "\n", - " if pretrained_model:\n", - " text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+') with a batch size of '+str(batch_size)+' and using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was retrained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n", - "\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - " pdf.multi_cell(190, 5, txt = text, align='L')\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.ln(1)\n", - " pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n", - " pdf.set_font('')\n", - " if augmentation:\n", - " aug_text = 'Data augmentation was enabled'\n", - "\n", - " else:\n", - " aug_text = 'No augmentation was used for training.'\n", - " pdf.multi_cell(190, 5, txt=aug_text, align='L')\n", - " pdf.set_font('Arial', size = 11, style = 'B')\n", - " pdf.ln(1)\n", - " pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - " if Use_Default_Advanced_Parameters:\n", - " pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n", - " pdf.cell(200, 5, txt='The following parameters were used for training:')\n", - " pdf.ln(1)\n", - " html = \"\"\" \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
ParameterValue
number_of_epochs{0}
batch_size{1}
number_of_steps{2}
percentage_validation{3}
Priority{4}
initial_learning_rate{5}
\n", - " \"\"\".format(number_of_epochs,batch_size,number_of_steps,percentage_validation,Priority,initial_learning_rate)\n", - " pdf.write_html(html)\n", - "\n", - " #pdf.multi_cell(190, 5, txt = text_2, align='L')\n", - " pdf.set_font(\"Arial\", size = 11, style='B')\n", - " pdf.ln(1)\n", - " pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.cell(30, 5, txt= 'Training_source:', align = 'L', ln=0)\n", - " pdf.set_font('')\n", - " pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.cell(28, 5, txt= 'Training_target:', align = 'L', ln=0)\n", - " pdf.set_font('')\n", - " pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n", - " #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n", - " pdf.ln(1)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.cell(21, 5, txt= 'Model Path:', align = 'L', ln=0)\n", - " pdf.set_font('')\n", - " pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n", - " pdf.ln(1)\n", - " pdf.cell(60, 5, txt = 'Example Training pair', ln=1)\n", - " pdf.ln(1)\n", - " exp_size = io.imread('/content/TrainingDataExample_DenoiSeg.png').shape\n", - " pdf.image('/content/TrainingDataExample_DenoiSeg.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n", - " pdf.ln(1)\n", - " ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n", - " pdf.multi_cell(190, 5, txt = ref_1, align='L')\n", - " ref_2 = '- DenoiSeg: Buchholz, Prakash, et al. \"DenoiSeg: Joint Denoising and Segmentation\", arXiv 2020.'\n", - " pdf.multi_cell(190, 5, txt = ref_2, align='L')\n", - "\n", - "\n", - "\n", - " pdf.ln(3)\n", - " reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n", - " pdf.set_font('Arial', size = 11, style='B')\n", - " pdf.multi_cell(190, 5, txt=reminder, align='C')\n", - "\n", - " pdf.output(model_path+'/'+model_name+'/'+model_name+\"_training_report.pdf\")\n", - "\n", - "def qc_pdf_export():\n", - " class MyFPDF(FPDF, HTMLMixin):\n", - " pass\n", - "\n", - " pdf = MyFPDF()\n", - " pdf.add_page()\n", - " pdf.set_right_margin(-1)\n", - " pdf.set_font(\"Arial\", size = 11, style='B') \n", - "\n", - " Network = 'DenoiSeg_2D'\n", - "\n", - " day = datetime.now()\n", - " datetime_str = str(day)[0:10]\n", - "\n", - " Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n", - " pdf.multi_cell(180, 5, txt = Header, align = 'L') \n", - "\n", - " all_packages = ''\n", - " for requirement in freeze(local_only=True):\n", - " all_packages = all_packages+requirement+', '\n", - "\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 11, style = 'B')\n", - " pdf.ln(2)\n", - " pdf.cell(190, 5, txt = 'Loss curves', ln=1, align='L')\n", - " pdf.ln(1)\n", - " exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_segmentation.png').shape\n", - " if os.path.exists(full_QC_model_path+'/Quality Control/lossCurvePlots.png'):\n", - " pdf.image(full_QC_model_path+'/Quality Control/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/12), h = round(exp_size[0]/3))\n", - " else:\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size=10)\n", - " pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.',align='L')\n", - "\n", - " if Evaluate_Segmentation:\n", - "\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.ln(3)\n", - " pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n", - " pdf.ln(1)\n", - " exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_segmentation.png').shape\n", - " pdf.image(full_QC_model_path+'/Quality Control/QC_example_segmentation.png', x = 16, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n", - " pdf.ln(1)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 11, style = 'B')\n", - " pdf.ln(1)\n", - " pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - "\n", - " pdf.ln(1)\n", - " html = \"\"\"\n", - " \n", - " \n", - " \"\"\"\n", - " with open(full_QC_model_path+\"/Quality Control/QC_metrics_Segmentation_\"+QC_model_name+\".csv\", 'r') as csvfile:\n", - " metrics = csv.reader(csvfile)\n", - " header = next(metrics)\n", - " image = header[0]\n", - " IoU = header[1] \n", - " header = \"\"\"\n", - " \n", - " \n", - " \n", - " \"\"\".format(image,IoU)\n", - " html = html+header\n", - " i=0\n", - " for row in metrics:\n", - " i+=1\n", - " image = row[0]\n", - " IoU = row[1] \n", - " cells = \"\"\"\n", - " \n", - " \n", - " \n", - " \"\"\".format(image,str(round(float(IoU),3)))\n", - " html = html+cells\n", - " html = html+\"\"\"
{0}{1}
{0}{1}
\"\"\"\n", - " \n", - " pdf.write_html(html)\n", - "\n", - " if Evaluate_Denoising:\n", - "\n", - " pdf.ln(2)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.ln(3)\n", - " pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n", - " pdf.ln(1)\n", - " exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_denoising.png').shape\n", - " pdf.image(full_QC_model_path+'/Quality Control/QC_example_denoising.png', x = 16, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/10))\n", - " pdf.ln(1)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 11, style = 'B')\n", - " pdf.ln(1)\n", - " pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - "\n", - " pdf.ln(1)\n", - " html = \"\"\"\n", - " \n", - " \n", - " \"\"\"\n", - " with open(full_QC_model_path+\"/Quality Control/QC_metrics_Denoising_\"+QC_model_name+\".csv\", 'r') as csvfile:\n", - " metrics = csv.reader(csvfile)\n", - " header = next(metrics)\n", - " image = header[0]\n", - " mSSIM_PvsGT = header[1]\n", - " mSSIM_SvsGT = header[2]\n", - " NRMSE_PvsGT = header[3]\n", - " NRMSE_SvsGT = header[4]\n", - " PSNR_PvsGT = header[5]\n", - " PSNR_SvsGT = header[6]\n", - " header = \"\"\"\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \"\"\".format(image,mSSIM_PvsGT,mSSIM_SvsGT,NRMSE_PvsGT,NRMSE_SvsGT,PSNR_PvsGT,PSNR_SvsGT)\n", - " html = html+header\n", - " for row in metrics:\n", - " image = row[0]\n", - " mSSIM_PvsGT = row[1]\n", - " mSSIM_SvsGT = row[2]\n", - " NRMSE_PvsGT = row[3]\n", - " NRMSE_SvsGT = row[4]\n", - " PSNR_PvsGT = row[5]\n", - " PSNR_SvsGT = row[6]\n", - " cells = \"\"\"\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \"\"\".format(image,str(round(float(mSSIM_PvsGT),3)),str(round(float(mSSIM_SvsGT),3)),str(round(float(NRMSE_PvsGT),3)),str(round(float(NRMSE_SvsGT),3)),str(round(float(PSNR_PvsGT),3)),str(round(float(PSNR_SvsGT),3)))\n", - " html = html+cells\n", - " html = html+\"\"\"
{0}{1}{2}{3}{4}{5}{6}
{0}{1}{2}{3}{4}{5}{6}
\"\"\"\n", - " \n", - " pdf.write_html(html)\n", - "\n", - "\n", - " pdf.ln(1)\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - " ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n", - " pdf.multi_cell(190, 5, txt = ref_1, align='L')\n", - " ref_2 = '- Unet: Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. \"U-net: Convolutional networks for biomedical image segmentation.\" International Conference on Medical image computing and computer-assisted intervention. Springer, Cham, 2015.'\n", - " pdf.multi_cell(190, 5, txt = ref_2, align='L')\n", - "\n", - " pdf.ln(3)\n", - " reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n", - "\n", - " pdf.set_font('Arial', size = 11, style='B')\n", - " pdf.multi_cell(190, 5, txt=reminder, align='C')\n", - "\n", - " pdf.output(full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n", - "\n", - " print('------------------------------')\n", - " print('QC PDF report exported as '+full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n", - "\n", - "# Build requirements file for local run\n", - "after = [str(m) for m in sys.modules]\n", - "build_requirements_file(before, after)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "HLYcZR9gMv42" - }, - "source": [ - "# **3. Select your parameters and paths**\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "FQ_QxtSWQ7CL" - }, - "source": [ - "## **3.1. Setting main training parameters**\n", - "---\n", - " " - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "AuESFimvMv43" - }, - "source": [ - " **Paths for training, predictions and results**\n", - "\n", - "**`Training_source:`** These is the path to your folders containing the Training_source (noisy images). To find the path of the folder containing your datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n", - "\n", - "**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Do not re-use the name of an existing model (saved in the same folder), otherwise it will be overwritten.\n", - "\n", - "**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n", - "\n", - "\n", - "**Training Parameters**\n", - "\n", - "**`number_of_epochs`:** Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for 100-200 epochs. Evaluate the performance after training (see 4.3.). **Default value: 30**\n", - " \n", - "**Advanced Parameters - experienced users only**\n", - "\n", - "**`Priority`:** Choose how much relative the importance to assign to the denoising \n", - "and segmentation tasks by choosing an appropriate value (between 0 and 1; with 0 being only segmentation and 1 being only denoising. **Default value: 0.5**\n", - "\n", - "**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each image / patch is seen at least once per epoch. **Default value: depends on number of patches, min 100; max 400**\n", - "\n", - "**`batch_size:`** This parameter defines the number of patches seen in each training step. Noise2Void requires a large batch size for stable training. Reduce this parameter if your GPU runs out of memory. **Default value: 128**\n", - "\n", - "**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0004**\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "ewpNJ_I0Mv47", - "cellView": "form" - }, - "source": [ - "# create DataGenerator-object.\n", - "\n", - "\n", - "#@markdown ###Path to training image(s): \n", - "Training_source = \"\" #@param {type:\"string\"}\n", - "Training_target = \"\" #@param {type:\"string\"}\n", - "\n", - "\n", - "#@markdown ### Model name and path:\n", - "model_name = \"\" #@param {type:\"string\"}\n", - "model_path = \"\" #@param {type:\"string\"}\n", - "\n", - "#@markdown ###Training Parameters\n", - "#@markdown Number of epochs:\n", - "number_of_epochs = 30#@param {type:\"number\"}\n", - "\n", - "#@markdown ###Advanced Parameters\n", - "Use_Default_Advanced_Parameters = True#@param {type:\"boolean\"}\n", - "\n", - "#@markdown ###If not, please input:\n", - "Priority = 0.5#@param {type:\"number\"}\n", - "number_of_steps = 100#@param {type:\"number\"}\n", - "batch_size = 128#@param {type:\"number\"}\n", - "percentage_validation = 10#@param {type:\"number\"}\n", - "initial_learning_rate = 0.0004 #@param {type:\"number\"}\n", - "\n", - "if (Use_Default_Advanced_Parameters): \n", - " print(\"Default advanced parameters enabled\")\n", - " # number_of_steps is defined in the following cell in this case\n", - " Priority = 0.5\n", - " batch_size = 128\n", - " percentage_validation = 10\n", - " initial_learning_rate = 0.0004\n", - " \n", - "#here we check that no model with the same name already exist, if so print a warning\n", - "\n", - "if os.path.exists(model_path+'/'+model_name):\n", - " print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n", - " print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\"+W)\n", - "\n", - "# This will open a randomly chosen dataset input image\n", - "random_choice = random.choice(os.listdir(Training_target))\n", - "x = imread(Training_source+\"/\"+random_choice)\n", - "\n", - "# Here we disable pre-trained model by default (in case the next cell is not ran)\n", - "Use_pretrained_model = False\n", - "\n", - "# Here we disable data augmentation by default (in case the cell is not ran)\n", - "Use_Data_augmentation = True\n", - "\n", - "# Here we count the number of files in the training target folder\n", - "Mask_Filelist = os.listdir(Training_target)\n", - "Mask_number_files = len(Mask_Filelist)\n", - "\n", - "# Here we count the number of file to use for validation\n", - "Mask_for_validation = int((Mask_number_files)/percentage_validation)\n", - "\n", - "if Mask_for_validation == 0:\n", - " Mask_for_validation = 2\n", - "if Mask_for_validation == 1:\n", - " Mask_for_validation = 2\n", - "\n", - "# Here we count the number of files in the training target folder\n", - "Noisy_Filelist = os.listdir(Training_source)\n", - "Noisy_number_files = len(Noisy_Filelist)\n", - "\n", - "# Here we count the number of file to use for validation\n", - "Noisy_for_validation = int((Noisy_number_files)/percentage_validation)\n", - "\n", - "if Noisy_for_validation == 0:\n", - " Noisy_for_validation = 1\n", - "\n", - "#Here we find the noisy images that do not have masks\n", - "noisy_image_no_mask_list = list(set(Noisy_Filelist) - set(Mask_Filelist))\n", - "\n", - "\n", - "#Here we split the training dataset between training and validation\n", - "# Everything is copied in the /Content Folder\n", - "Training_source_temp = \"/content/training_source\"\n", - "\n", - "if os.path.exists(Training_source_temp):\n", - " shutil.rmtree(Training_source_temp)\n", - "os.makedirs(Training_source_temp)\n", - "\n", - "Training_target_temp = \"/content/training_target\"\n", - "if os.path.exists(Training_target_temp):\n", - " shutil.rmtree(Training_target_temp)\n", - "os.makedirs(Training_target_temp)\n", - "\n", - "Validation_source_temp = \"/content/validation_source\"\n", - "\n", - "if os.path.exists(Validation_source_temp):\n", - " shutil.rmtree(Validation_source_temp)\n", - "os.makedirs(Validation_source_temp)\n", - "\n", - "Validation_target_temp = \"/content/validation_target\"\n", - "if os.path.exists(Validation_target_temp):\n", - " shutil.rmtree(Validation_target_temp)\n", - "os.makedirs(Validation_target_temp)\n", - "\n", - "list_source = os.listdir(os.path.join(Training_source))\n", - "list_target = os.listdir(os.path.join(Training_target))\n", - "\n", - "#Move files into the temporary source and target directories:\n", - "\n", - "for f in os.listdir(os.path.join(Training_source)):\n", - " shutil.copy(Training_source+\"/\"+f, Training_source_temp+\"/\"+f)\n", - "\n", - "for p in os.listdir(os.path.join(Training_target)):\n", - " shutil.copy(Training_target+\"/\"+p, Training_target_temp+\"/\"+p)\n", - "\n", - "#Here we move images to be used for validation\n", - "for i in range(Mask_for_validation): \n", - " shutil.move(Training_source_temp+\"/\"+list_target[i], Validation_source_temp+\"/\"+list_target[i])\n", - " shutil.move(Training_target_temp+\"/\"+list_target[i], Validation_target_temp+\"/\"+list_target[i])\n", - "\n", - "#Here we move a few more noisy images for validation\n", - "if noisy_image_no_mask_list:\n", - " for y in range(Noisy_for_validation): \n", - " shutil.move(Training_source_temp+\"/\"+noisy_image_no_mask_list[y], Validation_source_temp+\"/\"+noisy_image_no_mask_list[y])\n", - "\n", - "\n", - "print(\"Parameters initiated.\")\n", - "\n", - "y = imread(Training_target+\"/\"+random_choice)\n", - "\n", - "#Here we display one image\n", - "norm = simple_norm(x, percent = 99)\n", - "\n", - "f=plt.figure(figsize=(16,8))\n", - "plt.subplot(1,2,1)\n", - "plt.imshow(x, interpolation='nearest', norm=norm, cmap='magma')\n", - "plt.title('Training source')\n", - "plt.axis('off');\n", - "\n", - "plt.subplot(1,2,2)\n", - "plt.imshow(y, interpolation='nearest', vmin=0, vmax=1, cmap='viridis')\n", - "plt.title('Training target')\n", - "plt.axis('off');\n", - "plt.savefig('/content/TrainingDataExample_DenoiSeg.png',bbox_inches='tight',pad_inches=0)\n", - "\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "xyQZKby8yFME" - }, - "source": [ - "## **3.2. Data augmentation**\n", - "---\n", - "" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "w_jCy7xOx2g3" - }, - "source": [ - "Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n", - "\n", - "Data augmentation is performed here by rotating the patches in XY-Plane and flip them along X-Axis (multiply the dataset by 8). \n", - "\n", - " **By default data augmentation is enabled. Disable this option is you run out of RAM during the training**.\n", - " " - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "DMqWq5-AxnFU", - "cellView": "form" - }, - "source": [ - "#Data augmentation\n", - "\n", - "Use_Data_augmentation = False #@param {type:\"boolean\"}\n", - "\n", - "if Use_Data_augmentation:\n", - " print(\"Data augmentation enabled\")\n", - "\n", - "if not Use_Data_augmentation:\n", - " print(\"Data augmentation disabled\")\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "3L9zSGtORKYI" - }, - "source": [ - "\n", - "## **3.3. Using weights from a pre-trained model as initial weights**\n", - "---\n", - " Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a DenoiSeg model**. \n", - "\n", - " This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n", - "\n", - " In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. " - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "9vC2n-HeLdiJ", - "cellView": "form" - }, - "source": [ - "# @markdown ##Loading weights from a pre-trained network\n", - "\n", - "Use_pretrained_model = False #@param {type:\"boolean\"}\n", - "\n", - "pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n", - "\n", - "Weights_choice = \"best\" #@param [\"last\", \"best\"]\n", - "\n", - "\n", - "#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n", - "pretrained_model_path = \"\" #@param {type:\"string\"}\n", - "\n", - "# --------------------- Check if we load a previously trained model ------------------------\n", - "if Use_pretrained_model:\n", - "\n", - "# --------------------- Load the model from the choosen path ------------------------\n", - " if pretrained_model_choice == \"Model_from_file\":\n", - " h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n", - "\n", - "\n", - "# --------------------- Download the a model provided in the XXX ------------------------\n", - "\n", - " if pretrained_model_choice == \"Model_name\":\n", - " pretrained_model_name = \"Model_name\"\n", - " pretrained_model_path = \"/content/\"+pretrained_model_name\n", - " print(\"Downloading the 2D_Demo_Model_from_Stardist_2D_paper\")\n", - " if os.path.exists(pretrained_model_path):\n", - " shutil.rmtree(pretrained_model_path)\n", - " os.makedirs(pretrained_model_path)\n", - " wget.download(\"\", pretrained_model_path)\n", - " wget.download(\"\", pretrained_model_path)\n", - " wget.download(\"\", pretrained_model_path) \n", - " wget.download(\"\", pretrained_model_path)\n", - " h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n", - "\n", - "# --------------------- Add additional pre-trained models here ------------------------\n", - "\n", - "\n", - "\n", - "# --------------------- Check the model exist ------------------------\n", - "# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n", - " if not os.path.exists(h5_file_path):\n", - " print(bcolors.WARNING+'WARNING: weights_last.h5 pretrained model does not exist')\n", - " Use_pretrained_model = False\n", - "\n", - " \n", - "# If the model path contains a pretrain model, we load the training rate, \n", - " if os.path.exists(h5_file_path):\n", - "#Here we check if the learning rate can be loaded from the quality control folder\n", - " if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n", - "\n", - " with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n", - " csvRead = pd.read_csv(csvfile, sep=',')\n", - " #print(csvRead)\n", - " \n", - " if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n", - " print(\"pretrained network learning rate found\")\n", - " #find the last learning rate\n", - " lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n", - " #Find the learning rate corresponding to the lowest validation loss\n", - " min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n", - " #print(min_val_loss)\n", - " bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n", - "\n", - " if Weights_choice == \"last\":\n", - " print('Last learning rate: '+str(lastLearningRate))\n", - "\n", - " if Weights_choice == \"best\":\n", - " print('Learning rate of best validation loss: '+str(bestLearningRate))\n", - "\n", - " if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n", - " bestLearningRate = initial_learning_rate\n", - " lastLearningRate = initial_learning_rate\n", - " print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n", - "\n", - "#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n", - " if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n", - " print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n", - " bestLearningRate = initial_learning_rate\n", - " lastLearningRate = initial_learning_rate\n", - "\n", - "\n", - "# Display info about the pretrained model to be loaded (or not)\n", - "if Use_pretrained_model:\n", - " print('Weights found in:')\n", - " print(h5_file_path)\n", - " print('will be loaded prior to training.')\n", - "\n", - "else:\n", - " print(bcolors.WARNING+'No pretrained nerwork will be used.')\n", - "\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "YOp8HwavpoON" - }, - "source": [ - "# **4. Train the network**\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "CKSKY4icpcKb" - }, - "source": [ - "## **4.1. Prepare the training data and model for training**\n", - "---\n", - "Here, we use the information from 3. to build the model and convert the training data into a suitable format for training." - ] - }, - { - "cell_type": "code", - "metadata": { - "cellView": "form", - "id": "0LM_L-5Spb2z" - }, - "source": [ - "#@markdown ##Create the model and dataset objects\n", - "\n", - "# --------------------- Here we delete the model folder if it already exist ------------------------\n", - "\n", - "if os.path.exists(model_path+'/'+model_name):\n", - " print(bcolors.WARNING +\"!! WARNING: Model folder already exists and has been removed !!\" + W)\n", - " shutil.rmtree(model_path+'/'+model_name)\n", - "\n", - "\n", - "# --------------------- Here we load the augmented data or the raw data ------------------------\n", - "\n", - "print(\"In progress...\")\n", - "\n", - "Training_source_dir = Training_source_temp\n", - "Training_target_dir = Training_target_temp\n", - "# --------------------- ------------------------------------------------\n", - "\n", - "training_images_tiff=Training_source_dir+\"/*.tif\"\n", - "mask_images_tiff=Training_target_dir+\"/*.tif\"\n", - "\n", - "validation_images_tiff=Validation_source_temp+\"/*.tif\"\n", - "validation_mask_tiff=Validation_target_temp+\"/*.tif\"\n", - "\n", - "train_images = imread(sorted(glob(training_images_tiff)))\n", - "val_images = imread(sorted(glob(validation_images_tiff)))\n", - "\n", - "available_train_masks = imread(sorted(glob(mask_images_tiff)))\n", - "available_val_masks = imread(sorted(glob(validation_mask_tiff)))\n", - "\n", - "#This allows the users to not have all their training images segmented\n", - "blank_images_train = np.zeros((train_images.shape[0]-available_train_masks.shape[0], available_train_masks.shape[1], available_train_masks.shape[2]))\n", - "blank_images_val = np.zeros((val_images.shape[0]-available_val_masks.shape[0], available_val_masks.shape[1], available_val_masks.shape[2]))\n", - "blank_images_train = blank_images_train.astype(\"uint16\")\n", - "blank_images_val = blank_images_val.astype(\"uint16\")\n", - "\n", - "train_masks = np.concatenate((available_train_masks,blank_images_train), axis = 0)\n", - "val_masks = np.concatenate((available_val_masks,blank_images_val), axis = 0)\n", - "\n", - "\n", - "if not Use_Data_augmentation:\n", - " X, Y_train_masks = train_images, train_masks\n", - "\n", - "# Now we apply data augmentation to the training patches:\n", - "# Rotate four times by 90 degree and add flipped versions.\n", - "if Use_Data_augmentation:\n", - " X, Y_train_masks = augment_data(train_images, train_masks)\n", - "\n", - "X_val, Y_val_masks = val_images, val_masks\n", - "\n", - "# Here we add the channel dimension to our input images.\n", - "# Dimensionality for training has to be 'SYXC' (Sample, Y-Dimension, X-Dimension, Channel)\n", - "X = X[...,np.newaxis]\n", - "Y = convert_to_oneHot(Y_train_masks)\n", - "X_val = X_val[...,np.newaxis]\n", - "Y_val = convert_to_oneHot(Y_val_masks)\n", - "print(\"Shape of X: {}\".format(X.shape))\n", - "print(\"Shape of Y: {}\".format(Y.shape))\n", - "print(\"Shape of X_val: {}\".format(X_val.shape))\n", - "print(\"Shape of Y_val: {}\".format(Y_val.shape))\n", - "\n", - "#Here we automatically define number_of_step in function of training data and batch size\n", - "#Here we ensure that our network has a minimal number of steps\n", - "if (Use_Default_Advanced_Parameters): \n", - " number_of_steps= max(100, min(int(X.shape[0]/batch_size), 400))\n", - "\n", - "\n", - "# --------------------- Using pretrained model ------------------------\n", - "#Here we ensure that the learning rate set correctly when using pre-trained models\n", - "if Use_pretrained_model:\n", - " if Weights_choice == \"last\":\n", - " initial_learning_rate = lastLearningRate\n", - "\n", - " if Weights_choice == \"best\": \n", - " initial_learning_rate = bestLearningRate\n", - "# --------------------- ---------------------- ------------------------\n", - "\n", - "# create a Config object\n", - "\n", - "config = DenoiSegConfig(X, unet_kern_size=3, n_channel_out=4, relative_weights = [1.0,1.0,5.0],\n", - " train_steps_per_epoch=number_of_steps, train_epochs=number_of_epochs, \n", - " batch_norm=True, train_batch_size=batch_size, unet_n_first = 32, \n", - " unet_n_depth=4, denoiseg_alpha=Priority, train_learning_rate = initial_learning_rate, train_tensorboard=False)\n", - "\n", - "\n", - "# Let's look at the parameters stored in the config-object.\n", - "vars(config)\n", - " \n", - " \n", - "# create network model.\n", - "\n", - "model = DenoiSeg(config=config, name=model_name, basedir=model_path)\n", - "\n", - "\n", - "\n", - "# --------------------- Using pretrained model ------------------------\n", - "# Load the pretrained weights \n", - "if Use_pretrained_model:\n", - " model.load_weights(h5_file_path)\n", - "# --------------------- ---------------------- ------------------------\n", - "\n", - "#Export summary of training parameters as pdf\n", - "pdf_export(augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)\n", - "\n", - "print(\"Setup done.\")\n", - "print(config)\n", - "\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "MCGklf1vZf2M" - }, - "source": [ - "## **4.2. Train the network**\n", - "---\n", - "When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n", - "\n", - "* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point.\n", - "\n", - "**Of Note:** At the end of the training, your model will be automatically exported so it can be used in the CSBDeep Fiji plugin (DenoiSeg -- DenoiSeg Predict). You can find it in your model folder (export.bioimage.io.zip and model.yaml). In Fiji, Make sure to choose the right version of tensorflow. You can check at: Edit-- Options-- Tensorflow. Choose the version 1.4 (CPU or GPU depending on your system).\n", - "\n", - "\n", - "Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder." - ] - }, - { - "cell_type": "code", - "metadata": { - "cellView": "form", - "id": "xlcY9dvfm67C" - }, - "source": [ - "start = time.time()\n", - "\n", - "#@markdown ##Start Training\n", - "%memit\n", - "\n", - "\n", - "\n", - "history = model.train(X, Y, (X_val, Y_val))\n", - "\n", - "print(\"Training done.\")\n", - "%memit\n", - "\n", - "\n", - "print(\"Training, done.\")\n", - "\n", - "threshold, val_score = model.optimize_thresholds(val_images[:available_val_masks.shape[0]].astype(np.float32), val_masks, measure=measure_precision())\n", - "\n", - "print(\"The higest score of {} is achieved with threshold = {}.\".format(np.round(val_score, 3), threshold))\n", - "\n", - "\n", - "\n", - "# convert the history.history dict to a pandas DataFrame: \n", - "lossData = pd.DataFrame(history.history) \n", - "\n", - "if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n", - " shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n", - "\n", - "os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n", - "\n", - "# The training evaluation.csv is saved (overwrites the Files if needed). \n", - "lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n", - "with open(lossDataCSVpath, 'w') as f:\n", - " writer = csv.writer(f)\n", - " writer.writerow(['loss','val_loss', 'learning rate','threshold'])\n", - " for i in range(len(history.history['loss'])):\n", - " writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i], str(threshold)])\n", - "\n", - "#Thresholdpath = model_path+'/'+model_name+'/Quality Control/optimal_threshold.csv'\n", - "#with open(Thresholdpath, 'w') as f1:\n", - " #writer1 = csv.writer(f1)\n", - " #writer1.writerow(['threshold'])\n", - " #writer1.writerow([str(threshold)])\n", - "\n", - "\n", - "# Displaying the time elapsed for training\n", - "dt = time.time() - start\n", - "mins, sec = divmod(dt, 60) \n", - "hour, mins = divmod(mins, 60) \n", - "print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n", - "\n", - "model.export_TF(name='DenoiSeg', \n", - " description='DenoiSeg 2D trained using ZeroCostDL4Mic.', \n", - " authors=[\"You\"],\n", - " test_img=X_val[0,...,0], axes='YX',\n", - " patch_shape=(64, 64))\n", - "\n", - "print(\"Your model has been sucessfully exported and can now also be used in the CSBDeep Fiji plugin\")\n", - "\n", - "#Create a pdf document with training summary\n", - "\n", - "pdf_export(trained = True, augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "_0Hynw3-xHp1" - }, - "source": [ - "# **5. Evaluate your model**\n", - "---\n", - "\n", - "This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n", - "\n", - "**We highly recommend to perform quality control on all newly trained models.**\n", - "\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "eAJzMwPA6tlH", - "cellView": "form" - }, - "source": [ - "# model name and path\n", - "#@markdown ###Do you want to assess the model you just trained ?\n", - "Use_the_current_trained_model = True #@param {type:\"boolean\"}\n", - "\n", - "#@markdown ###If not, please provide the path to the model folder:\n", - "\n", - "QC_model_folder = \"\" #@param {type:\"string\"}\n", - "\n", - "#Here we define the loaded model name and path\n", - "QC_model_name = os.path.basename(QC_model_folder)\n", - "QC_model_path = os.path.dirname(QC_model_folder)\n", - "\n", - "if (Use_the_current_trained_model): \n", - " QC_model_name = model_name\n", - " QC_model_path = model_path\n", - "\n", - "full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n", - "if os.path.exists(full_QC_model_path):\n", - " print(\"The \"+QC_model_name+\" network will be evaluated\")\n", - "else:\n", - " \n", - " print(bcolors.WARNING + '!! WARNING: The chosen model does not exist !!')\n", - " print('Please make sure you provide a valid model path and model name before proceeding further.')\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "dhJROwlAMv5o" - }, - "source": [ - "## **5.1. Inspection of the loss function**\n", - "---\n", - "\n", - "It is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n", - "\n", - "**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n", - "\n", - "**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n", - "\n", - "During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n", - "\n", - "Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact noise patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "vMzSP50kMv5p", - "cellView": "form" - }, - "source": [ - "#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n", - "\n", - "lossDataFromCSV = []\n", - "vallossDataFromCSV = []\n", - "\n", - "with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n", - " csvRead = csv.reader(csvfile, delimiter=',')\n", - " next(csvRead)\n", - " for row in csvRead:\n", - " lossDataFromCSV.append(float(row[0]))\n", - " vallossDataFromCSV.append(float(row[1]))\n", - "\n", - "epochNumber = range(len(lossDataFromCSV))\n", - "plt.figure(figsize=(15,10))\n", - "\n", - "plt.subplot(2,1,1)\n", - "plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n", - "plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n", - "plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n", - "plt.ylabel('Loss')\n", - "plt.xlabel('Epoch number')\n", - "plt.legend()\n", - "\n", - "plt.subplot(2,1,2)\n", - "plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n", - "plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n", - "plt.title('Training loss and validation loss vs. epoch number (log scale)')\n", - "plt.ylabel('Loss')\n", - "plt.xlabel('Epoch number')\n", - "plt.legend()\n", - "plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png')\n", - "plt.show()\n", - "\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "X5_92nL2xdP6" - }, - "source": [ - "## **5.2. Error mapping and quality metrics estimation**\n", - "---\n", - "**DenoiSeg** allow to both denoise and segment microscopy images. This section allow you to evaluate both tasks separetly.\n", - "\n", - "**Evaluation of the denoising**\n", - "\n", - "This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_Denoising_folder\" !\n", - "\n", - "**1. The SSIM (structural similarity) map** \n", - "\n", - "The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n", - "\n", - "**mSSIM** is the SSIM value calculated across the entire window of both images.\n", - "\n", - "**The output below shows the SSIM maps with the mSSIM**\n", - "\n", - "**2. The RSE (Root Squared Error) map** \n", - "\n", - "This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n", - "\n", - "\n", - "**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n", - "\n", - "**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n", - "\n", - "**The output below shows the RSE maps with the NRMSE and PSNR values.**\n", - "\n", - "**Evaluation of the Segmentation**\n", - "\n", - "This option will calculate the Intersection over Union score for all the images provided in the Source_QC_folder and Target_Segmentation_folder ! The result for one of the image will also be displayed.\n", - "\n", - "The **Intersection over Union** metric is a method that can be used to quantify the percent overlap between the target mask and your prediction output. **Therefore, the closer to 1, the better the performance.** This metric can be used to assess the quality of your model to accurately predict nuclei. \n", - "\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "w90MdriMxhjD", - "cellView": "form" - }, - "source": [ - "#@markdown ##Choose what to evaluate\n", - "\n", - "Evaluate_Denoising = True #@param {type:\"boolean\"}\n", - "\n", - "Evaluate_Segmentation = True #@param {type:\"boolean\"}\n", - "\n", - "\n", - "# ------------- User input ------------\n", - "#@markdown ##Choose the folders that contain your Quality Control dataset\n", - "Source_QC_folder = \"\" #@param{type:\"string\"}\n", - "Target_Denoising_folder = \"\" #@param{type:\"string\"}\n", - "Target_Segmentation_folder = \"\" #@param{type:\"string\"}\n", - "\n", - "\n", - "#@markdown ###If your model was trained outside of ZeroCostDl4Mic, please provide a threshold value for the segmentation (between 0-1):\n", - "\n", - "threshold = 0.5 #@param {type:\"number\"}\n", - "\n", - "# Create a quality control/Prediction Folder\n", - "if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"):\n", - " shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n", - "\n", - "os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n", - "\n", - "#Activate the pretrained model. \n", - "config = None\n", - "model = DenoiSeg(config=None, name=QC_model_name, basedir=QC_model_path)\n", - "\n", - "#Load the threshold value. \n", - "\n", - "if os.path.exists(os.path.join(full_QC_model_path, 'Quality Control', 'training_evaluation.csv')):\n", - "\n", - " with open(os.path.join(full_QC_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n", - " csvRead = pd.read_csv(csvfile, sep=',')\n", - " #print(csvRead)\n", - " \n", - " if \"threshold\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n", - " print(\"Optimal segmentation threshold found\")\n", - " #find the last learning rate\n", - " threshold = csvRead[\"threshold\"].iloc[-1]\n", - "\n", - "# ------------- Prepare the model and run predictions ------------\n", - "# creates a loop, creating filenames and saving them\n", - "\n", - "thisdir = Path(Source_QC_folder)\n", - "\n", - "# r=root, d=directories, f = files\n", - "for r, d, f in os.walk(thisdir):\n", - " for file in f:\n", - " if \".tif\" in file:\n", - " print(os.path.join(r, file))\n", - "\n", - "for r, d, f in os.walk(thisdir):\n", - " for file in f:\n", - "\n", - "#Here we load the images\n", - " base_filename = os.path.basename(file)\n", - " test_images = imread(os.path.join(r, file))\n", - "\n", - "#Here we perform the predictions\n", - " predicted_channels = model.predict(test_images.astype(np.float32), axes='YX')\n", - " denoised_images= predicted_channels[...,0]\n", - " segmented_images= (compute_labels(predicted_channels, threshold))\n", - "\n", - "#Here we save the results\n", - " io.imsave(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"+\"/\"+\"Predicted_denoised_\"+base_filename, denoised_images)\n", - " io.imsave(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"+\"/\"+\"Predicted_segmentation_\"+base_filename, segmented_images)\n", - "\n", - "# ------------- Here we Start assessing the denoising against GT ------------\n", - "\n", - "if Evaluate_Denoising:\n", - " def ssim(img1, img2):\n", - " return structural_similarity(img1,img2,data_range=1.,full=True)\n", - "\n", - "\n", - " def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n", - " \"\"\"This function is adapted from Martin Weigert\"\"\"\n", - " \"\"\"Percentile-based image normalization.\"\"\"\n", - "\n", - " mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n", - " ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n", - " return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n", - "\n", - "\n", - " def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n", - " \"\"\"This function is adapted from Martin Weigert\"\"\"\n", - " if dtype is not None:\n", - " x = x.astype(dtype,copy=False)\n", - " mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n", - " ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n", - " eps = dtype(eps)\n", - "\n", - " try:\n", - " import numexpr\n", - " x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n", - " except ImportError:\n", - " x = (x - mi) / ( ma - mi + eps )\n", - "\n", - " if clip:\n", - " x = np.clip(x,0,1)\n", - "\n", - " return x\n", - "\n", - " def norm_minmse(gt, x, normalize_gt=True):\n", - " \"\"\"This function is adapted from Martin Weigert\"\"\"\n", - "\n", - " \"\"\"\n", - " normalizes and affinely scales an image pair such that the MSE is minimized \n", - " \n", - " Parameters\n", - " ----------\n", - " gt: ndarray\n", - " the ground truth image \n", - " x: ndarray\n", - " the image that will be affinely scaled \n", - " normalize_gt: bool\n", - " set to True of gt image should be normalized (default)\n", - " Returns\n", - " -------\n", - " gt_scaled, x_scaled \n", - " \"\"\"\n", - " if normalize_gt:\n", - " gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n", - " x = x.astype(np.float32, copy=False) - np.mean(x)\n", - " gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n", - " scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n", - " return gt, scale * x\n", - "\n", - "# Open and create the csv file that will contain all the QC metrics\n", - " with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/QC_metrics_Denoising_\"+QC_model_name+\".csv\", \"w\", newline='') as file:\n", - " writer = csv.writer(file)\n", - "\n", - " # Write the header in the csv file\n", - " writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n", - "\n", - " # Let's loop through the provided dataset in the QC folders\n", - "\n", - "\n", - " for i in os.listdir(Source_QC_folder):\n", - " if not os.path.isdir(os.path.join(Source_QC_folder,i)):\n", - " print('Running QC on: '+i)\n", - " # -------------------------------- Target test data (Ground truth) --------------------------------\n", - " test_GT = io.imread(os.path.join(Target_Denoising_folder, i))\n", - "\n", - " # -------------------------------- Source test data --------------------------------\n", - " test_source = io.imread(os.path.join(Source_QC_folder,i))\n", - "\n", - " # Normalize the images wrt each other by minimizing the MSE between GT and Source image\n", - " test_GT_norm,test_source_norm = norm_minmse(test_GT, test_source, normalize_gt=True)\n", - "\n", - " # -------------------------------- Prediction --------------------------------\n", - " test_prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\",\"Predicted_denoised_\"+i))\n", - "\n", - " # Normalize the images wrt each other by minimizing the MSE between GT and prediction\n", - " test_GT_norm,test_prediction_norm = norm_minmse(test_GT, test_prediction, normalize_gt=True) \n", - "\n", - "\n", - " # -------------------------------- Calculate the metric maps and save them --------------------------------\n", - "\n", - " # Calculate the SSIM maps\n", - " index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT_norm, test_prediction_norm)\n", - " index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT_norm, test_source_norm)\n", - "\n", - " #Save ssim_maps\n", - " img_SSIM_GTvsPrediction_32bit = np.float32(img_SSIM_GTvsPrediction)\n", - " io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/SSIM_GTvsPrediction_'+i,img_SSIM_GTvsPrediction_32bit)\n", - " img_SSIM_GTvsSource_32bit = np.float32(img_SSIM_GTvsSource)\n", - " io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/SSIM_GTvsSource_'+i,img_SSIM_GTvsSource_32bit)\n", - " \n", - " # Calculate the Root Squared Error (RSE) maps\n", - " img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n", - " img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n", - "\n", - " # Save SE maps\n", - " img_RSE_GTvsPrediction_32bit = np.float32(img_RSE_GTvsPrediction)\n", - " img_RSE_GTvsSource_32bit = np.float32(img_RSE_GTvsSource)\n", - " io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/RSE_GTvsPrediction_'+i,img_RSE_GTvsPrediction_32bit)\n", - " io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/RSE_GTvsSource_'+i,img_RSE_GTvsSource_32bit)\n", - "\n", - "\n", - " # -------------------------------- Calculate the RSE metrics and save them --------------------------------\n", - "\n", - " # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n", - " NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n", - " NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n", - " \n", - " # We can also measure the peak signal to noise ratio between the images\n", - " PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n", - " PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n", - "\n", - " writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource),str(PSNR_GTvsPrediction),str(PSNR_GTvsSource)])\n", - "\n", - "\n", - "# All data is now processed saved\n", - " Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same\n", - " norm = simple_norm(x, percent = 99)\n", - "\n", - " plt.figure(figsize=(15,15))\n", - " # Currently only displays the last computed set, from memory\n", - " # Target (Ground-truth)\n", - " plt.subplot(3,3,1)\n", - " plt.axis('off')\n", - " img_GT = io.imread(os.path.join(Target_Denoising_folder, Test_FileList[-1]))\n", - " plt.imshow(img_GT, norm=norm, cmap='magma', interpolation='nearest')\n", - " plt.title('Target',fontsize=15)\n", - "\n", - "# Source\n", - " plt.subplot(3,3,2)\n", - " plt.axis('off')\n", - " img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n", - " plt.imshow(img_Source, norm=norm, cmap='magma', interpolation='nearest')\n", - " plt.title('Source',fontsize=15)\n", - "\n", - "#Prediction\n", - " plt.subplot(3,3,3)\n", - " plt.axis('off')\n", - " img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction/\", \"Predicted_denoised_\"+Test_FileList[-1]))\n", - " plt.imshow(img_Prediction, norm=norm, cmap='magma', interpolation='nearest')\n", - " plt.title('Prediction',fontsize=15)\n", - "\n", - "#Setting up colours\n", - " cmap = plt.cm.CMRmap\n", - "\n", - " #SSIM between GT and Source\n", - " plt.subplot(3,3,5)\n", - " #plt.axis('off')\n", - " plt.tick_params(\n", - " axis='both', # changes apply to the x-axis and y-axis\n", - " which='both', # both major and minor ticks are affected\n", - " bottom=False, # ticks along the bottom edge are off\n", - " top=False, # ticks along the top edge are off\n", - " left=False, # ticks along the left edge are off\n", - " right=False, # ticks along the right edge are off\n", - " labelbottom=False,\n", - " labelleft=False) \n", - " imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n", - " plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n", - " plt.title('Target vs. Source',fontsize=15)\n", - " plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\n", - " plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n", - "\n", - "#SSIM between GT and Prediction\n", - " plt.subplot(3,3,6)\n", - " #plt.axis('off')\n", - " plt.tick_params(\n", - " axis='both', # changes apply to the x-axis and y-axis\n", - " which='both', # both major and minor ticks are affected\n", - " bottom=False, # ticks along the bottom edge are off\n", - " top=False, # ticks along the top edge are off\n", - " left=False, # ticks along the left edge are off\n", - " right=False, # ticks along the right edge are off\n", - " labelbottom=False,\n", - " labelleft=False) \n", - " imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n", - " plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n", - " plt.title('Target vs. Prediction',fontsize=15)\n", - " plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\n", - "\n", - "#Root Squared Error between GT and Source\n", - " plt.subplot(3,3,8)\n", - "#plt.axis('off')\n", - " plt.tick_params(\n", - " axis='both', # changes apply to the x-axis and y-axis\n", - " which='both', # both major and minor ticks are affected\n", - " bottom=False, # ticks along the bottom edge are off\n", - " top=False, # ticks along the top edge are off\n", - " left=False, # ticks along the left edge are off\n", - " right=False, # ticks along the right edge are off\n", - " labelbottom=False,\n", - " labelleft=False) \n", - " imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource, cmap = cmap, vmin=0, vmax = 1)\n", - " plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n", - " plt.title('Target vs. Source',fontsize=15)\n", - " plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsSource,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n", - "\n", - " plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n", - "\n", - "#Root Squared Error between GT and Prediction\n", - " plt.subplot(3,3,9)\n", - " #plt.axis('off')\n", - " plt.tick_params(\n", - " axis='both', # changes apply to the x-axis and y-axis\n", - " which='both', # both major and minor ticks are affected\n", - " bottom=False, # ticks along the bottom edge are off\n", - " top=False, # ticks along the top edge are off\n", - " left=False, # ticks along the left edge are off\n", - " right=False, # ticks along the right edge are off\n", - " labelbottom=False,\n", - " labelleft=False) \n", - " imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n", - " plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n", - " plt.title('Target vs. Prediction',fontsize=15)\n", - " plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsPrediction,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n", - " plt.savefig(full_QC_model_path+'/Quality Control/QC_example_denoising.png',bbox_inches='tight',pad_inches=0)\n", - "#________________________________________________________________________\n", - "# Here we start testing the differences between GT and predicted masks\n", - "\n", - "if Evaluate_Segmentation:\n", - "\n", - " with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/QC_metrics_Segmentation_\"+QC_model_name+\".csv\", \"w\", newline='') as file:\n", - " writer = csv.writer(file)\n", - " writer.writerow([\"image\",\"Prediction v. GT Intersection over Union\"]) \n", - "\n", - "# define the images\n", - "\n", - " for n in os.listdir(Source_QC_folder):\n", - " \n", - " if not os.path.isdir(os.path.join(Source_QC_folder,n)):\n", - " print('Running QC on: '+n)\n", - " test_input = io.imread(os.path.join(Source_QC_folder,n))\n", - " test_prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\",\"Predicted_segmentation_\"+n))\n", - " test_ground_truth_image = io.imread(os.path.join(Target_Segmentation_folder, n))\n", - "\n", - " #Convert pixel values to 0 or 255\n", - " test_prediction_0_to_255 = test_prediction\n", - " test_prediction_0_to_255[test_prediction_0_to_255>0] = 255\n", - "\n", - " #Convert pixel values to 0 or 255\n", - " test_ground_truth_0_to_255 = test_ground_truth_image\n", - " test_ground_truth_0_to_255[test_ground_truth_0_to_255>0] = 255\n", - "\n", - " # Intersection over Union metric\n", - "\n", - " intersection = np.logical_and(test_ground_truth_0_to_255, test_prediction_0_to_255)\n", - " union = np.logical_or(test_ground_truth_0_to_255, test_prediction_0_to_255)\n", - " iou_score = np.sum(intersection) / np.sum(union)\n", - " writer.writerow([n, str(iou_score)])\n", - "\n", - "\n", - " from astropy.visualization import simple_norm\n", - "\n", - " # ------------- For display ------------\n", - " print('--------------------------------------------------------------')\n", - " @interact\n", - " def show_QC_results(file = os.listdir(Source_QC_folder)):\n", - "\n", - " plt.figure(figsize=(25,5))\n", - " source_image = io.imread(os.path.join(Source_QC_folder, file), as_gray = True)\n", - " target_image = io.imread(os.path.join(Target_Segmentation_folder, file), as_gray = True)\n", - " prediction = io.imread(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction/Predicted_segmentation_\"+file, as_gray = True)\n", - "\n", - " target_image_mask = np.empty_like(target_image)\n", - " target_image_mask[target_image > 0] = 255\n", - " target_image_mask[target_image == 0] = 0\n", - " \n", - " prediction_mask = np.empty_like(prediction)\n", - " prediction_mask[prediction > 0] = 255\n", - " prediction_mask[prediction == 0] = 0\n", - "\n", - " intersection = np.logical_and(target_image_mask, prediction_mask)\n", - " union = np.logical_or(target_image_mask, prediction_mask)\n", - " iou_score = np.sum(intersection) / np.sum(union)\n", - "\n", - " norm = simple_norm(source_image, percent = 99)\n", - "\n", - "\n", - " #Input\n", - " plt.subplot(1,4,1)\n", - " plt.axis('off')\n", - " plt.imshow(source_image, aspect='equal', norm=norm, cmap='magma', interpolation='nearest')\n", - " plt.title('Input')\n", - "\n", - " #Ground-truth\n", - " plt.subplot(1,4,2)\n", - " plt.axis('off')\n", - " plt.imshow(target_image_mask, aspect='equal', cmap='Greens')\n", - " plt.title('Ground Truth')\n", - "\n", - " #Prediction\n", - " plt.subplot(1,4,3)\n", - " plt.axis('off')\n", - " plt.imshow(prediction_mask, aspect='equal', cmap='Purples')\n", - " plt.title('Prediction')\n", - "\n", - " #Overlay\n", - " plt.subplot(1,4,4)\n", - " plt.axis('off')\n", - " plt.imshow(target_image_mask, cmap='Greens')\n", - " plt.imshow(prediction_mask, alpha=0.5, cmap='Purples')\n", - " plt.title('Ground Truth and Prediction, Intersection over Union:'+str(round(iou_score,3)));\n", - " plt.savefig(full_QC_model_path+'/Quality Control/QC_example_segmentation.png',bbox_inches='tight',pad_inches=0)\n", - "\n", - "#Export pdf summary of QC results\n", - "qc_pdf_export()\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-tJeeJjLnRkP" - }, - "source": [ - "# **6. Using the trained model**\n", - "\n", - "---\n", - "\n", - "In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "d8wuQGjoq6eN" - }, - "source": [ - "## **6.1. Generate prediction(s) from unseen dataset**\n", - "---\n", - "\n", - "The current trained model (from section 4.2) can now be used to process images. If an older model needs to be used, please untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n", - "\n", - "**`Data_folder`:** This folder should contains the images that you want to predict using the network that you will train.\n", - "\n", - "**`Result_folder`:** This folder will contain the predicted output images." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "y2TD5p7MZrEb", - "cellView": "form" - }, - "source": [ - "import imageio\n", - "\n", - "\n", - "#@markdown ### Provide the path to your dataset and to the folder where the prediction will be saved, then play the cell to predict output on your unseen images.\n", - "\n", - "#@markdown ###Path to data to analyse and where predicted output should be saved:\n", - "Data_folder = \"\" #@param {type:\"string\"}\n", - "Result_folder = \"\" #@param {type:\"string\"}\n", - "\n", - "# model name and path\n", - "#@markdown ###Do you want to use the current trained model?\n", - "Use_the_current_trained_model = True #@param {type:\"boolean\"}\n", - "\n", - "#@markdown ###If not, please provide the path to the model folder:\n", - "\n", - "Prediction_model_folder = \"\" #@param {type:\"string\"}\n", - "\n", - "\n", - "#@markdown ###If your model was trained outside of ZeroCostDl4Mic, please provide a Threshold value for the segmentation (between 0-1):\n", - "\n", - "threshold = 0.5 #@param {type:\"number\"}\n", - "\n", - "\n", - "#Here we find the loaded model name and parent path\n", - "Prediction_model_name = os.path.basename(Prediction_model_folder)\n", - "Prediction_model_path = os.path.dirname(Prediction_model_folder)\n", - "\n", - "if (Use_the_current_trained_model): \n", - " print(\"Using current trained network\")\n", - " Prediction_model_name = model_name\n", - " Prediction_model_path = model_path\n", - "\n", - "full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n", - "if os.path.exists(full_Prediction_model_path):\n", - " print(\"The \"+Prediction_model_name+\" network will be used.\")\n", - "else:\n", - " print(bcolors.WARNING +'!! WARNING: The chosen model does not exist !!')\n", - " print('Please make sure you provide a valid model path and model name before proceeding further.')\n", - "\n", - "\n", - "#Activate the pretrained model. \n", - "config = None\n", - "model = DenoiSeg(config=None, name=Prediction_model_name, basedir=Prediction_model_path)\n", - "\n", - "#Load the threshold value. \n", - "\n", - "if os.path.exists(os.path.join(full_Prediction_model_path, 'Quality Control', 'training_evaluation.csv')):\n", - "\n", - " with open(os.path.join(full_Prediction_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n", - " csvRead = pd.read_csv(csvfile, sep=',')\n", - " #print(csvRead)\n", - " \n", - " if \"threshold\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n", - " print(\"Optimal segmentation threshold found\")\n", - " #find the last learning rate\n", - " threshold = csvRead[\"threshold\"].iloc[-1]\n", - "\n", - "# creates a loop, creating filenames and saving them\n", - "\n", - "thisdir = Path(Data_folder)\n", - "outputdir = Path(Result_folder)\n", - "\n", - "# r=root, d=directories, f = files\n", - "for r, d, f in os.walk(thisdir):\n", - " for file in f:\n", - " if \".tif\" in file:\n", - " print(os.path.join(r, file))\n", - "\n", - "print(\"Processing...\")\n", - "for r, d, f in os.walk(thisdir):\n", - " for file in f:\n", - "\n", - "#Here we load the images\n", - " base_filename = os.path.basename(file)\n", - " test_images = imread(os.path.join(r, file))\n", - "\n", - "#Here we perform the predictions\n", - " predicted_channels = model.predict(test_images.astype(np.float32), axes='YX')\n", - " denoised_images= predicted_channels[...,0]\n", - " segmented_images= (compute_labels(predicted_channels, threshold))\n", - "\n", - "#Here we save the results\n", - " io.imsave(Result_folder+\"/\"+\"Predicted_denoised_\"+base_filename, denoised_images)\n", - " io.imsave(Result_folder+\"/\"+\"Predicted_segmentation_\"+base_filename,segmented_images)\n", - " \n", - "\n", - "\n", - "print(\"Images saved into folder:\", Result_folder)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Qg31ghpfoNBD" - }, - "source": [ - "## **6.2. Assess predicted output**\n", - "---\n", - "\n", - "\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "cellView": "form", - "id": "CH7t08UooLba" - }, - "source": [ - "\n", - "# @markdown ##Run this cell to display a randomly chosen input and its corresponding predicted output.\n", - "\n", - "\n", - "# This will display a randomly chosen dataset input and predicted output\n", - "random_choice = random.choice(os.listdir(Data_folder))\n", - "x = imread(Data_folder+\"/\"+random_choice)\n", - "\n", - "os.chdir(Result_folder)\n", - "y = imread(Result_folder+\"/\"+\"Predicted_denoised_\"+random_choice)\n", - "z = imread(Result_folder+\"/\"+\"Predicted_segmentation_\"+random_choice)\n", - "\n", - "norm = simple_norm(x, percent = 99)\n", - "\n", - "plt.figure(figsize=(30,15))\n", - "plt.subplot(1, 4, 1)\n", - "plt.imshow(x, interpolation='nearest', norm=norm, cmap='magma')\n", - "plt.axis('off');\n", - "plt.title(\"Input\")\n", - "\n", - "plt.subplot(1, 4, 2)\n", - "plt.imshow(y, interpolation='nearest', norm=norm, cmap='magma')\n", - "plt.axis('off');\n", - "plt.title(\"Predicted denoised image\")\n", - "\n", - "plt.subplot(1, 4, 3)\n", - "plt.imshow(z, interpolation='nearest', vmin=0, vmax=1, cmap='viridis')\n", - "plt.axis('off');\n", - "plt.title(\"Predicted segmentation\")\n", - "\n", - "plt.show()\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "hvkd66PldsXB" - }, - "source": [ - "## **6.3. Download your predictions**\n", - "---\n", - "\n", - "**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "UvSlTaH14s3t" - }, - "source": [ - "#**Thank you for using DenoiSeg!**" - ] - } - ] -} \ No newline at end of file +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"DenoiSeg_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1A26cn0nxWQCv-LuP3UBfyCWlKBGIo0RU","timestamp":1611075289867},{"file_id":"1MmLTCC0nyX3Akb9V4C_OVxM3X_M8u-eX","timestamp":1610543191319},{"file_id":"1paNjUObR5Rcr4BMGADJTz0PQBBLZDPrY","timestamp":1602522500580},{"file_id":"1WZRIoSBNcRUEq4-Rq5M4mDkIaOlEHnxz","timestamp":1588762142860},{"file_id":"10weAY0es-pEfHlACCaBCKK7PmgdoJqdh","timestamp":1587728072051},{"file_id":"10Ze0rFZoooyyTL_OIVWGdFJEhWE6_cSB","timestamp":1586789421439},{"file_id":"1SsGyUbWcMaLGHFepMuKElRNYLdEBUwf6","timestamp":1583244509550}],"collapsed_sections":[],"machine_shape":"hm"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.7"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"IkSguVy8Xv83"},"source":["# **Image denoising and segmentation using DenoiSeg 2D**\n","\n","---\n","\n"," DenoiSeg 2D is deep-learning method that can be used to jointly denoise and segment 2D microscopy images. By running this notebook, you can train your and use you own network. \n","\n"," The benefits of using DenoiSeg (compared to other Deep Learning-based segmentation methods) are more prononced when only a few annotated images are available. However, the denoising part requires many images to perform well. All the noisy images don't need to be labeled to train DenoiSeg.\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the paper: **DenoiSeg: Joint Denoising and Segmentation**\n","Tim-Oliver Buchholz, Mangal Prakash, Alexander Krull, Florian Jug\n","https://arxiv.org/abs/2005.02987\n","\n","And source code found in: /~https://github.com/juglab/DenoiSeg/wiki\n","\n","\n","\n","**Please also cite this original paper when using or developing this notebook.**\n"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"gKDLkLWUd-YX"},"source":["# **0. Before getting started**\n","---\n","\n","Before you run the notebook, please ensure that you are logged into your Google account and have the training and/or data to process in your Google Drive.\n","\n","**it needs to have access to a paired training dataset made of images and their corresponding masks**. Information on how to generate a training dataset is available in our Wiki page: /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","**Importantly, the benefits of using DenoiSeg are more pronounced when only limited numbers of segmentation annotations are available for training. However, DenoiSeg also expects that lots of noisy raw images are available to train the denoising part. It is therefore not required for all the noisy images to be annotated to train DenoiSeg**.\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model**. The quality control assessment can be done directly in this notebook.\n","\n","The data structure is important. It is necessary that all the input data are in the same folder and that all the output data is in a separate folder. The provided training dataset is already split.\n","\n","Additionally, the corresponding Training_source and Training_target files need to have **the same name**.\n","\n","Please note that you currently can **only use .tif files!**\n","\n","You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed. This can include Test dataset for which you have the equivalent output and can compare to what the network provides.\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Noisy Images (Training_source)\n"," - img_1.tif, img_2.tif, img_3.tif, img_4.tif, ...\n"," - Masks (Training_target)\n"," - img_1.tif, img_2.tif\n"," - **Quality control dataset (optional, not required for training)**\n"," - Noisy Images\n"," - img_1.tif, img_2.tif\n"," - High SNR Images\n"," - img_1.tif, img_2.tif\n"," - Masks \n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"AdN8B91xZO0x"},"source":["# **1. Install DenoiSeg and Dependencies**\n","---"]},{"cell_type":"markdown","metadata":{"id":"nwYTuUfOrtPj"},"source":["## **1.1. Install key dependencies**\n","---\n"," "]},{"cell_type":"code","metadata":{"id":"7TIzloUDrwOA","cellView":"form"},"source":["\n","#@markdown ##Install DenoiSeg and dependencies\n","\n","%tensorflow_version 1.x\n","import tensorflow \n","\n","!pip uninstall -y keras-nightly\n","\n","!pip3 install h5py==2.10.0\n","\n","!pip install denoiseg\n","!pip install wget\n","!pip install fpdf\n","!pip install memory_profiler\n","\n","!pip install q keras==2.2.5\n","\n","#Force session restart\n","exit(0)\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"XY4ZDsAMr3Ef"},"source":["## **1.2. Restart your runtime**\n","---\n","\n","\n","\n","** Ignore the following message error message. Your Runtime has automatically restarted. This is normal.**\n","\n","\"\"
\n","\n"]},{"cell_type":"markdown","metadata":{"id":"i2d3Z_Ggr7Ng"},"source":["## **1.3. Load key dependencies**\n","---\n"," "]},{"cell_type":"code","metadata":{"id":"fq21zJVFNASx","cellView":"form"},"source":["#@markdown ##Load key dependencies\n","\n","\n","Notebook_version = '1.13'\n","Network = 'DenoiSeg'\n","\n","\n","from builtins import any as b_any\n","\n","def get_requirements_path():\n"," # Store requirements file in 'contents' directory\n"," current_dir = os.getcwd()\n"," dir_count = current_dir.count('/') - 1\n"," path = '../' * (dir_count) + 'requirements.txt'\n"," return path\n","\n","def filter_files(file_list, filter_list):\n"," filtered_list = []\n"," for fname in file_list:\n"," if b_any(fname.split('==')[0] in s for s in filter_list):\n"," filtered_list.append(fname)\n"," return filtered_list\n","\n","def build_requirements_file(before, after):\n"," path = get_requirements_path()\n","\n"," # Exporting requirements.txt for local run\n"," !pip freeze > $path\n","\n"," # Get minimum requirements file\n"," df = pd.read_csv(path, delimiter = \"\\n\")\n"," mod_list = [m.split('.')[0] for m in after if not m in before]\n"," req_list_temp = df.values.tolist()\n"," req_list = [x[0] for x in req_list_temp]\n","\n"," # Replace with package name and handle cases where import name is different to module name\n"," mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]\n"," mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list]\n"," filtered_list = filter_files(req_list, mod_replace_list)\n","\n"," file=open(path,'w')\n"," for item in filtered_list:\n"," file.writelines(item + '\\n')\n","\n"," file.close()\n","\n","import sys\n","before = [str(m) for m in sys.modules]\n","\n","# Here we enable Tensorflow 1. \n","%tensorflow_version 1.x\n","import tensorflow\n","print(tensorflow.__version__)\n","\n","print(\"Tensorflow enabled.\")\n","\n","%load_ext memory_profiler\n","\n","\n","# Here we install all libraries and other depencies to run the notebook.\n","\n","# ------- Variable specific to Denoiseg -------\n","\n","import numpy as np\n","from matplotlib import pyplot as plt\n","from scipy import ndimage\n","\n","from denoiseg.models import DenoiSeg, DenoiSegConfig\n","from denoiseg.utils.misc_utils import combine_train_test_data, shuffle_train_data, augment_data\n","from denoiseg.utils.seg_utils import *\n","from denoiseg.utils.compute_precision_threshold import measure_precision, compute_labels\n","\n","from csbdeep.utils import plot_history\n","from tifffile import imread, imsave\n","from glob import glob\n","\n","import urllib\n","import os\n","import zipfile\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","\n","from fpdf import FPDF, HTMLMixin\n","from datetime import datetime\n","import subprocess\n","from pip._internal.operations.freeze import freeze\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","# For sliders and dropdown menu and progress bar\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","\n","\n","# Check if this is the latest version of the notebook\n","All_notebook_versions = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_Notebook_versions.csv\", dtype=str)\n","print('Notebook version: '+Notebook_version)\n","Latest_Notebook_version = All_notebook_versions[All_notebook_versions[\"Notebook\"] == Network]['Version'].iloc[0]\n","print('Latest notebook version: '+Latest_Notebook_version)\n","if Notebook_version == Latest_Notebook_version:\n"," print(\"This notebook is up-to-date.\")\n","else:\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","\n","def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Training report for '+Network+' model ('+model_name+')\\nDate and Time: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," \n"," # add another cell \n"," if trained:\n"," training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n"," pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n"," pdf.ln(1)\n","\n"," Header_2 = 'Information for your materials and method:'\n"," pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n"," #print(all_packages)\n","\n"," #Main Packages\n"," main_packages = ''\n"," version_numbers = []\n"," for name in ['tensorflow','numpy','Keras','csbdeep']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n"," cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n"," cuda_version = cuda_version.stdout.decode('utf-8')\n"," cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n"," gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n"," gpu_name = gpu_name.stdout.decode('utf-8')\n"," gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n"," #print(cuda_version[cuda_version.find(', V')+3:-1])\n"," #print(gpu_name)\n","\n"," shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n"," dataset_size = len(os.listdir(Training_source))\n","\n"," text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', with a batch size of '+str(batch_size)+' and using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," #text = 'The '+Network+' model ('+model_name+') was trained using '+str(dataset_size)+' paired images (image dimensions: '+str(shape)+') using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The GPU used was a '+gpu_name+'.'\n","\n"," if pretrained_model:\n"," text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+') with a batch size of '+str(batch_size)+' and using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was retrained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.multi_cell(190, 5, txt = text, align='L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n"," pdf.set_font('')\n"," if augmentation:\n"," aug_text = 'Data augmentation was enabled'\n","\n"," else:\n"," aug_text = 'No augmentation was used for training.'\n"," pdf.multi_cell(190, 5, txt=aug_text, align='L')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," if Use_Default_Advanced_Parameters:\n"," pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n"," pdf.cell(200, 5, txt='The following parameters were used for training:')\n"," pdf.ln(1)\n"," html = \"\"\" \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
number_of_epochs{0}
batch_size{1}
number_of_steps{2}
percentage_validation{3}
Priority{4}
initial_learning_rate{5}
\n"," \"\"\".format(number_of_epochs,batch_size,number_of_steps,percentage_validation,Priority,initial_learning_rate)\n"," pdf.write_html(html)\n","\n"," #pdf.multi_cell(190, 5, txt = text_2, align='L')\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," pdf.ln(1)\n"," pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(30, 5, txt= 'Training_source:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(28, 5, txt= 'Training_target:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n"," #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(21, 5, txt= 'Model Path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n"," pdf.ln(1)\n"," pdf.cell(60, 5, txt = 'Example Training pair', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread('/content/TrainingDataExample_DenoiSeg.png').shape\n"," pdf.image('/content/TrainingDataExample_DenoiSeg.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- DenoiSeg: Buchholz, Prakash, et al. \"DenoiSeg: Joint Denoising and Segmentation\", arXiv 2020.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n","\n","\n"," pdf.ln(3)\n"," reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(model_path+'/'+model_name+'/'+model_name+\"_training_report.pdf\")\n","\n","def qc_pdf_export():\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'DenoiSeg_2D'\n","\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(2)\n"," pdf.cell(190, 5, txt = 'Loss curves', ln=1, align='L')\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_segmentation.png').shape\n"," if os.path.exists(full_QC_model_path+'/Quality Control/lossCurvePlots.png'):\n"," pdf.image(full_QC_model_path+'/Quality Control/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/12), h = round(exp_size[0]/3))\n"," else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.',align='L')\n","\n"," if Evaluate_Segmentation:\n","\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(3)\n"," pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_segmentation.png').shape\n"," pdf.image(full_QC_model_path+'/Quality Control/QC_example_segmentation.png', x = 16, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n","\n"," pdf.ln(1)\n"," html = \"\"\"\n"," \n"," \n"," \"\"\"\n"," with open(full_QC_model_path+\"/Quality Control/QC_metrics_Segmentation_\"+QC_model_name+\".csv\", 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," IoU = header[1] \n"," header = \"\"\"\n"," \n"," \n"," \n"," \"\"\".format(image,IoU)\n"," html = html+header\n"," i=0\n"," for row in metrics:\n"," i+=1\n"," image = row[0]\n"," IoU = row[1] \n"," cells = \"\"\"\n"," \n"," \n"," \n"," \"\"\".format(image,str(round(float(IoU),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}
{0}{1}
\"\"\"\n"," \n"," pdf.write_html(html)\n","\n"," if Evaluate_Denoising:\n","\n"," pdf.ln(2)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(3)\n"," pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_denoising.png').shape\n"," pdf.image(full_QC_model_path+'/Quality Control/QC_example_denoising.png', x = 16, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/10))\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n","\n"," pdf.ln(1)\n"," html = \"\"\"\n"," \n"," \n"," \"\"\"\n"," with open(full_QC_model_path+\"/Quality Control/QC_metrics_Denoising_\"+QC_model_name+\".csv\", 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," mSSIM_PvsGT = header[1]\n"," mSSIM_SvsGT = header[2]\n"," NRMSE_PvsGT = header[3]\n"," NRMSE_SvsGT = header[4]\n"," PSNR_PvsGT = header[5]\n"," PSNR_SvsGT = header[6]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,mSSIM_PvsGT,mSSIM_SvsGT,NRMSE_PvsGT,NRMSE_SvsGT,PSNR_PvsGT,PSNR_SvsGT)\n"," html = html+header\n"," for row in metrics:\n"," image = row[0]\n"," mSSIM_PvsGT = row[1]\n"," mSSIM_SvsGT = row[2]\n"," NRMSE_PvsGT = row[3]\n"," NRMSE_SvsGT = row[4]\n"," PSNR_PvsGT = row[5]\n"," PSNR_SvsGT = row[6]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,str(round(float(mSSIM_PvsGT),3)),str(round(float(mSSIM_SvsGT),3)),str(round(float(NRMSE_PvsGT),3)),str(round(float(NRMSE_SvsGT),3)),str(round(float(PSNR_PvsGT),3)),str(round(float(PSNR_SvsGT),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}{3}{4}{5}{6}
{0}{1}{2}{3}{4}{5}{6}
\"\"\"\n"," \n"," pdf.write_html(html)\n","\n","\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- Unet: Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. \"U-net: Convolutional networks for biomedical image segmentation.\" International Conference on Medical image computing and computer-assisted intervention. Springer, Cham, 2015.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n"," pdf.ln(3)\n"," reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n","\n"," print('------------------------------')\n"," print('QC PDF report exported as '+full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n","\n","# Build requirements file for local run\n","after = [str(m) for m in sys.modules]\n","build_requirements_file(before, after)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **2. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["\n","## **2.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelerator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"zCvebubeSaGY","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"sNIVx8_CLolt"},"source":["## **2.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","cellView":"form"},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"NfWUBJ-brd8V"},"source":["** If you cannot see your files, reactivate your session by connecting to your hosted runtime.** \n","\n","\n","\"Example
Connect to a hosted runtime.
"]},{"cell_type":"markdown","metadata":{"id":"HLYcZR9gMv42"},"source":["# **3. Select your parameters and paths**\n","---"]},{"cell_type":"markdown","metadata":{"id":"FQ_QxtSWQ7CL"},"source":["## **3.1. Setting main training parameters**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"id":"AuESFimvMv43"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`** These is the path to your folders containing the Training_source (noisy images). To find the path of the folder containing your datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Do not re-use the name of an existing model (saved in the same folder), otherwise it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","\n","**Training Parameters**\n","\n","**`number_of_epochs`:** Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for 100-200 epochs. Evaluate the performance after training (see 4.3.). **Default value: 30**\n"," \n","**Advanced Parameters - experienced users only**\n","\n","**`Priority`:** Choose how much relative the importance to assign to the denoising \n","and segmentation tasks by choosing an appropriate value (between 0 and 1; with 0 being only segmentation and 1 being only denoising. **Default value: 0.5**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each image / patch is seen at least once per epoch. **Default value: depends on number of patches, min 100; max 400**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Noise2Void requires a large batch size for stable training. Reduce this parameter if your GPU runs out of memory. **Default value: 128**\n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0004**\n"]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","cellView":"form"},"source":["# create DataGenerator-object.\n","\n","\n","#@markdown ###Path to training image(s): \n","Training_source = \"\" #@param {type:\"string\"}\n","Training_target = \"\" #@param {type:\"string\"}\n","\n","\n","#@markdown ### Model name and path:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","#@markdown ###Training Parameters\n","#@markdown Number of epochs:\n","number_of_epochs = 30#@param {type:\"number\"}\n","\n","#@markdown ###Advanced Parameters\n","Use_Default_Advanced_Parameters = True#@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please input:\n","Priority = 0.5#@param {type:\"number\"}\n","number_of_steps = 100#@param {type:\"number\"}\n","batch_size = 128#@param {type:\"number\"}\n","percentage_validation = 10#@param {type:\"number\"}\n","initial_learning_rate = 0.0004 #@param {type:\"number\"}\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," # number_of_steps is defined in the following cell in this case\n"," Priority = 0.5\n"," batch_size = 128\n"," percentage_validation = 10\n"," initial_learning_rate = 0.0004\n"," \n","#here we check that no model with the same name already exist, if so print a warning\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n"," print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\"+W)\n","\n","# This will open a randomly chosen dataset input image\n","random_choice = random.choice(os.listdir(Training_target))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","# Here we disable pre-trained model by default (in case the next cell is not ran)\n","Use_pretrained_model = False\n","\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","Use_Data_augmentation = True\n","\n","# Here we count the number of files in the training target folder\n","Mask_Filelist = os.listdir(Training_target)\n","Mask_number_files = len(Mask_Filelist)\n","\n","# Here we count the number of file to use for validation\n","Mask_for_validation = int((Mask_number_files)/percentage_validation)\n","\n","if Mask_for_validation == 0:\n"," Mask_for_validation = 2\n","if Mask_for_validation == 1:\n"," Mask_for_validation = 2\n","\n","# Here we count the number of files in the training target folder\n","Noisy_Filelist = os.listdir(Training_source)\n","Noisy_number_files = len(Noisy_Filelist)\n","\n","# Here we count the number of file to use for validation\n","Noisy_for_validation = int((Noisy_number_files)/percentage_validation)\n","\n","if Noisy_for_validation == 0:\n"," Noisy_for_validation = 1\n","\n","#Here we find the noisy images that do not have masks\n","noisy_image_no_mask_list = list(set(Noisy_Filelist) - set(Mask_Filelist))\n","\n","\n","#Here we split the training dataset between training and validation\n","# Everything is copied in the /Content Folder\n","Training_source_temp = \"/content/training_source\"\n","\n","if os.path.exists(Training_source_temp):\n"," shutil.rmtree(Training_source_temp)\n","os.makedirs(Training_source_temp)\n","\n","Training_target_temp = \"/content/training_target\"\n","if os.path.exists(Training_target_temp):\n"," shutil.rmtree(Training_target_temp)\n","os.makedirs(Training_target_temp)\n","\n","Validation_source_temp = \"/content/validation_source\"\n","\n","if os.path.exists(Validation_source_temp):\n"," shutil.rmtree(Validation_source_temp)\n","os.makedirs(Validation_source_temp)\n","\n","Validation_target_temp = \"/content/validation_target\"\n","if os.path.exists(Validation_target_temp):\n"," shutil.rmtree(Validation_target_temp)\n","os.makedirs(Validation_target_temp)\n","\n","list_source = os.listdir(os.path.join(Training_source))\n","list_target = os.listdir(os.path.join(Training_target))\n","\n","#Move files into the temporary source and target directories:\n","\n","for f in os.listdir(os.path.join(Training_source)):\n"," shutil.copy(Training_source+\"/\"+f, Training_source_temp+\"/\"+f)\n","\n","for p in os.listdir(os.path.join(Training_target)):\n"," shutil.copy(Training_target+\"/\"+p, Training_target_temp+\"/\"+p)\n","\n","#Here we move images to be used for validation\n","for i in range(Mask_for_validation): \n"," shutil.move(Training_source_temp+\"/\"+list_target[i], Validation_source_temp+\"/\"+list_target[i])\n"," shutil.move(Training_target_temp+\"/\"+list_target[i], Validation_target_temp+\"/\"+list_target[i])\n","\n","#Here we move a few more noisy images for validation\n","if noisy_image_no_mask_list:\n"," for y in range(Noisy_for_validation): \n"," shutil.move(Training_source_temp+\"/\"+noisy_image_no_mask_list[y], Validation_source_temp+\"/\"+noisy_image_no_mask_list[y])\n","\n","\n","print(\"Parameters initiated.\")\n","\n","y = imread(Training_target+\"/\"+random_choice)\n","\n","#Here we display one image\n","norm = simple_norm(x, percent = 99)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest', norm=norm, cmap='magma')\n","plt.title('Training source')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(y, interpolation='nearest', vmin=0, vmax=1, cmap='viridis')\n","plt.title('Training target')\n","plt.axis('off');\n","plt.savefig('/content/TrainingDataExample_DenoiSeg.png',bbox_inches='tight',pad_inches=0)\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xyQZKby8yFME"},"source":["## **3.2. Data augmentation**\n","---\n",""]},{"cell_type":"markdown","metadata":{"id":"w_jCy7xOx2g3"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n","Data augmentation is performed here by rotating the patches in XY-Plane and flip them along X-Axis (multiply the dataset by 8). \n","\n"," **By default data augmentation is enabled. Disable this option is you run out of RAM during the training**.\n"," "]},{"cell_type":"code","metadata":{"id":"DMqWq5-AxnFU","cellView":"form"},"source":["#Data augmentation\n","\n","Use_Data_augmentation = False #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation:\n"," print(\"Data augmentation enabled\")\n","\n","if not Use_Data_augmentation:\n"," print(\"Data augmentation disabled\")\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"3L9zSGtORKYI"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a DenoiSeg model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"9vC2n-HeLdiJ","cellView":"form"},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","\n","Weights_choice = \"best\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","\n","# --------------------- Download the a model provided in the XXX ------------------------\n","\n"," if pretrained_model_choice == \"Model_name\":\n"," pretrained_model_name = \"Model_name\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the 2D_Demo_Model_from_Stardist_2D_paper\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path) \n"," wget.download(\"\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_last.h5 pretrained model does not exist')\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print(bcolors.WARNING+'No pretrained nerwork will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"YOp8HwavpoON"},"source":["# **4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"CKSKY4icpcKb"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"cellView":"form","id":"0LM_L-5Spb2z"},"source":["#@markdown ##Create the model and dataset objects\n","\n","# --------------------- Here we delete the model folder if it already exist ------------------------\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: Model folder already exists and has been removed !!\" + W)\n"," shutil.rmtree(model_path+'/'+model_name)\n","\n","\n","# --------------------- Here we load the augmented data or the raw data ------------------------\n","\n","print(\"In progress...\")\n","\n","Training_source_dir = Training_source_temp\n","Training_target_dir = Training_target_temp\n","# --------------------- ------------------------------------------------\n","\n","training_images_tiff=Training_source_dir+\"/*.tif\"\n","mask_images_tiff=Training_target_dir+\"/*.tif\"\n","\n","validation_images_tiff=Validation_source_temp+\"/*.tif\"\n","validation_mask_tiff=Validation_target_temp+\"/*.tif\"\n","\n","train_images = imread(sorted(glob(training_images_tiff)))\n","val_images = imread(sorted(glob(validation_images_tiff)))\n","\n","available_train_masks = imread(sorted(glob(mask_images_tiff)))\n","available_val_masks = imread(sorted(glob(validation_mask_tiff)))\n","\n","#This allows the users to not have all their training images segmented\n","blank_images_train = np.zeros((train_images.shape[0]-available_train_masks.shape[0], available_train_masks.shape[1], available_train_masks.shape[2]))\n","blank_images_val = np.zeros((val_images.shape[0]-available_val_masks.shape[0], available_val_masks.shape[1], available_val_masks.shape[2]))\n","blank_images_train = blank_images_train.astype(\"uint16\")\n","blank_images_val = blank_images_val.astype(\"uint16\")\n","\n","train_masks = np.concatenate((available_train_masks,blank_images_train), axis = 0)\n","val_masks = np.concatenate((available_val_masks,blank_images_val), axis = 0)\n","\n","\n","if not Use_Data_augmentation:\n"," X, Y_train_masks = train_images, train_masks\n","\n","# Now we apply data augmentation to the training patches:\n","# Rotate four times by 90 degree and add flipped versions.\n","if Use_Data_augmentation:\n"," X, Y_train_masks = augment_data(train_images, train_masks)\n","\n","X_val, Y_val_masks = val_images, val_masks\n","\n","# Here we add the channel dimension to our input images.\n","# Dimensionality for training has to be 'SYXC' (Sample, Y-Dimension, X-Dimension, Channel)\n","X = X[...,np.newaxis]\n","Y = convert_to_oneHot(Y_train_masks)\n","X_val = X_val[...,np.newaxis]\n","Y_val = convert_to_oneHot(Y_val_masks)\n","print(\"Shape of X: {}\".format(X.shape))\n","print(\"Shape of Y: {}\".format(Y.shape))\n","print(\"Shape of X_val: {}\".format(X_val.shape))\n","print(\"Shape of Y_val: {}\".format(Y_val.shape))\n","\n","#Here we automatically define number_of_step in function of training data and batch size\n","#Here we ensure that our network has a minimal number of steps\n","if (Use_Default_Advanced_Parameters): \n"," number_of_steps= max(100, min(int(X.shape[0]/batch_size), 400))\n","\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","# create a Config object\n","\n","config = DenoiSegConfig(X, unet_kern_size=3, n_channel_out=4, relative_weights = [1.0,1.0,5.0],\n"," train_steps_per_epoch=number_of_steps, train_epochs=number_of_epochs, \n"," batch_norm=True, train_batch_size=batch_size, unet_n_first = 32, \n"," unet_n_depth=4, denoiseg_alpha=Priority, train_learning_rate = initial_learning_rate, train_tensorboard=False)\n","\n","\n","# Let's look at the parameters stored in the config-object.\n","vars(config)\n"," \n"," \n","# create network model.\n","\n","model = DenoiSeg(config=config, name=model_name, basedir=model_path)\n","\n","\n","\n","# --------------------- Using pretrained model ------------------------\n","# Load the pretrained weights \n","if Use_pretrained_model:\n"," model.load_weights(h5_file_path)\n","# --------------------- ---------------------- ------------------------\n","\n","#Export summary of training parameters as pdf\n","pdf_export(augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)\n","\n","print(\"Setup done.\")\n","print(config)\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"MCGklf1vZf2M"},"source":["## **4.2. Train the network**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point.\n","\n","**Of Note:** At the end of the training, your model will be automatically exported so it can be used in the CSBDeep Fiji plugin (DenoiSeg -- DenoiSeg Predict). You can find it in your model folder (export.bioimage.io.zip and model.yaml). In Fiji, Make sure to choose the right version of tensorflow. You can check at: Edit-- Options-- Tensorflow. Choose the version 1.4 (CPU or GPU depending on your system).\n","\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"code","metadata":{"cellView":"form","id":"xlcY9dvfm67C"},"source":["start = time.time()\n","\n","#@markdown ##Start Training\n","%memit\n","\n","\n","\n","history = model.train(X, Y, (X_val, Y_val))\n","\n","print(\"Training done.\")\n","%memit\n","\n","\n","print(\"Training, done.\")\n","\n","threshold, val_score = model.optimize_thresholds(val_images[:available_val_masks.shape[0]].astype(np.float32), val_masks, measure=measure_precision())\n","\n","print(\"The higest score of {} is achieved with threshold = {}.\".format(np.round(val_score, 3), threshold))\n","\n","\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n"," shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate','threshold'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i], str(threshold)])\n","\n","#Thresholdpath = model_path+'/'+model_name+'/Quality Control/optimal_threshold.csv'\n","#with open(Thresholdpath, 'w') as f1:\n"," #writer1 = csv.writer(f1)\n"," #writer1.writerow(['threshold'])\n"," #writer1.writerow([str(threshold)])\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","model.export_TF(name='DenoiSeg', \n"," description='DenoiSeg 2D trained using ZeroCostDL4Mic.', \n"," authors=[\"You\"],\n"," test_img=X_val[0,...,0], axes='YX',\n"," patch_shape=(64, 64))\n","\n","print(\"Your model has been sucessfully exported and can now also be used in the CSBDeep Fiji plugin\")\n","\n","#Create a pdf document with training summary\n","\n","pdf_export(trained = True, augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"_0Hynw3-xHp1"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"id":"eAJzMwPA6tlH","cellView":"form"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," \n"," print(bcolors.WARNING + '!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"dhJROwlAMv5o"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","It is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact noise patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"id":"vMzSP50kMv5p","cellView":"form"},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png')\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"X5_92nL2xdP6"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","**DenoiSeg** allow to both denoise and segment microscopy images. This section allow you to evaluate both tasks separetly.\n","\n","**Evaluation of the denoising**\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_Denoising_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n","\n","**Evaluation of the Segmentation**\n","\n","This option will calculate the Intersection over Union score for all the images provided in the Source_QC_folder and Target_Segmentation_folder ! The result for one of the image will also be displayed.\n","\n","The **Intersection over Union** metric is a method that can be used to quantify the percent overlap between the target mask and your prediction output. **Therefore, the closer to 1, the better the performance.** This metric can be used to assess the quality of your model to accurately predict nuclei. \n","\n"]},{"cell_type":"code","metadata":{"id":"w90MdriMxhjD","cellView":"form"},"source":["#@markdown ##Choose what to evaluate\n","\n","Evaluate_Denoising = True #@param {type:\"boolean\"}\n","\n","Evaluate_Segmentation = True #@param {type:\"boolean\"}\n","\n","\n","# ------------- User input ------------\n","#@markdown ##Choose the folders that contain your Quality Control dataset\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_Denoising_folder = \"\" #@param{type:\"string\"}\n","Target_Segmentation_folder = \"\" #@param{type:\"string\"}\n","\n","\n","#@markdown ###If your model was trained outside of ZeroCostDl4Mic, please provide a threshold value for the segmentation (between 0-1):\n","\n","threshold = 0.5 #@param {type:\"number\"}\n","\n","# Create a quality control/Prediction Folder\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","#Activate the pretrained model. \n","config = None\n","model = DenoiSeg(config=None, name=QC_model_name, basedir=QC_model_path)\n","\n","#Load the threshold value. \n","\n","if os.path.exists(os.path.join(full_QC_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(full_QC_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"threshold\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"Optimal segmentation threshold found\")\n"," #find the last learning rate\n"," threshold = csvRead[\"threshold\"].iloc[-1]\n","\n","# ------------- Prepare the model and run predictions ------------\n","# creates a loop, creating filenames and saving them\n","\n","thisdir = Path(Source_QC_folder)\n","\n","# r=root, d=directories, f = files\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," if \".tif\" in file:\n"," print(os.path.join(r, file))\n","\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n","\n","#Here we load the images\n"," base_filename = os.path.basename(file)\n"," test_images = imread(os.path.join(r, file))\n","\n","#Here we perform the predictions\n"," predicted_channels = model.predict(test_images.astype(np.float32), axes='YX')\n"," denoised_images= predicted_channels[...,0]\n"," segmented_images= (compute_labels(predicted_channels, threshold))\n","\n","#Here we save the results\n"," io.imsave(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"+\"/\"+\"Predicted_denoised_\"+base_filename, denoised_images)\n"," io.imsave(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"+\"/\"+\"Predicted_segmentation_\"+base_filename, segmented_images)\n","\n","# ------------- Here we Start assessing the denoising against GT ------------\n","\n","if Evaluate_Denoising:\n"," def ssim(img1, img2):\n"," return structural_similarity(img1,img2,data_range=1.,full=True)\n","\n","\n"," def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n"," def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n"," def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","# Open and create the csv file that will contain all the QC metrics\n"," with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/QC_metrics_Denoising_\"+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n","\n"," # Let's loop through the provided dataset in the QC folders\n","\n","\n"," for i in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder,i)):\n"," print('Running QC on: '+i)\n"," # -------------------------------- Target test data (Ground truth) --------------------------------\n"," test_GT = io.imread(os.path.join(Target_Denoising_folder, i))\n","\n"," # -------------------------------- Source test data --------------------------------\n"," test_source = io.imread(os.path.join(Source_QC_folder,i))\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and Source image\n"," test_GT_norm,test_source_norm = norm_minmse(test_GT, test_source, normalize_gt=True)\n","\n"," # -------------------------------- Prediction --------------------------------\n"," test_prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\",\"Predicted_denoised_\"+i))\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and prediction\n"," test_GT_norm,test_prediction_norm = norm_minmse(test_GT, test_prediction, normalize_gt=True) \n","\n","\n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\n","\n"," # Calculate the SSIM maps\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT_norm, test_prediction_norm)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT_norm, test_source_norm)\n","\n"," #Save ssim_maps\n"," img_SSIM_GTvsPrediction_32bit = np.float32(img_SSIM_GTvsPrediction)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/SSIM_GTvsPrediction_'+i,img_SSIM_GTvsPrediction_32bit)\n"," img_SSIM_GTvsSource_32bit = np.float32(img_SSIM_GTvsSource)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/SSIM_GTvsSource_'+i,img_SSIM_GTvsSource_32bit)\n"," \n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n","\n"," # Save SE maps\n"," img_RSE_GTvsPrediction_32bit = np.float32(img_RSE_GTvsPrediction)\n"," img_RSE_GTvsSource_32bit = np.float32(img_RSE_GTvsSource)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/RSE_GTvsPrediction_'+i,img_RSE_GTvsPrediction_32bit)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/RSE_GTvsSource_'+i,img_RSE_GTvsSource_32bit)\n","\n","\n"," # -------------------------------- Calculate the RSE metrics and save them --------------------------------\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n"," \n"," # We can also measure the peak signal to noise ratio between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n","\n"," writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource),str(PSNR_GTvsPrediction),str(PSNR_GTvsSource)])\n","\n","\n","# All data is now processed saved\n"," Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same\n"," norm = simple_norm(x, percent = 99)\n","\n"," plt.figure(figsize=(15,15))\n"," # Currently only displays the last computed set, from memory\n"," # Target (Ground-truth)\n"," plt.subplot(3,3,1)\n"," plt.axis('off')\n"," img_GT = io.imread(os.path.join(Target_Denoising_folder, Test_FileList[-1]))\n"," plt.imshow(img_GT, norm=norm, cmap='magma', interpolation='nearest')\n"," plt.title('Target',fontsize=15)\n","\n","# Source\n"," plt.subplot(3,3,2)\n"," plt.axis('off')\n"," img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n"," plt.imshow(img_Source, norm=norm, cmap='magma', interpolation='nearest')\n"," plt.title('Source',fontsize=15)\n","\n","#Prediction\n"," plt.subplot(3,3,3)\n"," plt.axis('off')\n"," img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction/\", \"Predicted_denoised_\"+Test_FileList[-1]))\n"," plt.imshow(img_Prediction, norm=norm, cmap='magma', interpolation='nearest')\n"," plt.title('Prediction',fontsize=15)\n","\n","#Setting up colours\n"," cmap = plt.cm.CMRmap\n","\n"," #SSIM between GT and Source\n"," plt.subplot(3,3,5)\n"," #plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n"," imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n"," plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Source',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\n"," plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n"," plt.subplot(3,3,6)\n"," #plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n"," imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n"," plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\n","\n","#Root Squared Error between GT and Source\n"," plt.subplot(3,3,8)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n"," imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource, cmap = cmap, vmin=0, vmax = 1)\n"," plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n"," plt.title('Target vs. Source',fontsize=15)\n"," plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsSource,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n","\n"," plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#Root Squared Error between GT and Prediction\n"," plt.subplot(3,3,9)\n"," #plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n"," imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n"," plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsPrediction,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n"," plt.savefig(full_QC_model_path+'/Quality Control/QC_example_denoising.png',bbox_inches='tight',pad_inches=0)\n","#________________________________________________________________________\n","# Here we start testing the differences between GT and predicted masks\n","\n","if Evaluate_Segmentation:\n","\n"," with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/QC_metrics_Segmentation_\"+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"image\",\"Prediction v. GT Intersection over Union\"]) \n","\n","# define the images\n","\n"," for n in os.listdir(Source_QC_folder):\n"," \n"," if not os.path.isdir(os.path.join(Source_QC_folder,n)):\n"," print('Running QC on: '+n)\n"," test_input = io.imread(os.path.join(Source_QC_folder,n))\n"," test_prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\",\"Predicted_segmentation_\"+n))\n"," test_ground_truth_image = io.imread(os.path.join(Target_Segmentation_folder, n))\n","\n"," #Convert pixel values to 0 or 255\n"," test_prediction_0_to_255 = test_prediction\n"," test_prediction_0_to_255[test_prediction_0_to_255>0] = 255\n","\n"," #Convert pixel values to 0 or 255\n"," test_ground_truth_0_to_255 = test_ground_truth_image\n"," test_ground_truth_0_to_255[test_ground_truth_0_to_255>0] = 255\n","\n"," # Intersection over Union metric\n","\n"," intersection = np.logical_and(test_ground_truth_0_to_255, test_prediction_0_to_255)\n"," union = np.logical_or(test_ground_truth_0_to_255, test_prediction_0_to_255)\n"," iou_score = np.sum(intersection) / np.sum(union)\n"," writer.writerow([n, str(iou_score)])\n","\n","\n"," from astropy.visualization import simple_norm\n","\n"," # ------------- For display ------------\n"," print('--------------------------------------------------------------')\n"," @interact\n"," def show_QC_results(file = os.listdir(Source_QC_folder)):\n","\n"," plt.figure(figsize=(25,5))\n"," source_image = io.imread(os.path.join(Source_QC_folder, file), as_gray = True)\n"," target_image = io.imread(os.path.join(Target_Segmentation_folder, file), as_gray = True)\n"," prediction = io.imread(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction/Predicted_segmentation_\"+file, as_gray = True)\n","\n"," target_image_mask = np.empty_like(target_image)\n"," target_image_mask[target_image > 0] = 255\n"," target_image_mask[target_image == 0] = 0\n"," \n"," prediction_mask = np.empty_like(prediction)\n"," prediction_mask[prediction > 0] = 255\n"," prediction_mask[prediction == 0] = 0\n","\n"," intersection = np.logical_and(target_image_mask, prediction_mask)\n"," union = np.logical_or(target_image_mask, prediction_mask)\n"," iou_score = np.sum(intersection) / np.sum(union)\n","\n"," norm = simple_norm(source_image, percent = 99)\n","\n","\n"," #Input\n"," plt.subplot(1,4,1)\n"," plt.axis('off')\n"," plt.imshow(source_image, aspect='equal', norm=norm, cmap='magma', interpolation='nearest')\n"," plt.title('Input')\n","\n"," #Ground-truth\n"," plt.subplot(1,4,2)\n"," plt.axis('off')\n"," plt.imshow(target_image_mask, aspect='equal', cmap='Greens')\n"," plt.title('Ground Truth')\n","\n"," #Prediction\n"," plt.subplot(1,4,3)\n"," plt.axis('off')\n"," plt.imshow(prediction_mask, aspect='equal', cmap='Purples')\n"," plt.title('Prediction')\n","\n"," #Overlay\n"," plt.subplot(1,4,4)\n"," plt.axis('off')\n"," plt.imshow(target_image_mask, cmap='Greens')\n"," plt.imshow(prediction_mask, alpha=0.5, cmap='Purples')\n"," plt.title('Ground Truth and Prediction, Intersection over Union:'+str(round(iou_score,3)));\n"," plt.savefig(full_QC_model_path+'/Quality Control/QC_example_segmentation.png',bbox_inches='tight',pad_inches=0)\n","\n","#Export pdf summary of QC results\n","qc_pdf_export()\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-tJeeJjLnRkP"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"d8wuQGjoq6eN"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If an older model needs to be used, please untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contains the images that you want to predict using the network that you will train.\n","\n","**`Result_folder`:** This folder will contain the predicted output images."]},{"cell_type":"code","metadata":{"id":"y2TD5p7MZrEb","cellView":"form"},"source":["import imageio\n","\n","\n","#@markdown ### Provide the path to your dataset and to the folder where the prediction will be saved, then play the cell to predict output on your unseen images.\n","\n","#@markdown ###Path to data to analyse and where predicted output should be saved:\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","\n","#@markdown ###If your model was trained outside of ZeroCostDl4Mic, please provide a Threshold value for the segmentation (between 0-1):\n","\n","threshold = 0.5 #@param {type:\"number\"}\n","\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," print(bcolors.WARNING +'!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","#Activate the pretrained model. \n","config = None\n","model = DenoiSeg(config=None, name=Prediction_model_name, basedir=Prediction_model_path)\n","\n","#Load the threshold value. \n","\n","if os.path.exists(os.path.join(full_Prediction_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(full_Prediction_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"threshold\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"Optimal segmentation threshold found\")\n"," #find the last learning rate\n"," threshold = csvRead[\"threshold\"].iloc[-1]\n","\n","# creates a loop, creating filenames and saving them\n","\n","thisdir = Path(Data_folder)\n","outputdir = Path(Result_folder)\n","\n","# r=root, d=directories, f = files\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," if \".tif\" in file:\n"," print(os.path.join(r, file))\n","\n","print(\"Processing...\")\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n","\n","#Here we load the images\n"," base_filename = os.path.basename(file)\n"," test_images = imread(os.path.join(r, file))\n","\n","#Here we perform the predictions\n"," predicted_channels = model.predict(test_images.astype(np.float32), axes='YX')\n"," denoised_images= predicted_channels[...,0]\n"," segmented_images= (compute_labels(predicted_channels, threshold))\n","\n","#Here we save the results\n"," io.imsave(Result_folder+\"/\"+\"Predicted_denoised_\"+base_filename, denoised_images)\n"," io.imsave(Result_folder+\"/\"+\"Predicted_segmentation_\"+base_filename,segmented_images)\n"," \n","\n","\n","print(\"Images saved into folder:\", Result_folder)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Qg31ghpfoNBD"},"source":["## **6.2. Assess predicted output**\n","---\n","\n","\n"]},{"cell_type":"code","metadata":{"cellView":"form","id":"CH7t08UooLba"},"source":["\n","# @markdown ##Run this cell to display a randomly chosen input and its corresponding predicted output.\n","\n","\n","# This will display a randomly chosen dataset input and predicted output\n","random_choice = random.choice(os.listdir(Data_folder))\n","x = imread(Data_folder+\"/\"+random_choice)\n","\n","os.chdir(Result_folder)\n","y = imread(Result_folder+\"/\"+\"Predicted_denoised_\"+random_choice)\n","z = imread(Result_folder+\"/\"+\"Predicted_segmentation_\"+random_choice)\n","\n","norm = simple_norm(x, percent = 99)\n","\n","plt.figure(figsize=(30,15))\n","plt.subplot(1, 4, 1)\n","plt.imshow(x, interpolation='nearest', norm=norm, cmap='magma')\n","plt.axis('off');\n","plt.title(\"Input\")\n","\n","plt.subplot(1, 4, 2)\n","plt.imshow(y, interpolation='nearest', norm=norm, cmap='magma')\n","plt.axis('off');\n","plt.title(\"Predicted denoised image\")\n","\n","plt.subplot(1, 4, 3)\n","plt.imshow(z, interpolation='nearest', vmin=0, vmax=1, cmap='viridis')\n","plt.axis('off');\n","plt.title(\"Predicted segmentation\")\n","\n","plt.show()\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"W5BxExzzs7gh"},"source":["# **7. Version log**\n","---\n","**v1.13**: \n","\n","\n","* This version now includes an automatic restart allowing to set the h5py library to v2.10.\n","* The section 1 and 2 are now swapped for better export of *requirements.txt*.\n","\n","* This version also now includes built-in version check and the version log that you're reading now."]},{"cell_type":"markdown","metadata":{"id":"UvSlTaH14s3t"},"source":["#**Thank you for using DenoiSeg!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/Beta notebooks/Detectron2_2D_ZeroCostDL4Mic.ipynb b/Colab_notebooks/Beta notebooks/Detectron2_2D_ZeroCostDL4Mic.ipynb index d5b37563..10200cbc 100644 --- a/Colab_notebooks/Beta notebooks/Detectron2_2D_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/Beta notebooks/Detectron2_2D_ZeroCostDL4Mic.ipynb @@ -1 +1 @@ -{"nbformat":4,"nbformat_minor":0,"metadata":{"accelerator":"GPU","colab":{"name":"Detectron2_2D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1hzAI0joLETcG5sI2Qvo8AKDr0TWRKySJ","timestamp":1587653755731},{"file_id":"1QFcz4NnQv4rMwDNl7AzHajN-Ola9sUFW","timestamp":1586411847878},{"file_id":"12UDRQ7abcnXcf5FctR9IUStgCpBiQWn7","timestamp":1584466922281},{"file_id":"1zXCn3A39GI1MCnXK_g_Z-AWh9vkB0YhU","timestamp":1583244415636}],"collapsed_sections":[]},"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.9"}},"cells":[{"cell_type":"markdown","metadata":{"id":"JhcOyBQjR54F"},"source":["#**This notebook is in beta**\n","Expect some instabilities and bugs.\n","\n","**Currently missing features include:**\n","\n","- Augmentation cannot be disabled\n","- Exported results include only a simple CSV file. More options will be included in the next releases\n","- Training and QC reports are not generated\n"]},{"cell_type":"markdown","metadata":{"id":"IkSguVy8Xv83"},"source":["# **Detectron2 (2D)**\n","\n"," Detectron2 is a deep-learning method designed to perform object detection and classification of objects in images. Detectron2 is Facebook AI Research's next generation software system that implements state-of-the-art object detection algorithms. It is a ground-up rewrite of the previous version, Detectron, and it originates from maskrcnn-benchmark. More information on Detectron2 can be found on the Detectron2 github pages (/~https://github.com/facebookresearch/detectron2).\n","\n","\n","\n","**This particular notebook enables object detection and classification on 2D images given ground truth bounding boxes. If you are interested in image segmentation, you should use our U-net or Stardist notebooks instead.**\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"NDICs5NxYEWP"},"source":["# **License**\n","\n","---"]},{"cell_type":"code","metadata":{"cellView":"form","id":"R575GX8cX2aP"},"source":["#@markdown ##Double click to see the license information\n","\n","#------------------------- LICENSE FOR ZeroCostDL4Mic------------------------------------\n","#This ZeroCostDL4Mic notebook is distributed under the MIT licence\n","\n","\n","\n","#------------------------- LICENSE FOR CycleGAN ------------------------------------\n","\n","\n","#Apache License\n","#Version 2.0, January 2004\n","#http://www.apache.org/licenses/\n","\n","#TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n","\n","#1. Definitions.\n","\n","#\"License\" shall mean the terms and conditions for use, reproduction,\n","#and distribution as defined by Sections 1 through 9 of this document.\n","\n","#\"Licensor\" shall mean the copyright owner or entity authorized by\n","#the copyright owner that is granting the License.\n","\n","#\"Legal Entity\" shall mean the union of the acting entity and all\n","#other entities that control, are controlled by, or are under common\n","#control with that entity. For the purposes of this definition,\n","#\"control\" means (i) the power, direct or indirect, to cause the\n","#direction or management of such entity, whether by contract or\n","#otherwise, or (ii) ownership of fifty percent (50%) or more of the\n","#outstanding shares, or (iii) beneficial ownership of such entity.\n","\n","#\"You\" (or \"Your\") shall mean an individual or Legal Entity\n","#exercising permissions granted by this License.\n","\n","#\"Source\" form shall mean the preferred form for making modifications,\n","#including but not limited to software source code, documentation\n","#source, and configuration files.\n","\n","#\"Object\" form shall mean any form resulting from mechanical\n","#transformation or translation of a Source form, including but\n","#not limited to compiled object code, generated documentation,\n","#and conversions to other media types.\n","\n","#\"Work\" shall mean the work of authorship, whether in Source or\n","#Object form, made available under the License, as indicated by a\n","#copyright notice that is included in or attached to the work\n","#(an example is provided in the Appendix below).\n","\n","#\"Derivative Works\" shall mean any work, whether in Source or Object\n","#form, that is based on (or derived from) the Work and for which the\n","#editorial revisions, annotations, elaborations, or other modifications\n","#represent, as a whole, an original work of authorship. For the purposes\n","#of this License, Derivative Works shall not include works that remain\n","#separable from, or merely link (or bind by name) to the interfaces of,\n","#the Work and Derivative Works thereof.\n","\n","#\"Contribution\" shall mean any work of authorship, including\n","#the original version of the Work and any modifications or additions\n","#to that Work or Derivative Works thereof, that is intentionally\n","#submitted to Licensor for inclusion in the Work by the copyright owner\n","#or by an individual or Legal Entity authorized to submit on behalf of\n","#the copyright owner. For the purposes of this definition, \"submitted\"\n","#means any form of electronic, verbal, or written communication sent\n","#to the Licensor or its representatives, including but not limited to\n","#communication on electronic mailing lists, source code control systems,\n","#and issue tracking systems that are managed by, or on behalf of, the\n","#Licensor for the purpose of discussing and improving the Work, but\n","#excluding communication that is conspicuously marked or otherwise\n","#designated in writing by the copyright owner as \"Not a Contribution.\"\n","\n","#\"Contributor\" shall mean Licensor and any individual or Legal Entity\n","#on behalf of whom a Contribution has been received by Licensor and\n","#subsequently incorporated within the Work.\n","\n","#2. Grant of Copyright License. Subject to the terms and conditions of\n","#this License, each Contributor hereby grants to You a perpetual,\n","#worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n","#copyright license to reproduce, prepare Derivative Works of,\n","#publicly display, publicly perform, sublicense, and distribute the\n","#Work and such Derivative Works in Source or Object form.\n","\n","#3. Grant of Patent License. Subject to the terms and conditions of\n","#this License, each Contributor hereby grants to You a perpetual,\n","#worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n","#(except as stated in this section) patent license to make, have made,\n","#use, offer to sell, sell, import, and otherwise transfer the Work,\n","#where such license applies only to those patent claims licensable\n","#by such Contributor that are necessarily infringed by their\n","#Contribution(s) alone or by combination of their Contribution(s)\n","#with the Work to which such Contribution(s) was submitted. If You\n","#institute patent litigation against any entity (including a\n","#cross-claim or counterclaim in a lawsuit) alleging that the Work\n","#or a Contribution incorporated within the Work constitutes direct\n","#or contributory patent infringement, then any patent licenses\n","#granted to You under this License for that Work shall terminate\n","#as of the date such litigation is filed.\n","\n","#4. Redistribution. You may reproduce and distribute copies of the\n","#Work or Derivative Works thereof in any medium, with or without\n","#modifications, and in Source or Object form, provided that You\n","#meet the following conditions:\n","\n","#(a) You must give any other recipients of the Work or\n","#Derivative Works a copy of this License; and\n","\n","#(b) You must cause any modified files to carry prominent notices\n","#stating that You changed the files; and\n","\n","#(c) You must retain, in the Source form of any Derivative Works\n","#that You distribute, all copyright, patent, trademark, and\n","#attribution notices from the Source form of the Work,\n","#excluding those notices that do not pertain to any part of\n","#the Derivative Works; and\n","\n","#(d) If the Work includes a \"NOTICE\" text file as part of its\n","#distribution, then any Derivative Works that You distribute must\n","#include a readable copy of the attribution notices contained\n","#within such NOTICE file, excluding those notices that do not\n","#pertain to any part of the Derivative Works, in at least one\n","#of the following places: within a NOTICE text file distributed\n","#as part of the Derivative Works; within the Source form or\n","#documentation, if provided along with the Derivative Works; or,\n","#within a display generated by the Derivative Works, if and\n","#wherever such third-party notices normally appear. The contents\n","#of the NOTICE file are for informational purposes only and\n","#do not modify the License. You may add Your own attribution\n","#notices within Derivative Works that You distribute, alongside\n","#or as an addendum to the NOTICE text from the Work, provided\n","#that such additional attribution notices cannot be construed\n","#as modifying the License.\n","\n","#You may add Your own copyright statement to Your modifications and\n","#may provide additional or different license terms and conditions\n","#for use, reproduction, or distribution of Your modifications, or\n","#for any such Derivative Works as a whole, provided Your use,\n","#reproduction, and distribution of the Work otherwise complies with\n","#the conditions stated in this License.\n","\n","#5. Submission of Contributions. Unless You explicitly state otherwise,\n","#any Contribution intentionally submitted for inclusion in the Work\n","#by You to the Licensor shall be under the terms and conditions of\n","#this License, without any additional terms or conditions.\n","#Notwithstanding the above, nothing herein shall supersede or modify\n","#the terms of any separate license agreement you may have executed\n","#with Licensor regarding such Contributions.\n","\n","#6. Trademarks. This License does not grant permission to use the trade\n","#names, trademarks, service marks, or product names of the Licensor,\n","#except as required for reasonable and customary use in describing the\n","#origin of the Work and reproducing the content of the NOTICE file.\n","\n","#7. Disclaimer of Warranty. Unless required by applicable law or\n","#agreed to in writing, Licensor provides the Work (and each\n","#Contributor provides its Contributions) on an \"AS IS\" BASIS,\n","#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n","#implied, including, without limitation, any warranties or conditions\n","#of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n","#PARTICULAR PURPOSE. You are solely responsible for determining the\n","#appropriateness of using or redistributing the Work and assume any\n","#risks associated with Your exercise of permissions under this License.\n","\n","#8. Limitation of Liability. In no event and under no legal theory,\n","#whether in tort (including negligence), contract, or otherwise,\n","#unless required by applicable law (such as deliberate and grossly\n","#negligent acts) or agreed to in writing, shall any Contributor be\n","#liable to You for damages, including any direct, indirect, special,\n","#incidental, or consequential damages of any character arising as a\n","#result of this License or out of the use or inability to use the\n","#Work (including but not limited to damages for loss of goodwill,\n","#work stoppage, computer failure or malfunction, or any and all\n","#other commercial damages or losses), even if such Contributor\n","#has been advised of the possibility of such damages.\n","\n","#9. Accepting Warranty or Additional Liability. While redistributing\n","#the Work or Derivative Works thereof, You may choose to offer,\n","#and charge a fee for, acceptance of support, warranty, indemnity,\n","#or other liability obligations and/or rights consistent with this\n","#License. However, in accepting such obligations, You may act only\n","#on Your own behalf and on Your sole responsibility, not on behalf\n","#of any other Contributor, and only if You agree to indemnify,\n","#defend, and hold each Contributor harmless for any liability\n","#incurred by, or claims asserted against, such Contributor by reason\n","#of your accepting any such warranty or additional liability.\n","\n","#END OF TERMS AND CONDITIONS\n","\n","#APPENDIX: How to apply the Apache License to your work.\n","\n","#To apply the Apache License to your work, attach the following\n","#boilerplate notice, with the fields enclosed by brackets \"[]\"\n","#replaced with your own identifying information. (Don't include\n","#the brackets!) The text should be enclosed in the appropriate\n","#comment syntax for the file format. We also recommend that a\n","#file or class name and description of purpose be included on the\n","#same \"printed page\" as the copyright notice for easier\n","#identification within third-party archives.\n","\n","#Copyright [yyyy] [name of copyright owner]\n","\n","\n","#Licensed under the Apache License, Version 2.0 (the \"License\");\n","#you may not use this file except in compliance with the License.\n","#You may obtain a copy of the License at\n","\n","#http://www.apache.org/licenses/LICENSE-2.0\n","\n","#Unless required by applicable law or agreed to in writing, software\n","#distributed under the License is distributed on an \"AS IS\" BASIS,\n","#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n","#See the License for the specific language governing permissions and\n","#limitations under the License."],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"gKDLkLWUd-YX"},"source":["#**0. Before getting started**\n","---\n"," Preparing the dataset carefully is essential to make this Detectron2 notebook work. This model requires as input a set of images and as target a list of annotation files in Pascal VOC format. The annotation files should have the exact same name as the input files, except with an .xml instead of the .jpg extension. The annotation files contain the class labels and all bounding boxes for the objects for each image in your dataset. Most datasets will give the option of saving the annotations in this format or using software for hand-annotations will automatically save the annotations in this format. \n","\n"," If you want to assemble your own dataset we recommend using the open source https://www.makesense.ai/ resource. You can follow our instructions on how to label your dataset with this tool on our [wiki](/~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki/Object-Detection-(YOLOv2)).\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n"," **Additionally, the corresponding input and output files need to have the same name**.\n","\n"," Please note that you currently can **only use .png files!**\n","\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Input images (Training_source)\n"," - img_1.png, img_2.png, ...\n"," - High SNR images (Training_source_annotations)\n"," - img_1.xml, img_2.xml, ...\n"," - **Quality control dataset**\n"," - Input images\n"," - img_1.png, img_2.png\n"," - High SNR images\n"," - img_1.xml, img_2.xml\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"cbTknRcviyT7"},"source":["# **1. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"cellView":"form","id":"h5i5CS2bSmZr"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","#%tensorflow_version 1.x\n","\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n3B3meGTbYVi"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"cellView":"form","id":"01Djr8v-5pPk"},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n"," \n","#@markdown * Click on the URL. \n"," \n","#@markdown * Sign in your Google Account. \n"," \n","#@markdown * Copy the authorization code. \n"," \n","#@markdown * Enter the authorization code. \n"," \n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n"," \n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **2. Install Detectron2 and dependencies**\n","---"]},{"cell_type":"markdown","metadata":{"id":"yg1vZe88JEyk"},"source":["## **2.1. Install key dependencies**\n","---\n"," "]},{"cell_type":"code","metadata":{"id":"Tw1Usk1iPvRQ","cellView":"form"},"source":[" \n","#@markdown ##Install dependencies and Detectron2\n"," \n","# install dependencies\n","#!pip install -U torch torchvision cython\n","!pip install -U 'git+/~https://github.com/facebookresearch/fvcore.git' 'git+/~https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'\n","import torch, torchvision\n","torch.__version__\n"," \n","!git clone /~https://github.com/facebookresearch/detectron2 detectron2_repo\n","!pip install -e detectron2_repo\n","\n","!pip install wget\n","\n","#Force session restart\n","exit(0)\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"6-VxlZUkKLgC"},"source":["## **2.2. Restart your runtime**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"id":"xhWNIu6cf5G8"},"source":["** Your Runtime has automatically restarted. This is normal.**\n","\n"]},{"cell_type":"markdown","metadata":{"id":"5nXTBntzKRWu"},"source":["## **2.3. Load key dependencies**\n","---\n"," "]},{"cell_type":"code","metadata":{"id":"fq21zJVFNASx","cellView":"form"},"source":["Notebook_version = ['1.12']\n"," \n","#@markdown ##Play this cell to load the required dependencies\n","import wget\n","# Some basic setup: \n","import detectron2\n","from detectron2.utils.logger import setup_logger\n","setup_logger()\n"," \n","# import some common libraries\n","import numpy as np\n","import os, json, cv2, random\n","from google.colab.patches import cv2_imshow\n"," \n","import yaml\n"," \n","#Download the script to convert XML into COCO\n"," \n","wget.download(\"/~https://github.com/HenriquesLab/ZeroCostDL4Mic/raw/master/Tools/voc2coco.py\", \"/content\")\n"," \n"," \n","# import some common detectron2 utilities\n","from detectron2 import model_zoo\n","from detectron2.engine import DefaultPredictor\n","from detectron2.config import get_cfg\n","from detectron2.utils.visualizer import Visualizer\n","from detectron2.data import MetadataCatalog, DatasetCatalog\n","from detectron2.utils.visualizer import ColorMode\n","\n","from detectron2.data import DatasetCatalog, MetadataCatalog, build_detection_test_loader\n","from datetime import datetime\n","from detectron2.data.catalog import Metadata\n","\n","from detectron2.config import get_cfg\n","from detectron2.data import DatasetCatalog, MetadataCatalog, build_detection_test_loader\n","from detectron2.evaluation import COCOEvaluator, inference_on_dataset\n","from detectron2.engine import DefaultTrainer\n","from detectron2.data.datasets import register_coco_instances\n","from detectron2.utils.visualizer import ColorMode\n","import glob\n","from detectron2.checkpoint import Checkpointer\n","from detectron2.config import get_cfg\n","import os\n"," \n"," \n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n"," \n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n"," \n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n"," \n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n"," \n"," \n","from detectron2.engine import DefaultTrainer\n","from detectron2.evaluation import COCOEvaluator\n"," \n","class CocoTrainer(DefaultTrainer):\n"," \n"," @classmethod\n"," def build_evaluator(cls, cfg, dataset_name, output_folder=None):\n"," \n"," if output_folder is None:\n"," os.makedirs(\"coco_eval\", exist_ok=True)\n"," output_folder = \"coco_eval\"\n"," \n"," return COCOEvaluator(dataset_name, cfg, False, output_folder)\n"," \n"," \n"," \n","print(\"Librairies loaded\")\n"," \n","# Check if this is the latest version of the notebook\n","Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n"," \n","if Notebook_version == list(Latest_notebook_version.columns):\n"," print(\"This notebook is up-to-date.\")\n"," \n","if not Notebook_version == list(Latest_notebook_version.columns):\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n"," \n"," \n"," \n","#Failsafes\n","cell_ran_prediction = 0\n","cell_ran_training = 0\n","cell_ran_QC_training_dataset = 0\n","cell_ran_QC_QC_dataset = 0"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"iwjra6kMKmUA"},"source":["# **3. Select your parameters and paths**\n"]},{"cell_type":"markdown","metadata":{"id":"Kbn9_JdqnNnK"},"source":["## **3.1. Setting main training parameters**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"id":"CB6acvUFtWqd"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source and the annotation data respectively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`labels`:** Input the name of the differentes labels used to annotate your dataset (separated by a comma).\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**Training Parameters**\n","\n","**`number_of_iteration`:** Input how many iterations to use to train the network. Initial results can be observed using 1000 iterations but consider using 5000 or more iterations to train your models. **Default value: 2000**\n"," \n","\n","**Advanced Parameters - experienced users only**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Noise2Void requires a large batch size for stable training. Reduce this parameter if your GPU runs out of memory. **Default value: 128**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each image / patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during the training. **Default value: 10**\n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0001**\n"]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","cellView":"form"},"source":["# create DataGenerator-object.\n","\n","\n","#@markdown ###Path to training image(s): \n","Training_source = \"\" #@param {type:\"string\"}\n","Training_target = \"\" #@param {type:\"string\"}\n","\n","#@markdown ###Labels\n","#@markdown Input the name of the differentes labels present in your training dataset separated by a comma\n","labels = \"\" #@param {type:\"string\"}\n","\n","\n","#@markdown ### Model name and path:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","full_model_path = model_path+'/'+model_name+'/'\n","\n","\n","#@markdown ###Training Parameters\n","#@markdown Number of iterations:\n","number_of_iteration = 2000#@param {type:\"number\"}\n","\n","\n","#Here we store the informations related to our labels\n","\n","list_of_labels = labels.split(\", \")\n","with open('/content/labels.txt', 'w') as f:\n"," for item in list_of_labels:\n"," print(item, file=f)\n","\n","number_of_labels = len(list_of_labels)\n","\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True#@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please input:\n","batch_size = 4#@param {type:\"number\"}\n","percentage_validation = 10#@param {type:\"number\"}\n","initial_learning_rate = 0.001 #@param {type:\"number\"}\n","\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 4\n"," percentage_validation = 10\n"," initial_learning_rate = 0.001\n","\n","# Here we disable pre-trained model by default (in case the next cell is not ran)\n","Use_pretrained_model = True\n","\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = True\n","\n","# Here we split the data between training and validation\n","# Here we count the number of files in the training target folder\n","Filelist = os.listdir(Training_target)\n","number_files = len(Filelist)\n","\n","File_for_validation = int((number_files)/percentage_validation)+1\n","\n","#Here we split the training dataset between training and validation\n","# Everything is copied in the /Content Folder\n","\n","Training_source_temp = \"/content/training_source\"\n","\n","if os.path.exists(Training_source_temp):\n"," shutil.rmtree(Training_source_temp)\n","os.makedirs(Training_source_temp)\n","\n","Training_target_temp = \"/content/training_target\"\n","if os.path.exists(Training_target_temp):\n"," shutil.rmtree(Training_target_temp)\n","os.makedirs(Training_target_temp)\n","\n","Validation_source_temp = \"/content/validation_source\"\n","\n","if os.path.exists(Validation_source_temp):\n"," shutil.rmtree(Validation_source_temp)\n","os.makedirs(Validation_source_temp)\n","\n","Validation_target_temp = \"/content/validation_target\"\n","if os.path.exists(Validation_target_temp):\n"," shutil.rmtree(Validation_target_temp)\n","os.makedirs(Validation_target_temp)\n","\n","list_source = os.listdir(os.path.join(Training_source))\n","list_target = os.listdir(os.path.join(Training_target))\n","\n","#Move files into the temporary source and target directories:\n","\n"," \n","for f in os.listdir(os.path.join(Training_source)):\n"," shutil.copy(Training_source+\"/\"+f, Training_source_temp+\"/\"+f)\n","\n","for p in os.listdir(os.path.join(Training_target)):\n"," shutil.copy(Training_target+\"/\"+p, Training_target_temp+\"/\"+p)\n","\n","\n","list_source_temp = os.listdir(os.path.join(Training_source_temp))\n","list_target_temp = os.listdir(os.path.join(Training_target_temp))\n","\n","\n","#Here we move images to be used for validation\n","for i in range(File_for_validation):\n","\n"," name = list_source_temp[i]\n"," shutil.move(Training_source_temp+\"/\"+name, Validation_source_temp+\"/\"+name)\n","\n"," shortname_no_extension = name[:-4]\n","\n"," shutil.move(Training_target_temp+\"/\"+shortname_no_extension+\".xml\", Validation_target_temp+\"/\"+shortname_no_extension+\".xml\")\n","\n","# Here we convert the XML files into COCO format to be loaded in detectron2\n","\n","#First we need to create list of labels to generate the json dictionaries\n","\n","list_source_training_temp = os.listdir(os.path.join(Training_source_temp))\n","list_source_validation_temp = os.listdir(os.path.join(Validation_source_temp))\n","\n","\n","name_no_extension_training = []\n","for n in list_source_training_temp:\n"," name_no_extension_training.append(os.path.splitext(n)[0])\n","\n","name_no_extension_validation = []\n","for n in list_source_validation_temp:\n"," name_no_extension_validation.append(os.path.splitext(n)[0])\n","\n","#Save the list of labels as text file\n","\n","with open('/content/training_files.txt', 'w') as f:\n"," for item in name_no_extension_training:\n"," print(item, end='\\n', file=f)\n","\n","with open('/content/validation_files.txt', 'w') as f:\n"," for item in name_no_extension_validation:\n"," print(item, end='\\n', file=f)\n","\n","\n","file_output_training = Training_target_temp+\"/output.json\"\n","file_output_validation = Validation_target_temp+\"/output.json\"\n","\n","\n","os.chdir(\"/content\")\n","!python voc2coco.py --ann_dir \"$Training_target_temp\" --output \"$file_output_training\" --ann_ids \"/content/training_files.txt\" --labels \"/content/labels.txt\" --ext xml\n","!python voc2coco.py --ann_dir \"$Validation_target_temp\" --output \"$file_output_validation\" --ann_ids \"/content/validation_files.txt\" --labels \"/content/labels.txt\" --ext xml\n","\n","\n","os.chdir(\"/\")\n","\n","#Here we load the dataset to detectron2\n","if cell_ran_training == 0:\n"," from detectron2.data.datasets import register_coco_instances\n"," register_coco_instances(\"my_dataset_train\", {}, Training_target_temp+\"/output.json\", Training_source_temp)\n"," register_coco_instances(\"my_dataset_val\", {}, Validation_target_temp+\"/output.json\", Validation_source_temp)\n","\n","\n","#visualize training data\n","my_dataset_train_metadata = MetadataCatalog.get(\"my_dataset_train\")\n","\n","dataset_dicts = DatasetCatalog.get(\"my_dataset_train\")\n","\n","import random\n","from detectron2.utils.visualizer import Visualizer\n","\n","for d in random.sample(dataset_dicts, 1):\n"," img = cv2.imread(d[\"file_name\"])\n"," visualizer = Visualizer(img[:, :, ::-1], metadata=my_dataset_train_metadata, instance_mode=ColorMode.SEGMENTATION, scale=0.8)\n"," vis = visualizer.draw_dataset_dict(d)\n"," cv2_imshow(vis.get_image()[:, :, ::-1])\n","\n","# failsafe\n","cell_ran_training = 1"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"STDOuNOFsTTJ"},"source":["## **3.2. Data augmentation** \n","---\n",""]},{"cell_type":"markdown","metadata":{"id":"E4QW-tvYsWhX"},"source":["Data augmentation is currently enabled by default in this notebook. The option to disable data augmentation is not yet avaialble.\n"," "]},{"cell_type":"markdown","metadata":{"id":"W6pZg0KVnPzf"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a Detectron2 model**. \n","\n"," "]},{"cell_type":"code","metadata":{"id":"l-EDcv3Wyvqb","cellView":"form"},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = True #@param {type:\"boolean\"}\n","\n","pretrained_model_choice = \"Faster R-CNN\" #@param [\"Faster R-CNN\",\"RetinaNet\", \"Model_from_file\"]\n","\n","#pretrained_model_choice = \"Faster R-CNN\" #@param [\"Faster R-CNN\", \"RetinaNet\", \"RPN & Fast R-CNN\", \"Model_from_file\"]\n","\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = pretrained_model_path\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n"," if not os.path.exists(h5_file_path) and Use_pretrained_model:\n"," print('WARNING pretrained model does not exist')\n"," h5_file_path = \"COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x.yaml\"\n"," print('The Faster R-CNN model will be used.')\n"," \n"," if pretrained_model_choice == \"Faster R-CNN\":\n"," h5_file_path = \"COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x.yaml\"\n"," print('The Faster R-CNN model will be used.')\n"," \n"," if pretrained_model_choice == \"RetinaNet\":\n"," h5_file_path = \"COCO-Detection/retinanet_R_101_FPN_3x.yaml\"\n"," print('The RetinaNet model will be used.')\n","\n"," if pretrained_model_choice == \"RPN & Fast R-CNN\":\n"," h5_file_path = \"COCO-Detection/fast_rcnn_R_50_FPN_1x.yaml\"\n","\n","\n","if not Use_pretrained_model:\n"," h5_file_path = \"COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x.yaml\"\n"," print('The Faster R-CNN model will be used.')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"HLYcZR9gMv42"},"source":["#**4. Train the network**\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"DapHLZBVMNBZ"},"source":["\n","## **4.1. Start Trainning**\n","---\n","\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point.\n"]},{"cell_type":"code","metadata":{"id":"Nft44VSLU8ZH","cellView":"form"},"source":["#@markdown ##Start training\n","\n","# Create the model folder\n","\n","if os.path.exists(full_model_path):\n"," shutil.rmtree(full_model_path)\n","os.makedirs(full_model_path)\n","\n","#Copy the label names in the model folder\n","shutil.copy(\"/content/labels.txt\", full_model_path+\"/\"+\"labels.txt\")\n","\n","#PDF export\n","#######################################\n","## MISSING \n","#######################################\n","#To be added\n","\n","start = time.time()\n","\n","#Load the config files\n","cfg = get_cfg()\n","\n","if pretrained_model_choice == \"Model_from_file\":\n"," cfg.merge_from_file(pretrained_model_path+\"/config.yaml\")\n","\n","if not pretrained_model_choice == \"Model_from_file\":\n"," cfg.merge_from_file(model_zoo.get_config_file(h5_file_path))\n","\n","cfg.DATASETS.TRAIN = (\"my_dataset_train\",)\n","cfg.DATASETS.TEST = (\"my_dataset_val\",)\n","cfg.OUTPUT_DIR= (full_model_path)\n","cfg.DATALOADER.NUM_WORKERS = 4\n","\n","if pretrained_model_choice == \"Model_from_file\":\n"," cfg.MODEL.WEIGHTS = pretrained_model_path+\"/model_final.pth\" # Let training initialize from model zoo\n","\n","if not pretrained_model_choice == \"Model_from_file\":\n"," cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(h5_file_path) # Let training initialize from model zoo\n","\n","cfg.SOLVER.IMS_PER_BATCH = int(batch_size)\n","cfg.SOLVER.BASE_LR = initial_learning_rate\n","\n","cfg.SOLVER.WARMUP_ITERS = 1000\n","cfg.SOLVER.MAX_ITER = int(number_of_iteration) #adjust up if val mAP is still rising, adjust down if overfit\n","cfg.SOLVER.STEPS = (1000, 1500)\n","cfg.SOLVER.GAMMA = 0.05\n","\n","cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512\n","\n","if pretrained_model_choice == \"Faster R-CNN\":\n"," cfg.MODEL.ROI_HEADS.NUM_CLASSES = (number_of_labels) \n","\n","if pretrained_model_choice == \"RetinaNet\":\n"," cfg.MODEL.RETINANET.NUM_CLASSES = (number_of_labels) \n","\n","cfg.TEST.EVAL_PERIOD = 500\n","trainer = CocoTrainer(cfg)\n","\n","trainer.resume_or_load(resume=False)\n","trainer.train()\n","\n","#Save the config file after trainning\n","config= cfg.dump() # print formatted configs\n","\n","file1 = open(full_model_path+\"/config.yaml\", 'w') \n"," \n","file1.writelines(config) \n","file1.close() #to change file access modes\n","\n","#Save the label file after trainning\n","\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Vd9igRYvSnTr"},"source":["## **4.2. Download your model(s) from Google Drive**\n","---\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"markdown","metadata":{"id":"sTMDT1u7rK9g"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. Detectron 2 requires you to reload your training dataset in order to perform the quality control step.\n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"id":"OVxLyPyPiv85","cellView":"form"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder as well as the location of your training dataset:\n","\n","#@markdown ####Path to trained model to be assessed: \n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#@markdown ####Path to the image(s) used for training: \n","Training_source = \"\" #@param {type:\"string\"}\n","Training_target = \"\" #@param {type:\"string\"}\n","\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","\n","else: \n"," print(bcolors.WARNING + '!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","# Here we load the list of classes stored in the model folder\n","list_of_labels_QC =[]\n","with open(full_QC_model_path+'labels.txt', newline='') as csvfile:\n"," reader = csv.reader(csvfile)\n"," for row in csv.reader(csvfile):\n"," list_of_labels_QC.append(row[0])\n","\n","#Here we create a list of color for later display\n","color_list = []\n","for i in range(len(list_of_labels_QC)):\n"," color = list(np.random.choice(range(256), size=3))\n"," color_list.append(color)\n","\n","#Save the list of labels as text file \n","if not (Use_the_current_trained_model):\n"," with open('/content/labels.txt', 'w') as f:\n"," for item in list_of_labels_QC:\n"," print(item, file=f)\n","\n"," # Here we split the data between training and validation\n"," # Here we count the number of files in the training target folder\n"," Filelist = os.listdir(Training_target)\n"," number_files = len(Filelist)\n"," percentage_validation= 10\n","\n"," File_for_validation = int((number_files)/percentage_validation)+1\n","\n"," #Here we split the training dataset between training and validation\n"," # Everything is copied in the /Content Folder\n","\n"," Training_source_temp = \"/content/training_source\"\n","\n"," if os.path.exists(Training_source_temp):\n"," shutil.rmtree(Training_source_temp)\n"," os.makedirs(Training_source_temp)\n","\n"," Training_target_temp = \"/content/training_target\"\n"," if os.path.exists(Training_target_temp):\n"," shutil.rmtree(Training_target_temp)\n"," os.makedirs(Training_target_temp)\n","\n"," Validation_source_temp = \"/content/validation_source\"\n","\n"," if os.path.exists(Validation_source_temp):\n"," shutil.rmtree(Validation_source_temp)\n"," os.makedirs(Validation_source_temp)\n","\n"," Validation_target_temp = \"/content/validation_target\"\n"," if os.path.exists(Validation_target_temp):\n"," shutil.rmtree(Validation_target_temp)\n"," os.makedirs(Validation_target_temp)\n","\n"," list_source = os.listdir(os.path.join(Training_source))\n"," list_target = os.listdir(os.path.join(Training_target))\n","\n","#Move files into the temporary source and target directories:\n"," \n"," for f in os.listdir(os.path.join(Training_source)):\n"," shutil.copy(Training_source+\"/\"+f, Training_source_temp+\"/\"+f)\n","\n"," for p in os.listdir(os.path.join(Training_target)):\n"," shutil.copy(Training_target+\"/\"+p, Training_target_temp+\"/\"+p)\n","\n"," list_source_temp = os.listdir(os.path.join(Training_source_temp))\n"," list_target_temp = os.listdir(os.path.join(Training_target_temp))\n","\n","\n","#Here we move images to be used for validation\n"," for i in range(File_for_validation):\n","\n"," name = list_source_temp[i]\n"," shutil.move(Training_source_temp+\"/\"+name, Validation_source_temp+\"/\"+name)\n","\n"," shortname_no_extension = name[:-4]\n","\n"," shutil.move(Training_target_temp+\"/\"+shortname_no_extension+\".xml\", Validation_target_temp+\"/\"+shortname_no_extension+\".xml\")\n","\n","\n","#First we need to create list of labels to generate the json dictionaries\n","\n"," list_source_training_temp = os.listdir(os.path.join(Training_source_temp))\n"," list_source_validation_temp = os.listdir(os.path.join(Validation_source_temp))\n","\n"," name_no_extension_training = []\n"," for n in list_source_training_temp:\n"," name_no_extension_training.append(os.path.splitext(n)[0])\n","\n"," name_no_extension_validation = []\n"," for n in list_source_validation_temp:\n"," name_no_extension_validation.append(os.path.splitext(n)[0])\n","\n","#Save the list of labels as text file\n","\n"," with open('/content/training_files.txt', 'w') as f:\n"," for item in name_no_extension_training:\n"," print(item, end='\\n', file=f)\n","\n"," with open('/content/validation_files.txt', 'w') as f:\n"," for item in name_no_extension_validation:\n"," print(item, end='\\n', file=f)\n","\n"," file_output_training = Training_target_temp+\"/output.json\"\n"," file_output_validation = Validation_target_temp+\"/output.json\"\n","\n"," os.chdir(\"/content\")\n"," !python voc2coco.py --ann_dir \"$Training_target_temp\" --output \"$file_output_training\" --ann_ids \"/content/training_files.txt\" --labels \"/content/labels.txt\" --ext xml\n"," !python voc2coco.py --ann_dir \"$Validation_target_temp\" --output \"$file_output_validation\" --ann_ids \"/content/validation_files.txt\" --labels \"/content/labels.txt\" --ext xml\n","\n"," os.chdir(\"/\")\n","\n","#Here we load the dataset to detectron2\n"," if cell_ran_QC_training_dataset == 0:\n"," from detectron2.data.datasets import register_coco_instances\n"," register_coco_instances(\"my_dataset_train\", {}, Training_target_temp+\"/output.json\", Training_source_temp)\n"," register_coco_instances(\"my_dataset_val\", {}, Validation_target_temp+\"/output.json\", Validation_source_temp)\n"," \n","#Failsafe for later\n","cell_ran_QC_training_dataset = 1"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"WZDvRjLZu-Lm"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","It is good practice to evaluate the training progress by studying if your model is slowly improving over time. The following cell will allow you to load Tensorboard and investigate how several metric evolved over time (iterations).\n","\n"," "]},{"cell_type":"code","metadata":{"cellView":"form","id":"cap-cHIfNZnm"},"source":["#@markdown ##Play the cell to load tensorboard\n","%load_ext tensorboard\n","%tensorboard --logdir \"$full_QC_model_path\""],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"lreUY7-SsGkI"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will compare the predictions generated by your model against ground-truth. Additionally, the below cell will show the mAP value of the model on the QC data If you want to read in more detail about this score, we recommend [this brief explanation](https://medium.com/@jonathan_hui/map-mean-average-precision-for-object-detection-45c121a31173).\n","\n"," The images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" should contain images (e.g. as .png) and annotations (.xml files)!\n","\n","\n","**mAP score:** This refers to the mean average precision of the model on the given dataset. This value gives an indication how precise the predictions of the classes on this dataset are when compared to the ground-truth. Values closer to 1 indicate a good fit.\n"]},{"cell_type":"code","metadata":{"id":"kjbHJHbtsg2R","cellView":"form"},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","if cell_ran_QC_QC_dataset == 0:\n","#Save the list of labels as text file \n"," with open('/content/labels_QC.txt', 'w') as f:\n"," for item in list_of_labels_QC:\n"," print(item, file=f)\n","\n","#Here we create temp folder for the QC\n","\n"," QC_source_temp = \"/content/QC_source\"\n","\n"," if os.path.exists(QC_source_temp):\n"," shutil.rmtree(QC_source_temp)\n"," os.makedirs(QC_source_temp)\n","\n"," QC_target_temp = \"/content/QC_target\"\n"," if os.path.exists(QC_target_temp):\n"," shutil.rmtree(QC_target_temp)\n"," os.makedirs(QC_target_temp)\n","\n","# Create a quality control/Prediction Folder\n"," if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n"," os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","#Here we move the QC files to the temp\n","\n"," for f in os.listdir(os.path.join(Source_QC_folder)):\n"," shutil.copy(Source_QC_folder+\"/\"+f, QC_source_temp+\"/\"+f)\n","\n"," for p in os.listdir(os.path.join(Target_QC_folder)):\n"," shutil.copy(Target_QC_folder+\"/\"+p, QC_target_temp+\"/\"+p)\n","\n","#Here we convert the XML files into JSON\n","#Save the list of files\n","\n"," list_source_QC_temp = os.listdir(os.path.join(QC_source_temp))\n","\n"," name_no_extension_QC = []\n"," for n in list_source_QC_temp:\n"," name_no_extension_QC.append(os.path.splitext(n)[0])\n","\n"," with open('/content/QC_files.txt', 'w') as f:\n"," for item in name_no_extension_QC:\n"," print(item, end='\\n', file=f)\n","\n","#Convert XML into JSON\n"," file_output_QC = QC_target_temp+\"/output.json\"\n","\n"," os.chdir(\"/content\")\n"," !python voc2coco.py --ann_dir \"$QC_target_temp\" --output \"$file_output_QC\" --ann_ids \"/content/QC_files.txt\" --labels \"/content/labels.txt\" --ext xml\n","\n"," os.chdir(\"/\")\n","\n","\n","#Here we register the QC dataset\n"," register_coco_instances(\"my_dataset_QC\", {}, QC_target_temp+\"/output.json\", QC_source_temp)\n"," cell_ran_QC_QC_dataset = 1\n","\n","\n","#Load the model to use\n","cfg = get_cfg()\n","cfg.merge_from_file(full_QC_model_path+\"config.yaml\")\n","cfg.MODEL.WEIGHTS = os.path.join(full_QC_model_path, \"model_final.pth\")\n","cfg.DATASETS.TEST = (\"my_dataset_QC\", )\n","cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5\n","\n","#Metadata\n","test_metadata = MetadataCatalog.get(\"my_dataset_QC\")\n","test_metadata.set(thing_color = color_list)\n","\n","# For the evaluation we need to load the trainer\n","trainer = CocoTrainer(cfg)\n","trainer.resume_or_load(resume=True)\n","\n","# Here we need to load the predictor\n","\n","predictor = DefaultPredictor(cfg)\n","evaluator = COCOEvaluator(\"my_dataset_QC\", cfg, False, output_dir=QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","val_loader = build_detection_test_loader(cfg, \"my_dataset_QC\")\n","inference_on_dataset(trainer.model, val_loader, evaluator)\n","\n","\n","print(\"A prediction is displayed\")\n","\n","dataset_QC_dicts = DatasetCatalog.get(\"my_dataset_QC\")\n","\n","for d in random.sample(dataset_QC_dicts, 1):\n"," print(\"Ground Truth\")\n"," img = cv2.imread(d[\"file_name\"])\n"," visualizer = Visualizer(img[:, :, ::-1], metadata=test_metadata, instance_mode=ColorMode.SEGMENTATION, scale=0.5)\n"," vis = visualizer.draw_dataset_dict(d)\n"," cv2_imshow(vis.get_image()[:, :, ::-1])\n","\n"," print(\"A prediction is displayed\")\n"," im = cv2.imread(d[\"file_name\"])\n"," outputs = predictor(im)\n"," v = Visualizer(im[:, :, ::-1],\n"," metadata=test_metadata,\n"," instance_mode=ColorMode.SEGMENTATION, \n"," scale=0.5\n"," )\n"," out = v.draw_instance_predictions(outputs[\"instances\"].to(\"cpu\"))\n"," cv2_imshow(out.get_image()[:, :, ::-1])\n","\n","cell_ran_QC_QC_dataset = 1"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"DWAhOBc7gpzN"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"KAILvLGFS2-1"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If an older model needs to be used, please untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contains the images that you want to predict using the network that you will train.\n","\n","**`Result_folder`:** This folder will contain the predicted output images."]},{"cell_type":"code","metadata":{"id":"lp-cx8TDIGI-","cellView":"form"},"source":["\n","#@markdown ### Provide the path to your dataset and to the folder where the prediction will be saved, then play the cell to predict output on your unseen images.\n","\n","#@markdown ###Path to data to analyse and where predicted output should be saved:\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder as well as the location of your training dataset:\n","\n","#@markdown ####Path to trained model to be assessed: \n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","\n","\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," print(bcolors.WARNING +'!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","#Here we will load the label file\n","\n","list_of_labels_predictions =[]\n","with open(full_Prediction_model_path+'labels.txt', newline='') as csvfile:\n"," reader = csv.reader(csvfile)\n"," for row in csv.reader(csvfile):\n"," list_of_labels_predictions.append(row[0])\n","\n","#Here we create a list of color\n","color_list = []\n","for i in range(len(list_of_labels_predictions)):\n"," color = list(np.random.choice(range(256), size=3))\n"," color_list.append(color)\n","\n","#Activate the pretrained model. \n","# Create config\n","cfg = get_cfg()\n","cfg.merge_from_file(full_Prediction_model_path+\"config.yaml\")\n","cfg.MODEL.WEIGHTS = os.path.join(full_Prediction_model_path, \"model_final.pth\")\n","cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # set threshold for this model\n","\n","# Create predictor\n","predictor = DefaultPredictor(cfg)\n","\n","#Load the metadata from the prediction file\n","prediction_metadata = Metadata()\n","prediction_metadata.set(thing_classes = list_of_labels_predictions)\n","prediction_metadata.set(thing_color = color_list)\n","\n","start = datetime.now()\n","\n","validation_folder = Path(Data_folder)\n","\n","for i, file in enumerate(validation_folder.glob(\"*.png\")):\n"," # this loop opens the .png files from the val-folder, creates a dict with the file\n"," # information, plots visualizations and saves the result as .pkl files.\n"," file = str(file)\n"," file_name = file.split(\"/\")[-1]\n"," im = cv2.imread(file)\n","\n"," #Prediction are done here\n"," outputs = predictor(im)\n","\n"," #here we extract the results into numpy arrays\n","\n"," Classes_predictions = outputs[\"instances\"].pred_classes.cpu().data.numpy()\n","\n"," boxes_predictions = outputs[\"instances\"].pred_boxes.tensor.cpu().numpy()\n"," Score_predictions = outputs[\"instances\"].scores.cpu().data.numpy()\n"," \n"," #here we save the results into a csv file\n"," prediction_csv = Result_folder+\"/\"+file_name+\"_predictions.csv\"\n","\n"," with open(prediction_csv, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['x1','y1','x2','y2','box width','box height', 'class', 'score' ]) \n","\n"," for i in range(len(boxes_predictions)):\n","\n"," x1 = boxes_predictions[i][0]\n"," y1 = boxes_predictions[i][1]\n"," x2 = boxes_predictions[i][2]\n"," y2 = boxes_predictions[i][3]\n"," box_width = x2 - x1\n"," box_height = y2 -y1\n","\n"," writer.writerow([str(x1), str(y1), str(x2), str(y2), str(box_width), str(box_height), str(list_of_labels_predictions[Classes_predictions[i]]), Score_predictions[i]])\n","\n","\n","# The last example is displayed \n","v = Visualizer(im, metadata=prediction_metadata, instance_mode=ColorMode.SEGMENTATION, scale=1)\n","v = v.draw_instance_predictions(outputs[\"instances\"].to(\"cpu\")) \n","plt.figure(figsize=(20,20))\n","plt.imshow(v.get_image()[:, :, ::-1])\n","plt.axis('off');\n","plt.savefig(Result_folder+\"/\"+file_name)\n"," \n","print(\"Time needed for inferencing:\", datetime.now() - start)\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"wgO7Ok1PBFQj"},"source":["## **6.2. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"nlyPYwZu4VVS"},"source":["#**Thank you for using Detectron2 2D!**"]}]} \ No newline at end of file +{"nbformat":4,"nbformat_minor":0,"metadata":{"accelerator":"GPU","colab":{"name":"Detectron2_2D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1hzAI0joLETcG5sI2Qvo8AKDr0TWRKySJ","timestamp":1587653755731},{"file_id":"1QFcz4NnQv4rMwDNl7AzHajN-Ola9sUFW","timestamp":1586411847878},{"file_id":"12UDRQ7abcnXcf5FctR9IUStgCpBiQWn7","timestamp":1584466922281},{"file_id":"1zXCn3A39GI1MCnXK_g_Z-AWh9vkB0YhU","timestamp":1583244415636}],"collapsed_sections":[]},"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.9"}},"cells":[{"cell_type":"markdown","metadata":{"id":"JhcOyBQjR54F"},"source":["#**This notebook is in beta**\n","Expect some instabilities and bugs.\n","\n","**Currently missing features include:**\n","\n","- Augmentation cannot be disabled\n","- Exported results include only a simple CSV file. More options will be included in the next releases\n","- Training and QC reports are not generated\n"]},{"cell_type":"markdown","metadata":{"id":"IkSguVy8Xv83"},"source":["# **Detectron2 (2D)**\n","\n"," Detectron2 is a deep-learning method designed to perform object detection and classification of objects in images. Detectron2 is Facebook AI Research's next generation software system that implements state-of-the-art object detection algorithms. It is a ground-up rewrite of the previous version, Detectron, and it originates from maskrcnn-benchmark. More information on Detectron2 can be found on the Detectron2 github pages (/~https://github.com/facebookresearch/detectron2).\n","\n","\n","\n","**This particular notebook enables object detection and classification on 2D images given ground truth bounding boxes. If you are interested in image segmentation, you should use our U-net or Stardist notebooks instead.**\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"NDICs5NxYEWP"},"source":["# **License**\n","\n","---"]},{"cell_type":"code","metadata":{"cellView":"form","id":"R575GX8cX2aP"},"source":["#@markdown ##Double click to see the license information\n","\n","#------------------------- LICENSE FOR ZeroCostDL4Mic------------------------------------\n","#This ZeroCostDL4Mic notebook is distributed under the MIT licence\n","\n","\n","\n","#------------------------- LICENSE FOR CycleGAN ------------------------------------\n","\n","\n","#Apache License\n","#Version 2.0, January 2004\n","#http://www.apache.org/licenses/\n","\n","#TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n","\n","#1. Definitions.\n","\n","#\"License\" shall mean the terms and conditions for use, reproduction,\n","#and distribution as defined by Sections 1 through 9 of this document.\n","\n","#\"Licensor\" shall mean the copyright owner or entity authorized by\n","#the copyright owner that is granting the License.\n","\n","#\"Legal Entity\" shall mean the union of the acting entity and all\n","#other entities that control, are controlled by, or are under common\n","#control with that entity. For the purposes of this definition,\n","#\"control\" means (i) the power, direct or indirect, to cause the\n","#direction or management of such entity, whether by contract or\n","#otherwise, or (ii) ownership of fifty percent (50%) or more of the\n","#outstanding shares, or (iii) beneficial ownership of such entity.\n","\n","#\"You\" (or \"Your\") shall mean an individual or Legal Entity\n","#exercising permissions granted by this License.\n","\n","#\"Source\" form shall mean the preferred form for making modifications,\n","#including but not limited to software source code, documentation\n","#source, and configuration files.\n","\n","#\"Object\" form shall mean any form resulting from mechanical\n","#transformation or translation of a Source form, including but\n","#not limited to compiled object code, generated documentation,\n","#and conversions to other media types.\n","\n","#\"Work\" shall mean the work of authorship, whether in Source or\n","#Object form, made available under the License, as indicated by a\n","#copyright notice that is included in or attached to the work\n","#(an example is provided in the Appendix below).\n","\n","#\"Derivative Works\" shall mean any work, whether in Source or Object\n","#form, that is based on (or derived from) the Work and for which the\n","#editorial revisions, annotations, elaborations, or other modifications\n","#represent, as a whole, an original work of authorship. For the purposes\n","#of this License, Derivative Works shall not include works that remain\n","#separable from, or merely link (or bind by name) to the interfaces of,\n","#the Work and Derivative Works thereof.\n","\n","#\"Contribution\" shall mean any work of authorship, including\n","#the original version of the Work and any modifications or additions\n","#to that Work or Derivative Works thereof, that is intentionally\n","#submitted to Licensor for inclusion in the Work by the copyright owner\n","#or by an individual or Legal Entity authorized to submit on behalf of\n","#the copyright owner. For the purposes of this definition, \"submitted\"\n","#means any form of electronic, verbal, or written communication sent\n","#to the Licensor or its representatives, including but not limited to\n","#communication on electronic mailing lists, source code control systems,\n","#and issue tracking systems that are managed by, or on behalf of, the\n","#Licensor for the purpose of discussing and improving the Work, but\n","#excluding communication that is conspicuously marked or otherwise\n","#designated in writing by the copyright owner as \"Not a Contribution.\"\n","\n","#\"Contributor\" shall mean Licensor and any individual or Legal Entity\n","#on behalf of whom a Contribution has been received by Licensor and\n","#subsequently incorporated within the Work.\n","\n","#2. Grant of Copyright License. Subject to the terms and conditions of\n","#this License, each Contributor hereby grants to You a perpetual,\n","#worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n","#copyright license to reproduce, prepare Derivative Works of,\n","#publicly display, publicly perform, sublicense, and distribute the\n","#Work and such Derivative Works in Source or Object form.\n","\n","#3. Grant of Patent License. Subject to the terms and conditions of\n","#this License, each Contributor hereby grants to You a perpetual,\n","#worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n","#(except as stated in this section) patent license to make, have made,\n","#use, offer to sell, sell, import, and otherwise transfer the Work,\n","#where such license applies only to those patent claims licensable\n","#by such Contributor that are necessarily infringed by their\n","#Contribution(s) alone or by combination of their Contribution(s)\n","#with the Work to which such Contribution(s) was submitted. If You\n","#institute patent litigation against any entity (including a\n","#cross-claim or counterclaim in a lawsuit) alleging that the Work\n","#or a Contribution incorporated within the Work constitutes direct\n","#or contributory patent infringement, then any patent licenses\n","#granted to You under this License for that Work shall terminate\n","#as of the date such litigation is filed.\n","\n","#4. Redistribution. You may reproduce and distribute copies of the\n","#Work or Derivative Works thereof in any medium, with or without\n","#modifications, and in Source or Object form, provided that You\n","#meet the following conditions:\n","\n","#(a) You must give any other recipients of the Work or\n","#Derivative Works a copy of this License; and\n","\n","#(b) You must cause any modified files to carry prominent notices\n","#stating that You changed the files; and\n","\n","#(c) You must retain, in the Source form of any Derivative Works\n","#that You distribute, all copyright, patent, trademark, and\n","#attribution notices from the Source form of the Work,\n","#excluding those notices that do not pertain to any part of\n","#the Derivative Works; and\n","\n","#(d) If the Work includes a \"NOTICE\" text file as part of its\n","#distribution, then any Derivative Works that You distribute must\n","#include a readable copy of the attribution notices contained\n","#within such NOTICE file, excluding those notices that do not\n","#pertain to any part of the Derivative Works, in at least one\n","#of the following places: within a NOTICE text file distributed\n","#as part of the Derivative Works; within the Source form or\n","#documentation, if provided along with the Derivative Works; or,\n","#within a display generated by the Derivative Works, if and\n","#wherever such third-party notices normally appear. The contents\n","#of the NOTICE file are for informational purposes only and\n","#do not modify the License. You may add Your own attribution\n","#notices within Derivative Works that You distribute, alongside\n","#or as an addendum to the NOTICE text from the Work, provided\n","#that such additional attribution notices cannot be construed\n","#as modifying the License.\n","\n","#You may add Your own copyright statement to Your modifications and\n","#may provide additional or different license terms and conditions\n","#for use, reproduction, or distribution of Your modifications, or\n","#for any such Derivative Works as a whole, provided Your use,\n","#reproduction, and distribution of the Work otherwise complies with\n","#the conditions stated in this License.\n","\n","#5. Submission of Contributions. Unless You explicitly state otherwise,\n","#any Contribution intentionally submitted for inclusion in the Work\n","#by You to the Licensor shall be under the terms and conditions of\n","#this License, without any additional terms or conditions.\n","#Notwithstanding the above, nothing herein shall supersede or modify\n","#the terms of any separate license agreement you may have executed\n","#with Licensor regarding such Contributions.\n","\n","#6. Trademarks. This License does not grant permission to use the trade\n","#names, trademarks, service marks, or product names of the Licensor,\n","#except as required for reasonable and customary use in describing the\n","#origin of the Work and reproducing the content of the NOTICE file.\n","\n","#7. Disclaimer of Warranty. Unless required by applicable law or\n","#agreed to in writing, Licensor provides the Work (and each\n","#Contributor provides its Contributions) on an \"AS IS\" BASIS,\n","#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n","#implied, including, without limitation, any warranties or conditions\n","#of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n","#PARTICULAR PURPOSE. You are solely responsible for determining the\n","#appropriateness of using or redistributing the Work and assume any\n","#risks associated with Your exercise of permissions under this License.\n","\n","#8. Limitation of Liability. In no event and under no legal theory,\n","#whether in tort (including negligence), contract, or otherwise,\n","#unless required by applicable law (such as deliberate and grossly\n","#negligent acts) or agreed to in writing, shall any Contributor be\n","#liable to You for damages, including any direct, indirect, special,\n","#incidental, or consequential damages of any character arising as a\n","#result of this License or out of the use or inability to use the\n","#Work (including but not limited to damages for loss of goodwill,\n","#work stoppage, computer failure or malfunction, or any and all\n","#other commercial damages or losses), even if such Contributor\n","#has been advised of the possibility of such damages.\n","\n","#9. Accepting Warranty or Additional Liability. While redistributing\n","#the Work or Derivative Works thereof, You may choose to offer,\n","#and charge a fee for, acceptance of support, warranty, indemnity,\n","#or other liability obligations and/or rights consistent with this\n","#License. However, in accepting such obligations, You may act only\n","#on Your own behalf and on Your sole responsibility, not on behalf\n","#of any other Contributor, and only if You agree to indemnify,\n","#defend, and hold each Contributor harmless for any liability\n","#incurred by, or claims asserted against, such Contributor by reason\n","#of your accepting any such warranty or additional liability.\n","\n","#END OF TERMS AND CONDITIONS\n","\n","#APPENDIX: How to apply the Apache License to your work.\n","\n","#To apply the Apache License to your work, attach the following\n","#boilerplate notice, with the fields enclosed by brackets \"[]\"\n","#replaced with your own identifying information. (Don't include\n","#the brackets!) The text should be enclosed in the appropriate\n","#comment syntax for the file format. We also recommend that a\n","#file or class name and description of purpose be included on the\n","#same \"printed page\" as the copyright notice for easier\n","#identification within third-party archives.\n","\n","#Copyright [yyyy] [name of copyright owner]\n","\n","\n","#Licensed under the Apache License, Version 2.0 (the \"License\");\n","#you may not use this file except in compliance with the License.\n","#You may obtain a copy of the License at\n","\n","#http://www.apache.org/licenses/LICENSE-2.0\n","\n","#Unless required by applicable law or agreed to in writing, software\n","#distributed under the License is distributed on an \"AS IS\" BASIS,\n","#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n","#See the License for the specific language governing permissions and\n","#limitations under the License."],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"gKDLkLWUd-YX"},"source":["#**0. Before getting started**\n","---\n"," Preparing the dataset carefully is essential to make this Detectron2 notebook work. This model requires as input a set of images and as target a list of annotation files in Pascal VOC format. The annotation files should have the exact same name as the input files, except with an .xml instead of the .jpg extension. The annotation files contain the class labels and all bounding boxes for the objects for each image in your dataset. Most datasets will give the option of saving the annotations in this format or using software for hand-annotations will automatically save the annotations in this format. \n","\n"," If you want to assemble your own dataset we recommend using the open source https://www.makesense.ai/ resource. You can follow our instructions on how to label your dataset with this tool on our [wiki](/~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki/Object-Detection-(YOLOv2)).\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n"," **Additionally, the corresponding input and output files need to have the same name**.\n","\n"," Please note that you currently can **only use .png files!**\n","\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Input images (Training_source)\n"," - img_1.png, img_2.png, ...\n"," - High SNR images (Training_source_annotations)\n"," - img_1.xml, img_2.xml, ...\n"," - **Quality control dataset**\n"," - Input images\n"," - img_1.png, img_2.png\n"," - High SNR images\n"," - img_1.xml, img_2.xml\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **1. Install Detectron2 and dependencies**\n","---"]},{"cell_type":"markdown","metadata":{"id":"yg1vZe88JEyk"},"source":["## **1.1. Install key dependencies**\n","---\n"," "]},{"cell_type":"code","metadata":{"id":"Tw1Usk1iPvRQ","cellView":"form"},"source":[" #@markdown ##Install dependencies and Detectron2\n","\n","from builtins import any as b_any\n","\n","def get_requirements_path():\n"," # Store requirements file in 'contents' directory\n"," current_dir = os.getcwd()\n"," dir_count = current_dir.count('/') - 1\n"," path = '../' * (dir_count) + 'requirements.txt'\n"," return path\n","\n","def filter_files(file_list, filter_list):\n"," filtered_list = []\n"," for fname in file_list:\n"," if b_any(fname.split('==')[0] in s for s in filter_list):\n"," filtered_list.append(fname)\n"," return filtered_list\n","\n","def build_requirements_file(before, after):\n"," path = get_requirements_path()\n","\n"," # Exporting requirements.txt for local run\n"," !pip freeze > $path\n","\n"," # Get minimum requirements file\n"," df = pd.read_csv(path, delimiter = \"\\n\")\n"," mod_list = [m.split('.')[0] for m in after if not m in before]\n"," req_list_temp = df.values.tolist()\n"," req_list = [x[0] for x in req_list_temp]\n","\n"," # Replace with package name and handle cases where import name is different to module name\n"," mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]\n"," mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list]\n"," filtered_list = filter_files(req_list, mod_replace_list)\n","\n"," file=open(path,'w')\n"," for item in filtered_list:\n"," file.writelines(item + '\\n')\n","\n"," file.close()\n","\n","import sys\n","before = [str(m) for m in sys.modules]\n","\n","# install dependencies\n","#!pip install -U torch torchvision cython\n","!pip install -U 'git+/~https://github.com/facebookresearch/fvcore.git' 'git+/~https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'\n","import torch, torchvision\n","import os\n","import pandas as pd\n","torch.__version__\n"," \n","!git clone /~https://github.com/facebookresearch/detectron2 detectron2_repo\n","!pip install -e detectron2_repo\n","\n","!pip install wget\n","\n","#Force session restart\n","exit(0)\n","\n","# Build requirements file for local run\n","after = [str(m) for m in sys.modules]\n","build_requirements_file(before, after)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xhWNIu6cf5G8"},"source":["## **1.2. Restart your runtime**\n","---\n","\n","\n","\n","** Ignore the following message error message. Your Runtime has automatically restarted. This is normal.**\n","\n","\"\"
\n"]},{"cell_type":"markdown","metadata":{"id":"5nXTBntzKRWu"},"source":["## **1.3. Load key dependencies**\n","---\n"," "]},{"cell_type":"code","metadata":{"id":"fq21zJVFNASx","cellView":"form"},"source":["Notebook_version = '1.13'\n","Network = 'Detectron 2D'\n","\n","\n","#@markdown ##Play this cell to load the required dependencies\n","import wget\n","# Some basic setup: \n","import detectron2\n","from detectron2.utils.logger import setup_logger\n","setup_logger()\n"," \n","# import some common libraries\n","import numpy as np\n","import os, json, cv2, random\n","from google.colab.patches import cv2_imshow\n"," \n","import yaml\n"," \n","#Download the script to convert XML into COCO\n"," \n","wget.download(\"/~https://github.com/HenriquesLab/ZeroCostDL4Mic/raw/master/Tools/voc2coco.py\", \"/content\")\n"," \n"," \n","# import some common detectron2 utilities\n","from detectron2 import model_zoo\n","from detectron2.engine import DefaultPredictor\n","from detectron2.config import get_cfg\n","from detectron2.utils.visualizer import Visualizer\n","from detectron2.data import MetadataCatalog, DatasetCatalog\n","from detectron2.utils.visualizer import ColorMode\n","\n","from detectron2.data import DatasetCatalog, MetadataCatalog, build_detection_test_loader\n","from datetime import datetime\n","from detectron2.data.catalog import Metadata\n","\n","from detectron2.config import get_cfg\n","from detectron2.data import DatasetCatalog, MetadataCatalog, build_detection_test_loader\n","from detectron2.evaluation import COCOEvaluator, inference_on_dataset\n","from detectron2.engine import DefaultTrainer\n","from detectron2.data.datasets import register_coco_instances\n","from detectron2.utils.visualizer import ColorMode\n","import glob\n","from detectron2.checkpoint import Checkpointer\n","from detectron2.config import get_cfg\n","import os\n"," \n"," \n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n"," \n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n"," \n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n"," \n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n"," \n"," \n","from detectron2.engine import DefaultTrainer\n","from detectron2.evaluation import COCOEvaluator\n"," \n","class CocoTrainer(DefaultTrainer):\n"," \n"," @classmethod\n"," def build_evaluator(cls, cfg, dataset_name, output_folder=None):\n"," \n"," if output_folder is None:\n"," os.makedirs(\"coco_eval\", exist_ok=True)\n"," output_folder = \"coco_eval\"\n"," \n"," return COCOEvaluator(dataset_name, cfg, False, output_folder)\n"," \n"," \n"," \n","print(\"Librairies loaded\")\n"," \n","# Check if this is the latest version of the notebook\n","All_notebook_versions = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_Notebook_versions.csv\", dtype=str)\n","print('Notebook version: '+Notebook_version)\n","Latest_Notebook_version = All_notebook_versions[All_notebook_versions[\"Notebook\"] == Network]['Version'].iloc[0]\n","print('Latest notebook version: '+Latest_Notebook_version)\n","if Notebook_version == Latest_Notebook_version:\n"," print(\"This notebook is up-to-date.\")\n","else:\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","\n"," \n","#Failsafes\n","cell_ran_prediction = 0\n","cell_ran_training = 0\n","cell_ran_QC_training_dataset = 0\n","cell_ran_QC_QC_dataset = 0"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"cbTknRcviyT7"},"source":["# **2. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["## **2.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelerator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"cellView":"form","id":"h5i5CS2bSmZr"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","#%tensorflow_version 1.x\n","\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n3B3meGTbYVi"},"source":["## **2.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"cellView":"form","id":"01Djr8v-5pPk"},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n"," \n","#@markdown * Click on the URL. \n"," \n","#@markdown * Sign in your Google Account. \n"," \n","#@markdown * Copy the authorization code. \n"," \n","#@markdown * Enter the authorization code. \n"," \n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n"," \n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xm5YEhKq-Hse"},"source":["** If you cannot see your files, reactivate your session by connecting to your hosted runtime.** \n","\n","\n","\"Example
Connect to a hosted runtime.
"]},{"cell_type":"markdown","metadata":{"id":"iwjra6kMKmUA"},"source":["# **3. Select your parameters and paths**\n"]},{"cell_type":"markdown","metadata":{"id":"Kbn9_JdqnNnK"},"source":["## **3.1. Setting main training parameters**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"id":"CB6acvUFtWqd"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source and the annotation data respectively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`labels`:** Input the name of the differentes labels used to annotate your dataset (separated by a comma).\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**Training Parameters**\n","\n","**`number_of_iteration`:** Input how many iterations to use to train the network. Initial results can be observed using 1000 iterations but consider using 5000 or more iterations to train your models. **Default value: 2000**\n"," \n","\n","**Advanced Parameters - experienced users only**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Noise2Void requires a large batch size for stable training. Reduce this parameter if your GPU runs out of memory. **Default value: 128**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each image / patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during the training. **Default value: 10**\n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0001**\n"]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","cellView":"form"},"source":["# create DataGenerator-object.\n","\n","\n","#@markdown ###Path to training image(s): \n","Training_source = \"\" #@param {type:\"string\"}\n","Training_target = \"\" #@param {type:\"string\"}\n","\n","#@markdown ###Labels\n","#@markdown Input the name of the differentes labels present in your training dataset separated by a comma\n","labels = \"\" #@param {type:\"string\"}\n","\n","\n","#@markdown ### Model name and path:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","full_model_path = model_path+'/'+model_name+'/'\n","\n","\n","#@markdown ###Training Parameters\n","#@markdown Number of iterations:\n","number_of_iteration = 2000#@param {type:\"number\"}\n","\n","\n","#Here we store the informations related to our labels\n","\n","list_of_labels = labels.split(\", \")\n","with open('/content/labels.txt', 'w') as f:\n"," for item in list_of_labels:\n"," print(item, file=f)\n","\n","number_of_labels = len(list_of_labels)\n","\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True#@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please input:\n","batch_size = 4#@param {type:\"number\"}\n","percentage_validation = 10#@param {type:\"number\"}\n","initial_learning_rate = 0.001 #@param {type:\"number\"}\n","\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 4\n"," percentage_validation = 10\n"," initial_learning_rate = 0.001\n","\n","# Here we disable pre-trained model by default (in case the next cell is not ran)\n","Use_pretrained_model = True\n","\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = True\n","\n","# Here we split the data between training and validation\n","# Here we count the number of files in the training target folder\n","Filelist = os.listdir(Training_target)\n","number_files = len(Filelist)\n","\n","File_for_validation = int((number_files)/percentage_validation)+1\n","\n","#Here we split the training dataset between training and validation\n","# Everything is copied in the /Content Folder\n","\n","Training_source_temp = \"/content/training_source\"\n","\n","if os.path.exists(Training_source_temp):\n"," shutil.rmtree(Training_source_temp)\n","os.makedirs(Training_source_temp)\n","\n","Training_target_temp = \"/content/training_target\"\n","if os.path.exists(Training_target_temp):\n"," shutil.rmtree(Training_target_temp)\n","os.makedirs(Training_target_temp)\n","\n","Validation_source_temp = \"/content/validation_source\"\n","\n","if os.path.exists(Validation_source_temp):\n"," shutil.rmtree(Validation_source_temp)\n","os.makedirs(Validation_source_temp)\n","\n","Validation_target_temp = \"/content/validation_target\"\n","if os.path.exists(Validation_target_temp):\n"," shutil.rmtree(Validation_target_temp)\n","os.makedirs(Validation_target_temp)\n","\n","list_source = os.listdir(os.path.join(Training_source))\n","list_target = os.listdir(os.path.join(Training_target))\n","\n","#Move files into the temporary source and target directories:\n","\n"," \n","for f in os.listdir(os.path.join(Training_source)):\n"," shutil.copy(Training_source+\"/\"+f, Training_source_temp+\"/\"+f)\n","\n","for p in os.listdir(os.path.join(Training_target)):\n"," shutil.copy(Training_target+\"/\"+p, Training_target_temp+\"/\"+p)\n","\n","\n","list_source_temp = os.listdir(os.path.join(Training_source_temp))\n","list_target_temp = os.listdir(os.path.join(Training_target_temp))\n","\n","\n","#Here we move images to be used for validation\n","for i in range(File_for_validation):\n","\n"," name = list_source_temp[i]\n"," shutil.move(Training_source_temp+\"/\"+name, Validation_source_temp+\"/\"+name)\n","\n"," shortname_no_extension = name[:-4]\n","\n"," shutil.move(Training_target_temp+\"/\"+shortname_no_extension+\".xml\", Validation_target_temp+\"/\"+shortname_no_extension+\".xml\")\n","\n","# Here we convert the XML files into COCO format to be loaded in detectron2\n","\n","#First we need to create list of labels to generate the json dictionaries\n","\n","list_source_training_temp = os.listdir(os.path.join(Training_source_temp))\n","list_source_validation_temp = os.listdir(os.path.join(Validation_source_temp))\n","\n","\n","name_no_extension_training = []\n","for n in list_source_training_temp:\n"," name_no_extension_training.append(os.path.splitext(n)[0])\n","\n","name_no_extension_validation = []\n","for n in list_source_validation_temp:\n"," name_no_extension_validation.append(os.path.splitext(n)[0])\n","\n","#Save the list of labels as text file\n","\n","with open('/content/training_files.txt', 'w') as f:\n"," for item in name_no_extension_training:\n"," print(item, end='\\n', file=f)\n","\n","with open('/content/validation_files.txt', 'w') as f:\n"," for item in name_no_extension_validation:\n"," print(item, end='\\n', file=f)\n","\n","\n","file_output_training = Training_target_temp+\"/output.json\"\n","file_output_validation = Validation_target_temp+\"/output.json\"\n","\n","\n","os.chdir(\"/content\")\n","!python voc2coco.py --ann_dir \"$Training_target_temp\" --output \"$file_output_training\" --ann_ids \"/content/training_files.txt\" --labels \"/content/labels.txt\" --ext xml\n","!python voc2coco.py --ann_dir \"$Validation_target_temp\" --output \"$file_output_validation\" --ann_ids \"/content/validation_files.txt\" --labels \"/content/labels.txt\" --ext xml\n","\n","\n","os.chdir(\"/\")\n","\n","#Here we load the dataset to detectron2\n","if cell_ran_training == 0:\n"," from detectron2.data.datasets import register_coco_instances\n"," register_coco_instances(\"my_dataset_train\", {}, Training_target_temp+\"/output.json\", Training_source_temp)\n"," register_coco_instances(\"my_dataset_val\", {}, Validation_target_temp+\"/output.json\", Validation_source_temp)\n","\n","\n","#visualize training data\n","my_dataset_train_metadata = MetadataCatalog.get(\"my_dataset_train\")\n","\n","dataset_dicts = DatasetCatalog.get(\"my_dataset_train\")\n","\n","import random\n","from detectron2.utils.visualizer import Visualizer\n","\n","for d in random.sample(dataset_dicts, 1):\n"," img = cv2.imread(d[\"file_name\"])\n"," visualizer = Visualizer(img[:, :, ::-1], metadata=my_dataset_train_metadata, instance_mode=ColorMode.SEGMENTATION, scale=0.8)\n"," vis = visualizer.draw_dataset_dict(d)\n"," cv2_imshow(vis.get_image()[:, :, ::-1])\n","\n","# failsafe\n","cell_ran_training = 1"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"STDOuNOFsTTJ"},"source":["## **3.2. Data augmentation** \n","---\n",""]},{"cell_type":"markdown","metadata":{"id":"E4QW-tvYsWhX"},"source":["Data augmentation is currently enabled by default in this notebook. The option to disable data augmentation is not yet avaialble.\n"," "]},{"cell_type":"markdown","metadata":{"id":"W6pZg0KVnPzf"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a Detectron2 model**. \n","\n"," "]},{"cell_type":"code","metadata":{"id":"l-EDcv3Wyvqb","cellView":"form"},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = True #@param {type:\"boolean\"}\n","\n","pretrained_model_choice = \"Faster R-CNN\" #@param [\"Faster R-CNN\",\"RetinaNet\", \"Model_from_file\"]\n","\n","#pretrained_model_choice = \"Faster R-CNN\" #@param [\"Faster R-CNN\", \"RetinaNet\", \"RPN & Fast R-CNN\", \"Model_from_file\"]\n","\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = pretrained_model_path\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n"," if not os.path.exists(h5_file_path) and Use_pretrained_model:\n"," print('WARNING pretrained model does not exist')\n"," h5_file_path = \"COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x.yaml\"\n"," print('The Faster R-CNN model will be used.')\n"," \n"," if pretrained_model_choice == \"Faster R-CNN\":\n"," h5_file_path = \"COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x.yaml\"\n"," print('The Faster R-CNN model will be used.')\n"," \n"," if pretrained_model_choice == \"RetinaNet\":\n"," h5_file_path = \"COCO-Detection/retinanet_R_101_FPN_3x.yaml\"\n"," print('The RetinaNet model will be used.')\n","\n"," if pretrained_model_choice == \"RPN & Fast R-CNN\":\n"," h5_file_path = \"COCO-Detection/fast_rcnn_R_50_FPN_1x.yaml\"\n","\n","\n","if not Use_pretrained_model:\n"," h5_file_path = \"COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x.yaml\"\n"," print('The Faster R-CNN model will be used.')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"HLYcZR9gMv42"},"source":["#**4. Train the network**\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"DapHLZBVMNBZ"},"source":["\n","## **4.1. Start Trainning**\n","---\n","\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point.\n"]},{"cell_type":"code","metadata":{"id":"Nft44VSLU8ZH","cellView":"form"},"source":["#@markdown ##Start training\n","\n","# Create the model folder\n","\n","if os.path.exists(full_model_path):\n"," shutil.rmtree(full_model_path)\n","os.makedirs(full_model_path)\n","\n","#Copy the label names in the model folder\n","shutil.copy(\"/content/labels.txt\", full_model_path+\"/\"+\"labels.txt\")\n","\n","#PDF export\n","#######################################\n","## MISSING \n","#######################################\n","#To be added\n","\n","start = time.time()\n","\n","#Load the config files\n","cfg = get_cfg()\n","\n","if pretrained_model_choice == \"Model_from_file\":\n"," cfg.merge_from_file(pretrained_model_path+\"/config.yaml\")\n","\n","if not pretrained_model_choice == \"Model_from_file\":\n"," cfg.merge_from_file(model_zoo.get_config_file(h5_file_path))\n","\n","cfg.DATASETS.TRAIN = (\"my_dataset_train\",)\n","cfg.DATASETS.TEST = (\"my_dataset_val\",)\n","cfg.OUTPUT_DIR= (full_model_path)\n","cfg.DATALOADER.NUM_WORKERS = 4\n","\n","if pretrained_model_choice == \"Model_from_file\":\n"," cfg.MODEL.WEIGHTS = pretrained_model_path+\"/model_final.pth\" # Let training initialize from model zoo\n","\n","if not pretrained_model_choice == \"Model_from_file\":\n"," cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(h5_file_path) # Let training initialize from model zoo\n","\n","cfg.SOLVER.IMS_PER_BATCH = int(batch_size)\n","cfg.SOLVER.BASE_LR = initial_learning_rate\n","\n","cfg.SOLVER.WARMUP_ITERS = 1000\n","cfg.SOLVER.MAX_ITER = int(number_of_iteration) #adjust up if val mAP is still rising, adjust down if overfit\n","cfg.SOLVER.STEPS = (1000, 1500)\n","cfg.SOLVER.GAMMA = 0.05\n","\n","cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512\n","\n","if pretrained_model_choice == \"Faster R-CNN\":\n"," cfg.MODEL.ROI_HEADS.NUM_CLASSES = (number_of_labels) \n","\n","if pretrained_model_choice == \"RetinaNet\":\n"," cfg.MODEL.RETINANET.NUM_CLASSES = (number_of_labels) \n","\n","cfg.TEST.EVAL_PERIOD = 500\n","trainer = CocoTrainer(cfg)\n","\n","trainer.resume_or_load(resume=False)\n","trainer.train()\n","\n","#Save the config file after trainning\n","config= cfg.dump() # print formatted configs\n","\n","file1 = open(full_model_path+\"/config.yaml\", 'w') \n"," \n","file1.writelines(config) \n","file1.close() #to change file access modes\n","\n","#Save the label file after trainning\n","\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Vd9igRYvSnTr"},"source":["## **4.2. Download your model(s) from Google Drive**\n","---\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"markdown","metadata":{"id":"sTMDT1u7rK9g"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. Detectron 2 requires you to reload your training dataset in order to perform the quality control step.\n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"id":"OVxLyPyPiv85","cellView":"form"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder as well as the location of your training dataset:\n","\n","#@markdown ####Path to trained model to be assessed: \n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#@markdown ####Path to the image(s) used for training: \n","Training_source = \"\" #@param {type:\"string\"}\n","Training_target = \"\" #@param {type:\"string\"}\n","\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","\n","else: \n"," print(bcolors.WARNING + '!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","# Here we load the list of classes stored in the model folder\n","list_of_labels_QC =[]\n","with open(full_QC_model_path+'labels.txt', newline='') as csvfile:\n"," reader = csv.reader(csvfile)\n"," for row in csv.reader(csvfile):\n"," list_of_labels_QC.append(row[0])\n","\n","#Here we create a list of color for later display\n","color_list = []\n","for i in range(len(list_of_labels_QC)):\n"," color = list(np.random.choice(range(256), size=3))\n"," color_list.append(color)\n","\n","#Save the list of labels as text file \n","if not (Use_the_current_trained_model):\n"," with open('/content/labels.txt', 'w') as f:\n"," for item in list_of_labels_QC:\n"," print(item, file=f)\n","\n"," # Here we split the data between training and validation\n"," # Here we count the number of files in the training target folder\n"," Filelist = os.listdir(Training_target)\n"," number_files = len(Filelist)\n"," percentage_validation= 10\n","\n"," File_for_validation = int((number_files)/percentage_validation)+1\n","\n"," #Here we split the training dataset between training and validation\n"," # Everything is copied in the /Content Folder\n","\n"," Training_source_temp = \"/content/training_source\"\n","\n"," if os.path.exists(Training_source_temp):\n"," shutil.rmtree(Training_source_temp)\n"," os.makedirs(Training_source_temp)\n","\n"," Training_target_temp = \"/content/training_target\"\n"," if os.path.exists(Training_target_temp):\n"," shutil.rmtree(Training_target_temp)\n"," os.makedirs(Training_target_temp)\n","\n"," Validation_source_temp = \"/content/validation_source\"\n","\n"," if os.path.exists(Validation_source_temp):\n"," shutil.rmtree(Validation_source_temp)\n"," os.makedirs(Validation_source_temp)\n","\n"," Validation_target_temp = \"/content/validation_target\"\n"," if os.path.exists(Validation_target_temp):\n"," shutil.rmtree(Validation_target_temp)\n"," os.makedirs(Validation_target_temp)\n","\n"," list_source = os.listdir(os.path.join(Training_source))\n"," list_target = os.listdir(os.path.join(Training_target))\n","\n","#Move files into the temporary source and target directories:\n"," \n"," for f in os.listdir(os.path.join(Training_source)):\n"," shutil.copy(Training_source+\"/\"+f, Training_source_temp+\"/\"+f)\n","\n"," for p in os.listdir(os.path.join(Training_target)):\n"," shutil.copy(Training_target+\"/\"+p, Training_target_temp+\"/\"+p)\n","\n"," list_source_temp = os.listdir(os.path.join(Training_source_temp))\n"," list_target_temp = os.listdir(os.path.join(Training_target_temp))\n","\n","\n","#Here we move images to be used for validation\n"," for i in range(File_for_validation):\n","\n"," name = list_source_temp[i]\n"," shutil.move(Training_source_temp+\"/\"+name, Validation_source_temp+\"/\"+name)\n","\n"," shortname_no_extension = name[:-4]\n","\n"," shutil.move(Training_target_temp+\"/\"+shortname_no_extension+\".xml\", Validation_target_temp+\"/\"+shortname_no_extension+\".xml\")\n","\n","\n","#First we need to create list of labels to generate the json dictionaries\n","\n"," list_source_training_temp = os.listdir(os.path.join(Training_source_temp))\n"," list_source_validation_temp = os.listdir(os.path.join(Validation_source_temp))\n","\n"," name_no_extension_training = []\n"," for n in list_source_training_temp:\n"," name_no_extension_training.append(os.path.splitext(n)[0])\n","\n"," name_no_extension_validation = []\n"," for n in list_source_validation_temp:\n"," name_no_extension_validation.append(os.path.splitext(n)[0])\n","\n","#Save the list of labels as text file\n","\n"," with open('/content/training_files.txt', 'w') as f:\n"," for item in name_no_extension_training:\n"," print(item, end='\\n', file=f)\n","\n"," with open('/content/validation_files.txt', 'w') as f:\n"," for item in name_no_extension_validation:\n"," print(item, end='\\n', file=f)\n","\n"," file_output_training = Training_target_temp+\"/output.json\"\n"," file_output_validation = Validation_target_temp+\"/output.json\"\n","\n"," os.chdir(\"/content\")\n"," !python voc2coco.py --ann_dir \"$Training_target_temp\" --output \"$file_output_training\" --ann_ids \"/content/training_files.txt\" --labels \"/content/labels.txt\" --ext xml\n"," !python voc2coco.py --ann_dir \"$Validation_target_temp\" --output \"$file_output_validation\" --ann_ids \"/content/validation_files.txt\" --labels \"/content/labels.txt\" --ext xml\n","\n"," os.chdir(\"/\")\n","\n","#Here we load the dataset to detectron2\n"," if cell_ran_QC_training_dataset == 0:\n"," from detectron2.data.datasets import register_coco_instances\n"," register_coco_instances(\"my_dataset_train\", {}, Training_target_temp+\"/output.json\", Training_source_temp)\n"," register_coco_instances(\"my_dataset_val\", {}, Validation_target_temp+\"/output.json\", Validation_source_temp)\n"," \n","#Failsafe for later\n","cell_ran_QC_training_dataset = 1"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"WZDvRjLZu-Lm"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","It is good practice to evaluate the training progress by studying if your model is slowly improving over time. The following cell will allow you to load Tensorboard and investigate how several metric evolved over time (iterations).\n","\n"," "]},{"cell_type":"code","metadata":{"cellView":"form","id":"cap-cHIfNZnm"},"source":["#@markdown ##Play the cell to load tensorboard\n","%load_ext tensorboard\n","%tensorboard --logdir \"$full_QC_model_path\""],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"lreUY7-SsGkI"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will compare the predictions generated by your model against ground-truth. Additionally, the below cell will show the mAP value of the model on the QC data If you want to read in more detail about this score, we recommend [this brief explanation](https://medium.com/@jonathan_hui/map-mean-average-precision-for-object-detection-45c121a31173).\n","\n"," The images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" should contain images (e.g. as .png) and annotations (.xml files)!\n","\n","\n","**mAP score:** This refers to the mean average precision of the model on the given dataset. This value gives an indication how precise the predictions of the classes on this dataset are when compared to the ground-truth. Values closer to 1 indicate a good fit.\n"]},{"cell_type":"code","metadata":{"id":"kjbHJHbtsg2R","cellView":"form"},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","if cell_ran_QC_QC_dataset == 0:\n","#Save the list of labels as text file \n"," with open('/content/labels_QC.txt', 'w') as f:\n"," for item in list_of_labels_QC:\n"," print(item, file=f)\n","\n","#Here we create temp folder for the QC\n","\n"," QC_source_temp = \"/content/QC_source\"\n","\n"," if os.path.exists(QC_source_temp):\n"," shutil.rmtree(QC_source_temp)\n"," os.makedirs(QC_source_temp)\n","\n"," QC_target_temp = \"/content/QC_target\"\n"," if os.path.exists(QC_target_temp):\n"," shutil.rmtree(QC_target_temp)\n"," os.makedirs(QC_target_temp)\n","\n","# Create a quality control/Prediction Folder\n"," if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n"," os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","#Here we move the QC files to the temp\n","\n"," for f in os.listdir(os.path.join(Source_QC_folder)):\n"," shutil.copy(Source_QC_folder+\"/\"+f, QC_source_temp+\"/\"+f)\n","\n"," for p in os.listdir(os.path.join(Target_QC_folder)):\n"," shutil.copy(Target_QC_folder+\"/\"+p, QC_target_temp+\"/\"+p)\n","\n","#Here we convert the XML files into JSON\n","#Save the list of files\n","\n"," list_source_QC_temp = os.listdir(os.path.join(QC_source_temp))\n","\n"," name_no_extension_QC = []\n"," for n in list_source_QC_temp:\n"," name_no_extension_QC.append(os.path.splitext(n)[0])\n","\n"," with open('/content/QC_files.txt', 'w') as f:\n"," for item in name_no_extension_QC:\n"," print(item, end='\\n', file=f)\n","\n","#Convert XML into JSON\n"," file_output_QC = QC_target_temp+\"/output.json\"\n","\n"," os.chdir(\"/content\")\n"," !python voc2coco.py --ann_dir \"$QC_target_temp\" --output \"$file_output_QC\" --ann_ids \"/content/QC_files.txt\" --labels \"/content/labels.txt\" --ext xml\n","\n"," os.chdir(\"/\")\n","\n","\n","#Here we register the QC dataset\n"," register_coco_instances(\"my_dataset_QC\", {}, QC_target_temp+\"/output.json\", QC_source_temp)\n"," cell_ran_QC_QC_dataset = 1\n","\n","\n","#Load the model to use\n","cfg = get_cfg()\n","cfg.merge_from_file(full_QC_model_path+\"config.yaml\")\n","cfg.MODEL.WEIGHTS = os.path.join(full_QC_model_path, \"model_final.pth\")\n","cfg.DATASETS.TEST = (\"my_dataset_QC\", )\n","cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5\n","\n","#Metadata\n","test_metadata = MetadataCatalog.get(\"my_dataset_QC\")\n","test_metadata.set(thing_color = color_list)\n","\n","# For the evaluation we need to load the trainer\n","trainer = CocoTrainer(cfg)\n","trainer.resume_or_load(resume=True)\n","\n","# Here we need to load the predictor\n","\n","predictor = DefaultPredictor(cfg)\n","evaluator = COCOEvaluator(\"my_dataset_QC\", cfg, False, output_dir=QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","val_loader = build_detection_test_loader(cfg, \"my_dataset_QC\")\n","inference_on_dataset(trainer.model, val_loader, evaluator)\n","\n","\n","print(\"A prediction is displayed\")\n","\n","dataset_QC_dicts = DatasetCatalog.get(\"my_dataset_QC\")\n","\n","for d in random.sample(dataset_QC_dicts, 1):\n"," print(\"Ground Truth\")\n"," img = cv2.imread(d[\"file_name\"])\n"," visualizer = Visualizer(img[:, :, ::-1], metadata=test_metadata, instance_mode=ColorMode.SEGMENTATION, scale=0.5)\n"," vis = visualizer.draw_dataset_dict(d)\n"," cv2_imshow(vis.get_image()[:, :, ::-1])\n","\n"," print(\"A prediction is displayed\")\n"," im = cv2.imread(d[\"file_name\"])\n"," outputs = predictor(im)\n"," v = Visualizer(im[:, :, ::-1],\n"," metadata=test_metadata,\n"," instance_mode=ColorMode.SEGMENTATION, \n"," scale=0.5\n"," )\n"," out = v.draw_instance_predictions(outputs[\"instances\"].to(\"cpu\"))\n"," cv2_imshow(out.get_image()[:, :, ::-1])\n","\n","cell_ran_QC_QC_dataset = 1"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"DWAhOBc7gpzN"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"KAILvLGFS2-1"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If an older model needs to be used, please untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contains the images that you want to predict using the network that you will train.\n","\n","**`Result_folder`:** This folder will contain the predicted output images."]},{"cell_type":"code","metadata":{"id":"lp-cx8TDIGI-","cellView":"form"},"source":["\n","#@markdown ### Provide the path to your dataset and to the folder where the prediction will be saved, then play the cell to predict output on your unseen images.\n","\n","#@markdown ###Path to data to analyse and where predicted output should be saved:\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder as well as the location of your training dataset:\n","\n","#@markdown ####Path to trained model to be assessed: \n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","\n","\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," print(bcolors.WARNING +'!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","#Here we will load the label file\n","\n","list_of_labels_predictions =[]\n","with open(full_Prediction_model_path+'labels.txt', newline='') as csvfile:\n"," reader = csv.reader(csvfile)\n"," for row in csv.reader(csvfile):\n"," list_of_labels_predictions.append(row[0])\n","\n","#Here we create a list of color\n","color_list = []\n","for i in range(len(list_of_labels_predictions)):\n"," color = list(np.random.choice(range(256), size=3))\n"," color_list.append(color)\n","\n","#Activate the pretrained model. \n","# Create config\n","cfg = get_cfg()\n","cfg.merge_from_file(full_Prediction_model_path+\"config.yaml\")\n","cfg.MODEL.WEIGHTS = os.path.join(full_Prediction_model_path, \"model_final.pth\")\n","cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # set threshold for this model\n","\n","# Create predictor\n","predictor = DefaultPredictor(cfg)\n","\n","#Load the metadata from the prediction file\n","prediction_metadata = Metadata()\n","prediction_metadata.set(thing_classes = list_of_labels_predictions)\n","prediction_metadata.set(thing_color = color_list)\n","\n","start = datetime.now()\n","\n","validation_folder = Path(Data_folder)\n","\n","for i, file in enumerate(validation_folder.glob(\"*.png\")):\n"," # this loop opens the .png files from the val-folder, creates a dict with the file\n"," # information, plots visualizations and saves the result as .pkl files.\n"," file = str(file)\n"," file_name = file.split(\"/\")[-1]\n"," im = cv2.imread(file)\n","\n"," #Prediction are done here\n"," outputs = predictor(im)\n","\n"," #here we extract the results into numpy arrays\n","\n"," Classes_predictions = outputs[\"instances\"].pred_classes.cpu().data.numpy()\n","\n"," boxes_predictions = outputs[\"instances\"].pred_boxes.tensor.cpu().numpy()\n"," Score_predictions = outputs[\"instances\"].scores.cpu().data.numpy()\n"," \n"," #here we save the results into a csv file\n"," prediction_csv = Result_folder+\"/\"+file_name+\"_predictions.csv\"\n","\n"," with open(prediction_csv, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['x1','y1','x2','y2','box width','box height', 'class', 'score' ]) \n","\n"," for i in range(len(boxes_predictions)):\n","\n"," x1 = boxes_predictions[i][0]\n"," y1 = boxes_predictions[i][1]\n"," x2 = boxes_predictions[i][2]\n"," y2 = boxes_predictions[i][3]\n"," box_width = x2 - x1\n"," box_height = y2 -y1\n","\n"," writer.writerow([str(x1), str(y1), str(x2), str(y2), str(box_width), str(box_height), str(list_of_labels_predictions[Classes_predictions[i]]), Score_predictions[i]])\n","\n","\n","# The last example is displayed \n","v = Visualizer(im, metadata=prediction_metadata, instance_mode=ColorMode.SEGMENTATION, scale=1)\n","v = v.draw_instance_predictions(outputs[\"instances\"].to(\"cpu\")) \n","plt.figure(figsize=(20,20))\n","plt.imshow(v.get_image()[:, :, ::-1])\n","plt.axis('off');\n","plt.savefig(Result_folder+\"/\"+file_name)\n"," \n","print(\"Time needed for inferencing:\", datetime.now() - start)\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"wgO7Ok1PBFQj"},"source":["## **6.2. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"ir5oDtGF-34t"},"source":["# **7. Version log**\n","---\n","**v1.13**: \n","\n","* The section 1 and 2 are now swapped for better export of *requirements.txt*.\n","\n","* This version also now includes built-in version check and the version log that you're reading now."]},{"cell_type":"markdown","metadata":{"id":"nlyPYwZu4VVS"},"source":["#**Thank you for using Detectron2 2D!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/Beta notebooks/MaskRCNN_ZeroCostDL4Mic.ipynb b/Colab_notebooks/Beta notebooks/MaskRCNN_ZeroCostDL4Mic.ipynb new file mode 100644 index 00000000..eac88d2d --- /dev/null +++ b/Colab_notebooks/Beta notebooks/MaskRCNN_ZeroCostDL4Mic.ipynb @@ -0,0 +1 @@ +{"nbformat":4,"nbformat_minor":0,"metadata":{"accelerator":"GPU","colab":{"name":"MaskRCNN_ZeroCostDL4Mic.ipynb","provenance":[],"collapsed_sections":["YrTo6T74i7s0","RZL8pqcEi0KY","3yywetML0lUX","F3zreN5K5S2S"]},"kernelspec":{"display_name":"Python 3","name":"python3"}},"cells":[{"cell_type":"markdown","metadata":{"id":"YrTo6T74i7s0"},"source":["# **MaskRCNN**\n","\n","---\n","\n"," This notebook is an implementation of MaskRCNN. This neural network performs instance segmentation. This means it can be used to detect objects in images, segment these objects and classify them. This notebook is based on the work of [He et al.](https://arxiv.org/abs/1703.06870)\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the *Zero-Cost Deep-Learning to Enhance Microscopy* project (ZeroCostDL4Mic) (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki)\n","\n","This notebook is based on the following paper: \n","\n","**Mask R-CNN**, arxiv, 2018 by Kaiming He, Georgia Gkioxari, Piotr Dollár, Ross Girshick [here](https://arxiv.org/abs/1703.06870)\n","\n","And source code found in: */~https://github.com/matterport/Mask_RCNN*\n","\n","Provide information on dataset availability and link for download if applicable.\n","\n","\n","**Please also cite this original paper when using or developing this notebook.**"]},{"cell_type":"markdown","metadata":{"id":"RZL8pqcEi0KY"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use ZeroCostDL4Mic notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"3yywetML0lUX"},"source":["#**0. Before getting started**\n","---\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n"," **Additionally, the corresponding input and output files need to have the same name**.\n","\n"," Please note that while the file format is flexible (.tif, .png, .jpeg should all work) but these currently **must be of RGB** type.\n","\n","Here's the data structure that you should use:\n","* Experiment A\n"," - **Training dataset**\n"," - Training\n"," - img_1.png, img_1.png.csv, img_2.png, img_2.png.csv, ...\n"," - Validation\n"," - img_a.png, img_a.png.csv, img_b.png, img_b.png.csv,...\n"," - **Quality control dataset**\n"," - Validation\n"," - img_a.png, img_a.png.csv, img_b.png, img_b.png.csv\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","\n"," **Note: This notebook is still in the beta stage.\n","Currently, the notebook works only if the annotation files are in csv format with the following columns:**\n","\n","***| filename | width | height | object_index | class_name | x | y |***\n","\n","where each row in the csv will provide the coordinates **(x,y)** of an edge point in the segmentation mask of an individual object with a dedicated **object_index** (e.g. 1, 2, 3....) and its **class_name** (e.g. 'nucleus' or 'horse' etc.) on the image of dimensions **width** x **height** (pixels). If you already have a dataset with segmentation masks we can provide a fiji macro that can convert the dataset into the correct format.\n","*We are actively working on integrating more flexibility into the annotations this notebook can be used with.*\n","\n","---\n","\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"ffNw8dIQjftT"},"source":["# **1. Install MaskRCNN and dependencies**\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"iYBjQqd95MpG"},"source":["## **1.1. Install key dependencies**\n","---\n"," "]},{"cell_type":"code","metadata":{"cellView":"form","id":"UTratIh3-Zl_"},"source":["#@markdown ##Install MaskRCNN and dependencies\n","!pip install fpdf\n","!pip install imgaug\n","!pip install h5py==2.10\n","!git clone /~https://github.com/matterport/Mask_RCNN\n","\n","#Force session restart\n","exit(0)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"c3JUL5cQ5cY-"},"source":["## **1.2. Restart your runtime**\n","---\n","\n","\n","\n","** Ignore the following message error message. Your Runtime has automatically restarted. This is normal.**\n","\n","\"\"
\n"]},{"cell_type":"markdown","metadata":{"id":"eLGtfVWE6lu9"},"source":["## **1.3. Load key dependencies**\n","---\n"," "]},{"cell_type":"code","metadata":{"id":"laDhajuKOs9t","cellView":"form"},"source":["Notebook_version = '1.13'\n","Network = 'MaskRCNN'\n","\n","from builtins import any as b_any\n","\n","def get_requirements_path():\n"," # Store requirements file in 'contents' directory\n"," current_dir = os.getcwd()\n"," dir_count = current_dir.count('/') - 1\n"," path = '../' * (dir_count) + 'requirements.txt'\n"," return path\n","\n","def filter_files(file_list, filter_list):\n"," filtered_list = []\n"," for fname in file_list:\n"," if b_any(fname.split('==')[0] in s for s in filter_list):\n"," filtered_list.append(fname)\n"," return filtered_list\n","\n","def build_requirements_file(before, after):\n"," path = get_requirements_path()\n","\n"," # Exporting requirements.txt for local run\n"," !pip freeze > $path\n","\n"," # Get minimum requirements file\n"," df = pd.read_csv(path, delimiter = \"\\n\")\n"," mod_list = [m.split('.')[0] for m in after if not m in before]\n"," req_list_temp = df.values.tolist()\n"," req_list = [x[0] for x in req_list_temp]\n","\n"," # Replace with package name and handle cases where import name is different to module name\n"," mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]\n"," mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list]\n"," filtered_list = filter_files(req_list, mod_replace_list)\n","\n"," file=open(path,'w')\n"," for item in filtered_list:\n"," file.writelines(item + '\\n')\n","\n"," file.close()\n","\n","import sys\n","before = [str(m) for m in sys.modules]\n","\n","#@markdown ##Load Key Dependencies\n","%tensorflow_version 1.x\n","\n","import os\n","import sys\n","import json\n","import datetime\n","import time\n","import numpy as np\n","import skimage.draw\n","from skimage import io\n","import imgaug\n","import pandas as pd\n","import csv\n","import random\n","import datetime\n","import shutil\n","from matplotlib import pyplot as plt\n","import matplotlib.lines as lines\n","from matplotlib.patches import Polygon\n","import IPython.display\n","from PIL import Image, ImageDraw, ImageFont\n","from fpdf import FPDF, HTMLMixin \n","from pip._internal.operations.freeze import freeze\n","import subprocess as sp\n","\n","# Root directory of the project\n","ROOT_DIR = os.path.abspath(\"/content\")\n","# !git clone /~https://github.com/matterport/Mask_RCNN\n","# Import Mask RCNN\n","sys.path.append(ROOT_DIR) # To find local version of the library\n","os.chdir('/content/Mask_RCNN')\n","\n","#Here we need to replace \"self.keras_model.metrics_tensors.append(loss)\" with \"self.keras_model.add_metric(loss, name)\"\n","# in model.py line 2199, otherwise we get version issues.\n","from tempfile import mkstemp\n","from shutil import move, copymode\n","from os import fdopen, remove\n","#This function replaces the old default files with new values\n","def replace(file_path, pattern, subst):\n"," #Create temp file\n"," fh, abs_path = mkstemp()\n"," with fdopen(fh,'w') as new_file:\n"," with open(file_path) as old_file:\n"," for line in old_file:\n"," new_file.write(line.replace(pattern, subst))\n"," #Copy the file permissions from the old file to the new file\n"," copymode(file_path, abs_path)\n"," #Remove original file\n"," remove(file_path)\n"," #Move new file\n"," move(abs_path, file_path)\n","\n","replace(\"/content/Mask_RCNN/mrcnn/model.py\",'self.keras_model.metrics_tensors.append(loss)','self.keras_model.add_metric(loss, name)')\n","#replace(\"/content/Mask_RCNN/mrcnn/model.py\", \"save_weights_only=True),\", \"save_weights_only=True),\\n keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor = 0.1, patience = 30, min_lr = 0, verbose = 1)\")\n","#replace(\"/content/Mask_RCNN/mrcnn/model.py\", \"save_weights_only=True),\", \"save_weights_only=True),\\n keras.callbacks.CSVLogger('/content/results.csv'),\")\n","replace(\"/content/Mask_RCNN/mrcnn/model.py\",'workers = 0','workers = 1')\n","replace(\"/content/Mask_RCNN/mrcnn/model.py\",'workers = multiprocessing.cpu_count()','workers = 1')\n","replace(\"/content/Mask_RCNN/mrcnn/model.py\",'use_multiprocessing=True','use_multiprocessing=False')\n","replace(\"/content/Mask_RCNN/mrcnn/utils.py\",\"shift = np.array([0, 0, 1, 1])\",\"shift = np.array([0., 0., 1., 1.])\")\n","replace(\"/content/Mask_RCNN/mrcnn/visualize.py\", \"i += 1\",\"i += 1\\n plt.savefig('/content/TrainingDataExample_MaskRCNN.png',bbox_inches='tight',pad_inches=0)\")\n","#replace(\"/content/Mask_RCNN/mrcnn/model.py\",\" class_ids\",\" if config.NUM_CLASSES == 2:\\n class_ids = tf.ones_like(probs[:, 0], dtype=tf.int32)\\n else:\\n class_ids\")\n","\n","#Using this command will allow display of detections below the 0.5 score threshold, if only 1 class beyond background is in the dataset\n","replace(\"/content/Mask_RCNN/mrcnn/model.py\",\"class_ids = tf.argmax(probs\",\"if config.NUM_CLASSES >= 2:\\n class_ids = tf.ones_like(probs[:, 0], dtype=tf.int32)\\n else:\\n class_ids = tf.argmax(probs\")\n","\n","\n","from mrcnn.config import Config\n","from mrcnn import model as modellib, utils\n","from mrcnn import visualize\n","from mrcnn.model import log\n","from mrcnn import utils\n","\n","def get_ax(rows=1, cols=1, size=8):\n"," \"\"\"Return a Matplotlib Axes array to be used in\n"," all visualizations in the notebook. Provide a\n"," central point to control graph sizes.\n"," \n"," Change the default size attribute to control the size\n"," of rendered images\n"," \"\"\"\n"," _, ax = plt.subplots(rows, cols, figsize=(size*cols, size*rows))\n"," return ax\n","\n","############################################################\n","# Dataset\n","############################################################\n","\n","class ClassDataset(utils.Dataset):\n"," def load_coco(annotation_file):\n"," dataset = json.load(open(annotation_file, 'r'))\n"," assert type(dataset)==dict, 'annotation file format {} not supported'.format(type(dataset))\n"," self.dataset = dataset\n"," self.createIndex()\n","\n"," def createIndex(self):\n"," # create index\n"," print('creating index...')\n"," anns, cats, imgs = {}, {}, {}\n"," imgToAnns,catToImgs = defaultdict(list),defaultdict(list)\n"," if 'annotations' in self.dataset:\n"," for ann in self.dataset['annotations']:\n"," imgToAnns[ann['image_id']].append(ann)\n"," anns[ann['id']] = ann\n","\n"," if 'images' in self.dataset:\n"," for img in self.dataset['images']:\n"," imgs[img['id']] = img\n","\n"," if 'categories' in self.dataset:\n"," for cat in self.dataset['categories']:\n"," cats[cat['id']] = cat\n","\n"," if 'annotations' in self.dataset and 'categories' in self.dataset:\n"," for ann in self.dataset['annotations']:\n"," catToImgs[ann['category_id']].append(ann['image_id'])\n","\n"," print('index created!')\n","\n"," # create class members\n"," self.anns = anns\n"," self.imgToAnns = imgToAnns\n"," self.catToImgs = catToImgs\n"," self.imgs = imgs\n"," self.cats = cats\n","\n"," def load_class(self, dataset_dir, subset):\n"," \"\"\"Load a subset of the dataset.\n"," dataset_dir: Root directory of the dataset.\n"," subset: Subset to load: train or val\n"," \"\"\"\n","\n"," # Add classes. We have only one class to add.\n"," self.add_class(\"Training_Datasets\", 1, \"nucleus\")\n"," \n"," # Train or validation dataset?\n"," assert subset in [\"Training\", \"Validation\"]\n"," dataset_dir = os.path.join(dataset_dir, subset)\n","\n"," # Load annotations\n"," # VGG Image Annotator (up to version 1.6) saves each image in the form:\n"," # { 'filename': '28503151_5b5b7ec140_b.jpg',\n"," # 'regions': {\n"," # '0': {\n"," # 'region_attributes': {},\n"," # 'shape_attributes': {\n"," # 'all_points_x': [...],\n"," # 'all_points_y': [...],\n"," # 'name': 'polygon'}},\n"," # ... more regions ...\n"," # },\n"," # 'size': 100202\n"," # }\n"," # We mostly care about the x and y coordinates of each region\n"," # Note: In VIA 2.0, regions was changed from a dict to a list.\n"," annotations = json.load(open(os.path.join(dataset_dir, \"birds071220220_json.json\")))\n"," annotations = list(annotations.values()) # don't need the dict keys\n"," \n"," # The VIA tool saves images in the JSON even if they don't have any\n"," # annotations. Skip unannotated images.\n"," annotations = [a for a in annotations if a['regions']]\n"," \n"," # Add images\n"," for a in annotations:\n"," # Get the x, y coordinaets of points of the polygons that make up\n"," # the outline of each object instance. These are stores in the\n"," # shape_attributes (see json format above)\n"," # The if condition is needed to support VIA versions 1.x and 2.x.\n"," if type(a['regions']) is dict:\n"," polygons = [r['shape_attributes'] for r in a['regions'].values()]\n"," else:\n"," polygons = [r['shape_attributes'] for r in a['regions']] \n","\n"," #Get the class of the object\n"," obj_class = [c['region_attributes']['species'] for c in a['regions']]\n","\n"," # load_mask() needs the image size to convert polygons to masks.\n"," # Unfortunately, VIA doesn't include it in JSON, so we must read\n"," # the image. This is only managable since the dataset is tiny.\n"," image_path = os.path.join(dataset_dir, a['filename'])\n"," image = skimage.io.imread(image_path)\n"," height, width = image.shape[:2]\n","\n"," self.add_image(\n"," \"Training_Datasets\",\n"," image_id=a['filename'], # use file name as a unique image id\n"," path=image_path,\n"," width=width, height=height,\n"," polygons=polygons,\n"," obj_class=obj_class)\n"," \n"," def load_image_csv(self, dataset_dir, subset):\n"," # Add classes. We have only one class to add.\n"," # self.add_class(\"Training_Datasets\", 1, \"nucleus\")\n"," #self.add_class(\"Training_Datasets\", 2, \"Great tit\")\n"," \n"," # Train or validation dataset?\n"," assert subset in [\"Training\", \"Validation\"]\n"," dataset_dir = os.path.join(dataset_dir, subset)\n"," #Data Format\n"," #csv file:\n"," #filename,width,height,object_index, class_name, x, y\n"," #file_1,256,256,1,nucleus, 1, 1\n"," #file_1,256,256,1,nucleus, 3, 10\n"," #file_1,256,256,1,nucleus, 1, 3\n"," #file_1,256,256,1,nucleus, 3, 7\n"," #file_1,256,256,2,nucleus, 17, 20\n"," #...\n"," class_index = 0\n"," obj_class_old = \"\"\n"," #class_names will hold all the classes we find in the dataset \n"," class_names = {obj_class_old:class_index}\n"," for csv_file_name in os.listdir(dataset_dir):\n"," if csv_file_name.endswith('.csv'):\n"," with open(os.path.join(dataset_dir,csv_file_name)) as csvfile_count:\n"," row_count = sum(1 for _ in csvfile_count)\n"," with open(os.path.join(dataset_dir,csv_file_name)) as csvfile:\n"," annotations = csv.reader(csvfile)\n"," next(annotations)\n"," polygons = []\n"," x_values = []\n"," y_values = []\n"," index_old = 1\n"," for line in annotations:\n"," img_file_name = line[0]\n"," index_new = int(line[4])\n"," obj_class = line[3]\n"," \n"," if not obj_class in class_names:\n"," class_index+=1\n"," class_names[obj_class] = class_index\n"," self.add_class(\"Training_Datasets\", class_index, obj_class)\n"," \n"," if index_new == index_old:\n"," x_values.append(int(line[5]))\n"," y_values.append(int(line[6]))\n"," \n"," if row_count == annotations.line_num:\n"," polygon = {\"class_name\":class_names[obj_class],\"all_points_x\":x_values,\"all_points_y\":y_values}\n"," polygons.append(polygon)\n"," \n"," elif index_new != index_old:\n"," polygon = {\"class_name\":class_names[obj_class_old],\"all_points_x\":x_values,\"all_points_y\":y_values}\n"," polygons.append(polygon)\n"," x_values = []\n"," x_values.append(int(line[5]))\n"," y_values = []\n"," y_values.append(int(line[6]))\n"," \n"," index_old = int(line[4])\n"," obj_class_old = line[3]\n"," image_path = os.path.join(dataset_dir,img_file_name)\n"," \n"," self.add_image(\n"," \"Training_Datasets\",\n"," image_id=img_file_name, # use file name as a unique image id\n"," path=image_path,\n"," width=int(line[1]), height=int(line[2]),\n"," polygons=polygons)\n"," #print(csv_file_name, class_index, polygons)\n"," return class_index\n","\n"," def load_mask(self, image_id):\n"," info = self.image_info[image_id]\n"," #print(info)\n"," mask = np.zeros([info[\"height\"], info[\"width\"], len(info[\"polygons\"])],\n"," dtype=np.uint8)\n"," class_ids = []\n"," #class_index = 0\n"," for i, p in enumerate(info[\"polygons\"]):\n"," \n"," class_name = p['class_name']\n"," # class_names = {class_name:class_index}\n"," # if class_name != class_name_old:\n"," # class_index+=1\n"," # class_names[class_name] = class_index\n"," \n"," # Get indexes of pixels inside the polygon and set them to 1\n"," # print(p['y_values'])\n"," rr, cc = skimage.draw.polygon(p['all_points_y'], p['all_points_x'])\n"," mask[rr, cc, i] = 1\n"," \n"," #class_name_old = p['class_name']\n"," class_ids.append(class_name)\n"," \n"," class_ids = np.array(class_ids)\n","\n"," return mask.astype(np.bool), class_ids.astype(np.int32)\n","\n"," # def load_mask(self, image_id):\n"," # \"\"\"Generate instance masks for an image.\n"," # Returns:\n"," # masks: A bool array of shape [height, width, instance count] with\n"," # one mask per instance.\n"," # class_ids: a 1D array of class IDs of the instance masks.\n"," # \"\"\"\n"," # def clean_name(name):\n"," # \"\"\"Returns a shorter version of object names for cleaner display.\"\"\"\n"," # return \",\".join(name.split(\",\")[:1])\n","\n"," # # If not a balloon dataset image, delegate to parent class.\n"," # image_info = self.image_info[image_id]\n"," # if image_info[\"source\"] != \"Training_Datasets\":\n"," # return super(self.__class__, self).load_mask(image_id)\n","\n"," # # Convert polygons to a bitmap mask of shape\n"," # # [height, width, instance_count]\n"," # info = self.image_info[image_id]\n","\n"," # mask = np.zeros([info[\"height\"], info[\"width\"], len(info[\"polygons\"])],\n"," # dtype=np.uint8)\n"," # for i, p in enumerate(info[\"polygons\"]):\n"," # # Get indexes of pixels inside the polygon and set them to 1\n"," # rr, cc = skimage.draw.polygon(p['all_points_y'], p['all_points_x'])\n"," # mask[rr, cc, i] = 1\n","\n"," # classes = info[\"obj_class\"]\n"," # class_list = [clean_name(c[\"name\"]) for c in self.class_info]\n"," # class_ids = np.array([class_list.index(s) for s in classes])\n","\n"," # # Return mask, and array of class IDs of each instance. Since we have\n"," # # one class ID only, we return an array of 1s\n"," # return mask.astype(np.bool), class_ids.astype(np.int32)#np.ones([mask.shape[-1]], dtype=np.int32)\n","\n"," def image_reference(self, image_id):\n"," \"\"\"Return the path of the image.\"\"\"\n"," info = self.image_info[image_id]\n"," if info[\"source\"] == \"Training_Datasets\":\n"," return info[\"path\"]\n"," else:\n"," super(self.__class__, self).image_reference(image_id)\n","\n","\n","def train(model, augmentation=True):\n"," \"\"\"Train the model.\"\"\"\n"," # Training dataset.\n"," dataset_train = ClassDataset()\n"," dataset_train.load_class('/content/gdrive/MyDrive/MaskRCNN/Training_Datasets', \"Training\")\n"," dataset_train.prepare()\n","\n"," # Validation dataset\n"," dataset_val = ClassDataset()\n"," dataset_val.load_class('/content/gdrive/MyDrive/MaskRCNN/Training_Datasets', \"Validation\")\n"," dataset_val.prepare()\n","\n"," if augmentation == True:\n"," augment = imgaug.augmenters.Sometimes(0.5, imgaug.augmenters.OneOf([imgaug.augmenters.Fliplr(0.5),\n"," imgaug.augmenters.Flipud(0.5),\n"," imgaug.augmenters.Affine(rotate=45)]))\n"," else:\n"," augment = None\n"," # *** This training schedule is an example. Update to your needs ***\n"," # Since we're using a very small dataset, and starting from\n"," # COCO trained weights, we don't need to train too long. Also,\n"," # no need to train all layers, just the heads should do it.\n"," print(\"Training network heads\")\n"," model.train(dataset_train, dataset_val,\n"," learning_rate=config.LEARNING_RATE,\n"," epochs=80,\n"," augmentation = augment,\n"," layers='heads')\n","\n","\n","def train_csv(model, training_folder, augmentation=True, epochs = 20, layers = 'heads'):\n"," \"\"\"Train the model.\"\"\"\n"," # Training dataset.\n"," dataset_train = ClassDataset()\n"," dataset_train.load_image_csv(training_folder, \"Training\")\n"," dataset_train.prepare()\n","\n"," # Validation dataset\n"," dataset_val = ClassDataset()\n"," dataset_val.load_image_csv(training_folder, \"Validation\")\n"," dataset_val.prepare()\n","\n"," if augmentation == True:\n"," augment = imgaug.augmenters.SomeOf((1,2),[imgaug.augmenters.OneOf([imgaug.augmenters.Affine(rotate=90),\n"," imgaug.augmenters.Affine(rotate=180),\n"," imgaug.augmenters.Affine(rotate=270)]),\n"," imgaug.augmenters.Fliplr(0.5),\n"," imgaug.augmenters.Flipud(0.5),\n"," imgaug.augmenters.Multiply((0.8, 1.5)),\n"," imgaug.augmenters.GaussianBlur(sigma=(0.0, 5.0))])\n"," else:\n"," augment = None\n"," # *** This training schedule is an example. Update to your needs ***\n"," # Since we're using a very small dataset, and starting from\n"," # COCO trained weights, we don't need to train too long. Also,\n"," # no need to train all layers, just the heads should do it.\n"," print(\"Training network heads\")\n"," model.train(dataset_train, dataset_val,\n"," learning_rate=config.LEARNING_RATE,\n"," epochs=epochs,\n"," augmentation = augment,\n"," layers=layers)\n","\n","def color_splash(image, mask):\n"," \"\"\"Apply color splash effect.\n"," image: RGB image [height, width, 3]\n"," mask: instance segmentation mask [height, width, instance count]\n"," Returns result image.\n"," \"\"\"\n"," # Make a grayscale copy of the image. The grayscale copy still\n"," # has 3 RGB channels, though.\n"," gray = skimage.color.gray2rgb(skimage.color.rgb2gray(image)) * 255\n"," # Copy color pixels from the original color image where mask is set\n"," if mask.shape[-1] > 0:\n"," # We're treating all instances as one, so collapse the mask into one layer\n"," mask = (np.sum(mask, -1, keepdims=True) >= 1)\n"," splash = np.where(mask, image, gray).astype(np.uint8)\n"," else:\n"," splash = gray.astype(np.uint8)\n"," return splash\n","\n","\n","def detect_and_color_splash(model, image_path=None, video_path=None):\n"," assert image_path or video_path\n","\n"," # Image or video?\n"," if image_path:\n"," # Run model detection and generate the color splash effect\n"," print(\"Running on {}\".format(args.image))\n"," # Read image\n"," image = skimage.io.imread(args.image)\n"," # Detect objects\n"," r = model.detect([image], verbose=1)[0]\n"," # Color splash\n"," splash = color_splash(image, r['masks'])\n"," # Save output\n"," file_name = \"splash_{:%Y%m%dT%H%M%S}.png\".format(datetime.datetime.now())\n"," skimage.io.imsave(file_name, splash)\n"," elif video_path:\n"," import cv2\n"," # Video capture\n"," vcapture = cv2.VideoCapture(video_path)\n"," width = int(vcapture.get(cv2.CAP_PROP_FRAME_WIDTH))\n"," height = int(vcapture.get(cv2.CAP_PROP_FRAME_HEIGHT))\n"," fps = vcapture.get(cv2.CAP_PROP_FPS)\n","\n"," # Define codec and create video writer\n"," file_name = \"splash_{:%Y%m%dT%H%M%S}.avi\".format(datetime.datetime.now())\n"," vwriter = cv2.VideoWriter(file_name,\n"," cv2.VideoWriter_fourcc(*'MJPG'),\n"," fps, (width, height))\n","\n"," count = 0\n"," success = True\n"," while success:\n"," print(\"frame: \", count)\n"," # Read next image\n"," success, image = vcapture.read()\n"," if success:\n"," # OpenCV returns images as BGR, convert to RGB\n"," image = image[..., ::-1]\n"," # Detect objects\n"," r = model.detect([image], verbose=0)[0]\n"," # Color splash\n"," splash = color_splash(image, r['masks'])\n"," # RGB -> BGR to save image to video\n"," splash = splash[..., ::-1]\n"," # Add image to video writer\n"," vwriter.write(splash)\n"," count += 1\n"," vwriter.release()\n"," print(\"Saved to \", file_name)\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n"," NORMAL = '\\033[0m'\n","\n","class ClassConfig(Config):\n"," \"\"\"Configuration for training on the toy dataset.\n"," Derives from the base Config class and overrides some values.\n"," \"\"\"\n"," # Give the configuration a recognizable name\n"," # We use a GPU with 12GB memory, which can fit two images.\n"," # Adjust down if you use a smaller GPU.\n"," IMAGES_PER_GPU = 1\n"," DETECTION_MIN_CONFIDENCE = 0\n"," NAME = \"nucleus\"\n"," # Backbone network architecture\n"," # Supported values are: resnet50, resnet101\n"," BACKBONE = \"resnet50\"\n"," # Input image resizing\n"," # Random crops of size 64x64\n"," IMAGE_RESIZE_MODE = \"crop\"\n"," IMAGE_MIN_DIM = 256\n"," IMAGE_MAX_DIM = 256\n"," IMAGE_MIN_SCALE = 2.0\n"," # Length of square anchor side in pixels\n"," RPN_ANCHOR_SCALES = (4, 8, 16, 32, 64)\n"," # ROIs kept after non-maximum supression (training and inference)\n"," POST_NMS_ROIS_TRAINING = 200\n"," POST_NMS_ROIS_INFERENCE = 400\n"," # Non-max suppression threshold to filter RPN proposals.\n"," # You can increase this during training to generate more propsals.\n"," RPN_NMS_THRESHOLD = 0.9\n"," # How many anchors per image to use for RPN training\n"," RPN_TRAIN_ANCHORS_PER_IMAGE = 64\n"," # Image mean (RGB)\n"," MEAN_PIXEL = np.array([43.53, 39.56, 48.22])\n"," # If enabled, resizes instance masks to a smaller size to reduce\n"," # memory load. Recommended when using high-resolution images.\n"," USE_MINI_MASK = True\n"," MINI_MASK_SHAPE = (56, 56) # (height, width) of the mini-mask\n"," TRAIN_ROIS_PER_IMAGE = 128\n"," # Maximum number of ground truth instances to use in one image\n"," MAX_GT_INSTANCES = 100\n"," # Max number of final detections per image\n"," DETECTION_MAX_INSTANCES = 200\n","\n","# Below we define a function which saves the predictions.\n","# It is from this branch:\n","# /~https://github.com/matterport/Mask_RCNN/commit/bc8f148b820ebd45246ed358a120c99b09798d71\n","\n","def save_image(image, image_name, boxes, masks, class_ids, scores, class_names, filter_classs_names=None,\n"," scores_thresh=0.1, save_dir=None, mode=0):\n"," \"\"\"\n"," image: image array\n"," image_name: image name\n"," boxes: [num_instance, (y1, x1, y2, x2, class_id)] in image coordinates.\n"," masks: [num_instances, height, width]\n"," class_ids: [num_instances]\n"," scores: confidence scores for each box\n"," class_names: list of class names of the dataset\n"," filter_classs_names: (optional) list of class names we want to draw\n"," scores_thresh: (optional) threshold of confidence scores\n"," save_dir: (optional) the path to store image\n"," mode: (optional) select the result which you want\n"," mode = 0 , save image with bbox,class_name,score and mask;\n"," mode = 1 , save image with bbox,class_name and score;\n"," mode = 2 , save image with class_name,score and mask;\n"," mode = 3 , save mask with black background;\n"," \"\"\"\n"," mode_list = [0, 1, 2, 3]\n"," assert mode in mode_list, \"mode's value should in mode_list %s\" % str(mode_list)\n","\n"," if save_dir is None:\n"," save_dir = os.path.join(os.getcwd(), \"output\")\n"," if not os.path.exists(save_dir):\n"," os.makedirs(save_dir)\n","\n"," useful_mask_indices = []\n","\n"," N = boxes.shape[0]\n"," if not N:\n"," print(\"\\n*** No instances in image %s to draw *** \\n\" % (image_name))\n"," return\n"," else:\n"," assert boxes.shape[0] == masks.shape[-1] == class_ids.shape[0]\n","\n"," for i in range(N):\n"," # filter\n"," class_id = class_ids[i]\n"," score = scores[i] if scores is not None else None\n"," if score is None or score < scores_thresh:\n"," continue\n","\n"," label = class_names[class_id]\n"," if (filter_classs_names is not None) and (label not in filter_classs_names):\n"," continue\n","\n"," if not np.any(boxes[i]):\n"," # Skip this instance. Has no bbox. Likely lost in image cropping.\n"," continue\n","\n"," useful_mask_indices.append(i)\n","\n"," if len(useful_mask_indices) == 0:\n"," print(\"\\n*** No instances in image %s to draw *** \\n\" % (image_name))\n"," return\n","\n"," colors = visualize.random_colors(len(useful_mask_indices))\n","\n"," if mode != 3:\n"," masked_image = image.astype(np.uint8).copy()\n"," else:\n"," masked_image = np.zeros(image.shape).astype(np.uint8)\n","\n"," if mode != 1:\n"," for index, value in enumerate(useful_mask_indices):\n"," masked_image = visualize.apply_mask(masked_image, masks[:, :, value], colors[index])\n","\n"," masked_image = Image.fromarray(masked_image)\n","\n"," if mode == 3:\n"," masked_image.save(os.path.join(save_dir, '%s' % (image_name)))\n"," return\n","\n"," draw = ImageDraw.Draw(masked_image)\n"," colors = np.array(colors).astype(int) * 255\n","\n"," for index, value in enumerate(useful_mask_indices):\n"," class_id = class_ids[value]\n"," score = scores[value]\n"," label = class_names[class_id]\n","\n"," y1, x1, y2, x2 = boxes[value]\n"," if mode != 2:\n"," color = tuple(colors[index])\n"," draw.rectangle((x1, y1, x2, y2), outline=color)\n","\n"," # Label\n"," font = ImageFont.load_default()\n"," draw.text((x1, y1), \"%s %f\" % (label, score), (255, 255, 255), font)\n","\n"," masked_image.save(os.path.join(save_dir, '%s' % (image_name)))\n","\n","def pdf_export(config, trained = False, augmentation = False, pretrained_model = False):\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," config_list = \"\"\n"," for a in dir(config):\n"," if not a.startswith(\"__\") and not callable(getattr(config, a)):\n"," config_list += \"{}: {}\\n\".format(a, getattr(config, a))\n"," \n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'MaskRCNN'\n"," day = datetime.datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Training report for '+Network+' model ('+model_name+'):\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," # add another cell\n"," if trained:\n"," training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n"," pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n"," pdf.ln(1)\n","\n"," Header_2 = 'Information for your materials and methods:'\n"," pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n"," #print(all_packages)\n","\n"," #Main Packages\n"," main_packages = ''\n"," version_numbers = []\n"," for name in ['tensorflow','numpy','Keras']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n"," cuda_version = sp.run('nvcc --version',stdout=sp.PIPE, shell=True)\n"," cuda_version = cuda_version.stdout.decode('utf-8')\n"," cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n"," gpu_name = sp.run('nvidia-smi',stdout=sp.PIPE, shell=True)\n"," gpu_name = gpu_name.stdout.decode('utf-8')\n"," gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n"," #print(cuda_version[cuda_version.find(', V')+3:-1])\n"," #print(gpu_name)\n"," try:\n"," shape = io.imread(Training_source+'/Training/'+os.listdir(Training_source+'/Training')[0]).shape\n"," except:\n"," shape = io.imread(Training_source+'/Training/'+os.listdir(Training_source+'/Training')[0][:-4]).shape\n"," dataset_size = len(os.listdir(Training_source))/2\n","\n"," text = 'The '+Network+' model was trained using weights initialised on the coco dataset for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' labelled images (image dimensions: '+str(shape)+') with a batch size of '+str(config.BATCH_SIZE)+' and custom loss functions for region proposal and classification, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," if pretrained_model:\n"," text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' labelled images (image dimensions: '+str(shape)+') with a batch size of '+str(config.BATCH_SIZE)+' and custom loss functions for region proposal and classification, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was retrained from a previous model checkpoint (model: '+os.path.basename(pretrained_model_path)[:-8]+', checkpoint: '+str(int(pretrained_model_path[-7:-3]))+'). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.multi_cell(190, 5, txt = text, align='L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n"," pdf.set_font('')\n"," if augmentation:\n"," aug_text = 'The dataset was augmented by vertical and horizontal flipping'\n"," # if multiply_dataset_by >= 2:\n"," # aug_text = aug_text+'\\n- flipping'\n"," # if multiply_dataset_by > 2:\n"," # aug_text = aug_text+'\\n- rotation'\n"," else:\n"," aug_text = 'No augmentation was used for training.'\n"," pdf.multi_cell(190, 5, txt=aug_text, align='L')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," # if Use_Default_Advanced_Parameters:\n"," # pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n"," pdf.cell(200, 5, txt='The following parameters were used for training:')\n"," pdf.ln(4)\n"," pdf.multi_cell(200, 5, txt=config_list)\n","\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," pdf.ln(1)\n"," pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(30, 5, txt= 'Training_source:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_source+'/Training', align = 'L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(29, 5, txt= 'Validation:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_source+'/Validation', align = 'L')\n"," #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(22, 5, txt= 'Model Path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n"," pdf.ln(1)\n"," pdf.cell(60, 5, txt = 'Example ground-truth annotation', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread('/content/TrainingDataExample_MaskRCNN.png').shape\n"," pdf.image('/content/TrainingDataExample_MaskRCNN.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- MaskRCNN: Kaiming He, Georgia Gkioxari, Piotr Dollár, Ross Girshick. \"Mask R - CNN\" arxiv. 2018.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," if augmentation:\n"," ref_3 = '- imgaug: Jung, Alexander et al., /~https://github.com/aleju/imgaug, (2020)'\n"," pdf.multi_cell(190, 5, txt = ref_3, align='L')\n"," pdf.ln(3)\n"," reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(os.path.dirname(model.log_dir)+'/'+model_name+'_training_report.pdf')\n","\n"," print('------------------------------')\n"," print('PDF report exported in '+model_path+'/'+model_name+'/')\n","\n","def qc_pdf_export():\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'MaskRCNN'\n","\n"," day = datetime.datetime.now()\n"," datetime_str = str(day)[0:16]\n","\n"," Header = 'Quality Control report for '+Network+' model ('+QC_model_name+', checkpoint:'+str(Checkpoint)+')\\nDate and Time: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(2)\n"," pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n"," pdf.ln(1)\n"," if os.path.exists(QC_model_folder+'/Quality Control/lossCurveAndmAPPlots.png'):\n"," exp_size = io.imread(QC_model_folder+'/Quality Control/lossCurveAndmAPPlots.png').shape\n"," pdf.image(QC_model_folder+'/Quality Control/lossCurveAndmAPPlots.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(80, 5, txt = 'P-R curves for test dataset', ln=1, align='L')\n"," pdf.ln(2)\n"," #for i in range(len(AP)):\n"," # os.path.exists(QC_model_folder+'/Quality Control/P-R_curve_'+config['model']['labels'][i]+'.png'):\n"," exp_size = io.imread(QC_model_folder+'/Quality Control/P-R_curve_'+QC_model_name+'.png').shape\n"," pdf.ln(1)\n"," pdf.image(QC_model_folder+'/Quality Control/P-R_curve_'+QC_model_name+'.png', x=16, y=None, w=round(exp_size[1]/4), h=round(exp_size[0]/4))\n"," # else:\n"," # pdf.cell(100, 5, txt='For the class '+config['model']['labels'][i]+' the model did not predict any objects.', ln=1, align='L')\n"," pdf.ln(3)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n","\n"," pdf.ln(1)\n"," html = \"\"\"\n"," \n"," \n"," \"\"\"\n"," with open(QC_model_folder+'/Quality Control/QC_results.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," class_name = header[0]\n"," gt = header[1]\n"," tp = header[2]\n"," fn = header[3]\n"," iou = header[4]\n"," mAP = header[5]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(class_name,gt,tp,fn,iou,mAP)\n"," html = html+header\n"," i=0\n"," for row in metrics:\n"," i+=1\n"," class_name = row[0]\n"," gt = row[1]\n"," tp = row[2]\n"," fn = row[3]\n"," iou = row[4]\n"," mAP = row[5]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(class_name,str(gt),str(tp),str(fn),str(iou),str(mAP))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}{3}{4}{5}
{0}{1}{2}{3}{4}{5}
\"\"\"\n","\n"," pdf.write_html(html)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(3)\n"," pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n"," pdf.ln(3)\n"," exp_size = io.imread(QC_model_folder+'/Quality Control/QC_example_data.png').shape\n"," pdf.image(QC_model_folder+'/Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/10))\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.ln(3)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- MaskRCNN: Kaiming He, Georgia Gkioxari, Piotr Dollár, Ross Girshick. \"Mask R - CNN\" arxiv. 2018.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n"," pdf.ln(3)\n"," reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(QC_model_folder+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n","\n","\n"," print('------------------------------')\n"," print('PDF report exported in '+QC_model_folder+'/Quality Control/')\n","\n","\n","# Check if this is the latest version of the notebook\n","All_notebook_versions = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_Notebook_versions.csv\", dtype=str)\n","print('Notebook version: '+Notebook_version)\n","Latest_Notebook_version = All_notebook_versions[All_notebook_versions[\"Notebook\"] == Network]['Version'].iloc[0]\n","print('Latest notebook version: '+Latest_Notebook_version)\n","if Notebook_version == Latest_Notebook_version:\n"," print(\"This notebook is up-to-date.\")\n","else:\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","\n","# Build requirements file for local run\n","after = [str(m) for m in sys.modules]\n","build_requirements_file(before, after)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"s7_nokQv7M4-"},"source":["# **2. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"5-hsYVdkjKuI"},"source":["\n","## **2.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"cellView":"form","id":"-goWypUVEvnp"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"L_pjmwONjTvb"},"source":["## **2.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"cellView":"form","id":"QK-DDu1ljVna"},"source":["#@markdown ##Run this cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","#mounts user's Google Drive to Google Colab.\n","\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"P-YFjdLR-5hv"},"source":["** If you cannot see your files, reactivate your session by connecting to your hosted runtime.** \n","\n","\n","\"Example
Connect to a hosted runtime.
"]},{"cell_type":"markdown","metadata":{"id":"Do_LZbDmpJiZ"},"source":["# **3. Select your paths and parameters**\n","\n","---\n","\n","The code below allows the user to enter the paths to where the training data is and to define the training parameters.\n","\n","If your dataset is large, this step can take a while. \n","\n","**Note:** The BG class reported by MaskRCNN stands for 'background'. By default BG is the default class in MaskRCNN, so even if your dataset contains only one class, MaskRCNN will treat the dataset as a two-class set.\n"]},{"cell_type":"markdown","metadata":{"id":"M5QFEW-HpRdQ"},"source":["## **3.1. Setting the main training parameters**\n","---\n",""]},{"cell_type":"markdown","metadata":{"id":"vdLRX63upWcB"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`:** This is the path to your folder containing the subfolders *Training* and *Validation*, each containing images with their respective annotations. **If your files are not organised in this way, the notebook will NOT work. So make sure everything looks right!** To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten. **Note that MaskRCNN will add a timestamp to your model_name in the form: model_name*YearMonthDayTHourMinute***\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**Training parameters**\n","\n","**`Training Depth`:** Here, you can choose how much you want to train the network. MaskRCNN is already pretrained on a large dataset which means its weights are already initialised. This means it may not be necessary to train the full model to reach satisfactory results on your dataset. To get the most out of the model, we recommend training the headlayers first for ca. 30 epochs, and then retraining the same model with an increasing depth for further 10s of epochs. To do this, use the same model_name in this section, with any other needed parameters and then load the desired weights file in section 3.3. **Default value: Head layers only**\n","\n","**`number_of_epochs`:**Enter the number of epochs the networks will be trained for. Note that if you want to continue training a previously trained model, enter the final number of epochs you want to use, i.e. if your previous model was trained for 50 epochs and you want to train it to 80, enter 80 epochs here, not 30.\n","**Default value: 50**\n","\n","**`detection_confidence`:** The network will assign scores of confidence to any predictions of ROIs it makes on the dataset during training. The detection confidence here indicates what threshold score you want to apply for the network to use accept any predicted ROIs. We recommend starting low here. If you notice your network is giving you too many ROIs, then increase this value gradually. **Default value: 0**\n","\n","**`learning_rate:`** Input the initial value to be used as learning rate. The learning rate will decrease after 7 epochs if the validation loss does not improve. **Default value: 0.003**\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"kajoWCX8ps4O","cellView":"form"},"source":["#@markdown ###Path to training images:\n","\n","Training_source = \"\" #@param {type:\"string\"}\n","\n","# Ground truth images\n","#Training_validation = \"\" #@param {type:\"string\"}\n","\n","# model name and path\n","##@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","full_model_path = os.path.join(model_path,model_name)\n","# if os.path.exists(full_model_path):\n","# print(bcolors.WARNING+'Model folder already exists and will be overwritten.'+bcolors.NORMAL)\n","\n","# other parameters for training.\n","#@markdown ###Training Parameters\n","\n","Training_depth = \"3+ resnet layers\" #@param [\"Head_layers_only\", \"3+ resnet layers\", \"4+ resnet layers\", \"5+ resnet layers\", \"all layers\"]\n","##@markdown ###Advanced Parameters\n","\n","#Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","##@markdown ###If not, please input:\n","\n","number_of_epochs = 10#@param {type:\"integer\"}\n","\n","batch_size = 4#@param{type:\"integer\"}\n","\n","image_resize_mode = \"none\"\n","\n","detection_confidence = 0 #@param {type:\"number\"}\n","\n","region_proposal_nms_threshold = 0.9 #@param{type:\"number\"}\n","\n","learning_rate = 0.003 #@param {type:\"number\"}\n","\n","#@markdown ###Loss weights\n","\n","region_proposal_class_loss = 1#@param {type:\"number\"}\n","region_proposal_class_loss = float(region_proposal_class_loss)\n","\n","region_proposal_bbox_loss = 1#@param {type:\"number\"}\n","region_proposal_bbox_loss = float(region_proposal_bbox_loss)\n","\n","mrcnn_class_loss = 1#@param {type:\"number\"}\n","mrcnn_class_loss = float(mrcnn_class_loss)\n","\n","mrcnn_bbox_loss = 1#@param {type:\"number\"}\n","mrcnn_bbox_loss = float(mrcnn_bbox_loss)\n","\n","mrcnn_mask_loss = 1#@param {type:\"number\"}\n","mrcnn_mask_loss = float(mrcnn_mask_loss)\n","\n","# Path to trained weights file\n","COCO_WEIGHTS_PATH = os.path.join(ROOT_DIR, \"mask_rcnn_coco.h5\")\n","\n","# Directory to save logs and model checkpoints, if not provided\n","# through the command line argument --logs\n","DEFAULT_LOGS_DIR = model_path\n","\n","dataset_train = ClassDataset()\n","dataset_train.load_image_csv(Training_source, \"Training\")\n","dataset_train.prepare()\n","\n","print(\"Class Count: {}\".format(dataset_train.num_classes))\n","for i, info in enumerate(dataset_train.class_info):\n"," print(\"{:3}. {:50}\".format(i, info['name']))\n","\n","############################################################\n","# Configurations\n","############################################################\n","\n","\n","class ClassConfig(Config):\n"," \"\"\"Configuration for training on the toy dataset.\n"," Derives from the base Config class and overrides some values.\n"," \"\"\"\n"," # Give the configuration a recognizable name\n"," NAME = model_name\n","\n"," # We use a GPU with 12GB memory, which can fit two images.\n"," # Adjust down if you use a smaller GPU.\n"," IMAGES_PER_GPU = batch_size\n","\n"," # Number of classes (including background)\n"," NUM_CLASSES = len(dataset_train.class_names) # Background + nucleus\n","\n"," # Number of training steps per epoch\n"," STEPS_PER_EPOCH = (len(os.listdir(Training_source+\"/Training\"))/2) // IMAGES_PER_GPU\n"," VALIDATION_STEPS = (len(os.listdir(Training_source+\"/Validation\"))/2) // IMAGES_PER_GPU\n","\n"," # Skip detections with < 90% confidence\n"," # DETECTION_MIN_CONFIDENCE = detection_confidence\n","\n"," LEARNING_RATE = learning_rate\n","\n"," DETECTION_MIN_CONFIDENCE = 0\n","\n"," # Backbone network architecture\n"," # Supported values are: resnet50, resnet101\n"," BACKBONE = \"resnet101\"\n","\n"," # Input image resizing\n"," # Random crops of size 64x64\n"," IMAGE_RESIZE_MODE = image_resize_mode #\"crop\"\n"," IMAGE_MIN_DIM = 128\n"," IMAGE_MAX_DIM = 128\n"," IMAGE_MIN_SCALE = 2.0\n","\n"," # Length of square anchor side in pixels\n"," RPN_ANCHOR_SCALES = (8, 16, 32, 64, 128)\n","\n"," # ROIs kept after non-maximum supression (training and inference)\n"," POST_NMS_ROIS_TRAINING = 2000\n"," POST_NMS_ROIS_INFERENCE = 4000\n","\n"," # Non-max suppression threshold to filter RPN proposals.\n"," # You can increase this during training to generate more propsals.\n"," RPN_NMS_THRESHOLD = region_proposal_nms_threshold\n","\n"," # How many anchors per image to use for RPN training\n"," RPN_TRAIN_ANCHORS_PER_IMAGE = 128\n","\n"," # Image mean (RGB)\n"," MEAN_PIXEL = np.array([43.53, 39.56, 48.22])\n","\n"," # If enabled, resizes instance masks to a smaller size to reduce\n"," # memory load. Recommended when using high-resolution images.\n"," USE_MINI_MASK = False\n"," MINI_MASK_SHAPE = (56, 56) # (height, width) of the mini-mask\n","\n"," # Number of ROIs per image to feed to classifier/mask heads\n"," # The Mask RCNN paper uses 512 but often the RPN doesn't generate\n"," # enough positive proposals to fill this and keep a positive:negative\n"," # ratio of 1:3. You can increase the number of proposals by adjusting\n"," # the RPN NMS threshold.\n"," TRAIN_ROIS_PER_IMAGE = 128\n","\n"," # Maximum number of ground truth instances to use in one image\n"," MAX_GT_INSTANCES = 100\n","\n"," # Max number of final detections per image\n"," DETECTION_MAX_INSTANCES = 200\n","\n"," LOSS_WEIGHTS = {\n"," \"rpn_class_loss\": region_proposal_class_loss,\n"," \"rpn_bbox_loss\": region_proposal_bbox_loss,\n"," \"mrcnn_class_loss\": mrcnn_class_loss,\n"," \"mrcnn_bbox_loss\": mrcnn_bbox_loss,\n"," \"mrcnn_mask_loss\": mrcnn_mask_loss\n"," }\n","\n","if Training_depth == \"Head_layers_only\":\n"," layers = \"heads\"\n","elif Training_depth == \"3+ resnet layers\":\n"," layers = \"3+\"\n","elif Training_depth == \"4+ resnet layers\":\n"," layers = \"4+\"\n","elif Training_depth == \"5+ resnet layers\":\n"," layers = \"5+\"\n","else:\n"," layers = \"all\"\n","\n","config = ClassConfig()\n","# Training dataset\n","# dataset_train = ClassDataset()\n","# num_classes = dataset_train.load_image_csv(Training_source, \"Training\")\n","# dataset_train.prepare()\n","# print(\"Class Count: {}\".format(dataset_train.num_classes))\n","# for i, info in enumerate(dataset_train.class_info):\n","# print(\"{:3}. {:50}\".format(i, info['name']))\n","\n","# Load and display random samples\n","image_ids = np.random.choice(dataset_train.image_ids, 1)\n","for image_id in image_ids:\n"," image = dataset_train.load_image(image_id)\n"," mask, class_ids = dataset_train.load_mask(image_id)\n"," visualize.display_top_masks(image, mask, class_ids, dataset_train.class_names, limit=dataset_train.num_classes-1)\n","\n","# plt.savefig('/content/TrainingDataExample_MaskRCNN.png',bbox_inches='tight',pad_inches=0)\n","\n","# image, image_meta, class_ids, bbox, mask = modellib.load_image_gt(\n","# dataset_train, config, image_id, use_mini_mask=False)\n","\n","# visualize.display_instances(image, bbox, mask, class_ids, dataset_train.class_names,\n","# show_bbox=False)\n","model = modellib.MaskRCNN(mode=\"training\", config=config, model_dir=DEFAULT_LOGS_DIR)\n","config.display()\n","Use_pretrained_model = False\n","Use_Data_augmentation = False"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"PzWJwWFGlYZi"},"source":["##**3.2. Data augmentation**\n","\n","---\n","\n"," Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if the dataset the `Use_Data_Augmentation` box can be unticked.\n","\n"," If the box is ticked a simple augmentation of horizontal and vertical flipping will be applied to the dataset."]},{"cell_type":"code","metadata":{"id":"d0BwRHRElaSD","cellView":"form"},"source":["#@markdown ##**Augmentation Options**\n","\n","Use_Data_augmentation = True #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation == True:\n"," # Number of training steps per epoch\n"," class AugClassConfig(ClassConfig):\n"," STEPS_PER_EPOCH = 10*((len(os.listdir(Training_source+\"/Training\"))/2) // batch_size)\n"," VALIDATION_STEPS = 10*((len(os.listdir(Training_source+\"/Validation\"))/2) // batch_size)\n"," \n","if Use_Data_augmentation:\n"," config = AugClassConfig()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"uJjmzKGHk_p9"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a MaskRCNN model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**."]},{"cell_type":"code","metadata":{"id":"3JsrRmNbgNeL","cellView":"form"},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If yes, please provide the path to the model (this path should end with the file extension .h5):\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","if Use_Data_augmentation == True:\n"," config = AugClassConfig()\n","else:\n"," config = ClassConfig()\n","\n","model = modellib.MaskRCNN(mode=\"training\", config=config, model_dir=DEFAULT_LOGS_DIR)\n","model.load_weights(pretrained_model_path, by_name=True)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"rTWfoQEPuPad"},"source":["#**4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"CRPOHMNSo0Sj"},"source":["## **4.1. Train the network**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches."]},{"cell_type":"code","metadata":{"id":"Li__jcfsTzs6","cellView":"form"},"source":["#@markdown ##Start training\n","\n","pdf_export(config, augmentation = Use_Data_augmentation, pretrained_model=Use_pretrained_model)\n","\n","if os.path.exists(model.log_dir+\"/Quality Control\"):\n"," shutil.rmtree(model.log_dir+\"/Quality Control\")\n","os.makedirs(model.log_dir+\"/Quality Control\")\n","\n","start = time.time()\n","#Here, we start the model training\n","train_csv(model, Training_source, augmentation=Use_Data_augmentation, epochs = number_of_epochs, layers = layers)\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","new_model_name = os.path.basename(model.log_dir)\n","#Here, we just save some interesting parameters from training as a csv file\n","if not os.path.exists(model_path+'/'+new_model_name+'/Quality Control/class_names.csv'):\n"," with open(model_path+'/'+new_model_name+'/Quality Control/class_names.csv','w') as class_count_csv:\n"," class_writer = csv.writer(class_count_csv)\n"," for class_name in dataset_train.class_names:\n"," class_writer.writerow([class_name])\n","\n","if os.path.exists(model_path+'/'+new_model_name+'/Quality Control/training_evaluation.csv'):\n"," with open(model_path+'/'+new_model_name+'/Quality Control/training_evaluation.csv','a') as csvfile:\n"," writer = csv.writer(csvfile)\n"," #print('hello')\n"," #writer.writerow(['epoch','loss','val_loss','learning rate'])\n"," model_starting_checkpoint = int(pretrained_model_path[-7:-3])\n"," for i in range(len(model.keras_model.history.history['loss'])):\n"," writer.writerow([str(model_starting_checkpoint+i),model.keras_model.history.history['loss'][i], str(learning_rate)])\n","else:\n"," with open(model_path+'/'+new_model_name+'/Quality Control/training_evaluation.csv','w') as csvfile:\n"," writer = csv.writer(csvfile)\n"," writer.writerow(['epoch','loss','val_loss','learning rate'])\n"," for i in range(len(model.keras_model.history.history['loss'])):\n"," writer.writerow([str(i+1),model.keras_model.history.history['loss'][i], model.keras_model.history.history['val_loss'][i], str(learning_rate)])\n","\n","pdf_export(config, trained=True, augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n0-RUNbruHa6"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"cellView":"form","id":"b10mT10YtngQ"},"source":["#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the name of the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","if (Use_the_current_trained_model): \n"," QC_model_folder = model_path+'/'+new_model_name\n","\n","QC_model_name = os.path.basename(QC_model_folder)\n","\n","if os.path.exists(QC_model_folder):\n"," print(\"The \"+QC_model_name+\" model will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path before proceeding further.')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xOOXTMHkLqYq"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased.\n","\n","In this notebook, the training loss curves are plotted using **tensorboard**. However, all the training results are also logged in a csv file in your model folder."]},{"cell_type":"code","metadata":{"cellView":"form","id":"-BpIBHDiOTqK"},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","if os.path.exists(QC_model_folder):\n"," os.chdir(QC_model_folder)\n"," %load_ext tensorboard\n"," %tensorboard --logdir \"$QC_model_folder\"\n","else:\n"," print(\"The chosen model or path does not exist. Check if your model_name was saved with a timestamp.\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"PdJFjEXRKApD"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display an overlay of the input images ground-truth (solid lines) and predicted boxes (dashed lines). Additionally, the below cell will show the mAP value of the model on the QC data together with plots of the Precision-Recall curves for all the classes in the dataset. If you want to read in more detail about these scores, we recommend [this brief explanation](https://medium.com/@jonathan_hui/map-mean-average-precision-for-object-detection-45c121a31173).\n","\n"," In a nutshell:\n","\n","**Precision:** This is the proportion of the correct classifications (true positives) in all the predictions made by the model.\n","\n","**Recall:** This is the proportion of the detected true positives in all the detectable data.\n","\n"," The files provided in the \"QC_data_folder\" should be under a subfolder called validation which contains the images (e.g. as .jpg) and annotations (.csv files)!"]},{"cell_type":"code","metadata":{"id":"8yhm7a3gAFdK","cellView":"form"},"source":["#@markdown ### Provide the path to your quality control dataset.\n","DEFAULT_LOGS_DIR = \"/content/gdrive/MyDrive\"\n","QC_data_folder = \"\" #@param {type:\"string\"}\n","#Result_folder = \"\" #@param {type:\"string\"}\n","\n","# model name and path\n","\n","#Use_the_current_trained_model = False #@param {type:\"boolean\"}\n","\n","#@markdown #####During training, the model files are automatically saved inside a folder named after model_name in section 3. Provide the path to this folder below.\n","#QC_model_folder = \"/content/gdrive/MyDrive/maskrcnn_nucleus20210202T1206\" #@param {type:\"string\"}\n","\n","#@markdown ###Choose the checkpoint you want to evauluate:\n","Checkpoint = 8#@param {type:\"integer\"}\n","\n","#Load the dataset\n","dataset_val = ClassDataset()\n","dataset_val.load_image_csv(QC_data_folder, \"Validation\")\n","dataset_val.prepare()\n","\n","# Activate the (pre-)trained model\n","\n","detection_min_confidence = 0.35 #@param{type:\"number\"}\n","region_proposal_nms_threshold = 0.99 #@param{type:\"number\"}\n","resize_mode = \"none\" #@param[\"none\",\"square\",\"crop\",\"pad64\"]\n","\n","class InferenceConfig(ClassConfig):\n"," IMAGE_RESIZE_MODE = resize_mode\n"," RPN_NMS_THRESHOLD = region_proposal_nms_threshold\n"," NAME = \"nucleus\"\n"," IMAGES_PER_GPU = 1\n"," # Number of classes (including background)\n"," DETECTION_MIN_CONFIDENCE = detection_min_confidence\n"," NUM_CLASSES = len(dataset_val.class_names) # Background + nucleus\n"," RPN_ANCHOR_SCALES = (8, 16, 32, 64, 128)\n"," POST_NMS_ROIS_INFERENCE = 15000\n","inference_config = InferenceConfig()\n","\n","# Recreate the model in inference mode\n","#if Use_the_current_trained_model:\n","model = modellib.MaskRCNN(mode=\"inference\", \n"," config=inference_config,\n"," model_dir=QC_model_folder)\n","# else:\n","# model = modellib.MaskRCNN(mode=\"inference\", \n","# config=inference_config,\n","# model_dir=QC_model_folder)\n","\n","# Get path to saved weights\n","if Checkpoint < 10:\n"," qc_model_path = QC_model_folder+\"/mask_rcnn_\"+QC_model_name[:-13]+\"_000\"+str(Checkpoint)+\".h5\"\n","elif Checkpoint < 100:\n"," qc_model_path = QC_model_folder+\"/mask_rcnn_\"+QC_model_name[:-13]+\"_00\"+str(Checkpoint)+\".h5\"\n","elif Checkpoint < 1000:\n"," qc_model_path = QC_model_folder+\"/mask_rcnn_\"+QC_model_name[:-13]+\"_0\"+str(Checkpoint)+\".h5\"\n","\n","# Load trained weights\n","print(\"Loading weights from \", qc_model_path)\n","model.load_weights(qc_model_path, by_name=True)\n","\n","# dataset_val = ClassDataset()\n","# num_classes = dataset_val.load_image_csv(QC_data_folder, \"Validation\")\n","# dataset_val.prepare()\n","\n","image_id = random.choice(dataset_val.image_ids)\n","original_image, image_meta, gt_class_id, gt_bbox, gt_mask =\\\n"," modellib.load_image_gt(dataset_val, inference_config, \n"," image_id, use_mini_mask=False)\n","\n","results = model.detect([original_image], verbose=1)\n","r = results[0]\n","visualize.display_differences(original_image, gt_bbox, gt_class_id, gt_mask, r['rois'], r['class_ids'], r['scores'], r['masks'], dataset_val.class_names, iou_threshold = 0.8, score_threshold= 0.8)\n","# visualize.display_instances(original_image, gt_bbox, gt_mask, gt_class_id, \n","# dataset_val.class_names, figsize=(8, 8))\n","\n","save_image(original_image, \"QC_example_data.png\", r['rois'], r['masks'],\n"," r['class_ids'],r['scores'],dataset_val.class_names,\n"," scores_thresh=0,mode=0,save_dir=QC_model_folder+'/Quality Control')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"IuXEDjWAK6pO"},"source":["##**5.3. Precision-Recall Curve**\n","\n"," The p-r curve can give a quantification how well the model\n","Since the training saves model checkpoints for each epoch, you should choose which one you want to use for quality control in the `Checkpoint` box."]},{"cell_type":"code","metadata":{"id":"lzoGZUoCxpSc","cellView":"form"},"source":["#@markdown ###Show the precision-recall curve of the QC data\n","#@markdown Choose an IoU threshold for the p-r plot (between 0 and 1), ignore that the plot title says AP@50:\n","\n","iou_threshold = 0.3 #@param{type:\"number\"}\n","mAP, precisions, recalls, overlaps = utils.compute_ap(gt_bbox, gt_class_id, gt_mask,\n"," r['rois'], r['class_ids'], r['scores'], r['masks'],\n"," iou_threshold=iou_threshold)\n","visualize.plot_precision_recall(mAP, precisions, recalls)\n","plt.savefig(QC_model_folder+'/Quality Control/P-R_curve_'+QC_model_name+'.png',bbox_inches='tight',pad_inches=0)\n","\n","gt_match, pred_match, overlaps = utils.compute_matches(gt_bbox, gt_class_id, gt_mask, r['rois'], r['class_ids'], r['scores'], r['masks'])\n","\n","#TO DO: Implement for multiclasses\n","if len(dataset_val.class_names) == 2:\n"," with open (QC_model_folder+'/Quality Control/QC_results.csv','w') as csvfile:\n"," writer = csv.writer(csvfile)\n"," writer.writerow(['class','gt instances','True positives','False Negatives', 'IoU threshold', 'mAP'])\n"," for index in dataset_val.class_names:\n"," if index != 'BG':\n"," writer.writerow([index, str(len(gt_match)), str(len(pred_match)), str(len(gt_match)-len(pred_match)), str(iou_threshold), str(mAP)])\n"," qc_pdf_export()\n","else:\n"," print('Your dataset has more than one class. This means certain features may not be enabled. We are working on implementing this section fully for multiple classes.')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"MGBi1lB2vSOr"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"HrQPXU0DvWIT"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images."]},{"cell_type":"code","metadata":{"id":"7FttSetXvdTB","cellView":"form"},"source":["#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then play the cell to predict outputs from your unseen images.\n","DEFAULT_LOGS_DIR = \"/content/gdrive/MyDrive\"\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, provide the name of the model and path to model folder:\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","if Use_the_current_trained_model:\n"," Prediction_model_folder = model_path+'/'+new_model_name\n","\n","#@markdown ###Choose the checkpoint you want to evaluate:\n","Checkpoint = 8#@param {type:\"integer\"}\n","\n","if os.path.exists(Prediction_model_folder+'/Quality Control/class_names.csv'):\n"," print('Prediction classes detected! The model will predict the following classes:')\n"," class_names = []\n"," with open(Prediction_model_folder+'/Quality Control/class_names.csv', 'r') as class_names_csv:\n"," csvreader = csv.reader(class_names_csv)\n"," for row in csvreader:\n"," print(row[0])\n"," class_names.append(row[0])\n","\n","\n","detection_min_confidence = 0.1 #@param{type:\"number\"}\n","region_proposal_nms_threshold = 0.99 #@param{type:\"number\"}\n","resize_mode = \"none\" #@param[\"none\",\"square\",\"crop\",\"pad64\"]\n","post_nms_rois = 10000 #@param{type:\"integer\"}\n","\n","\n","#Load the dataset\n","dataset_val = ClassDataset()\n","dataset_val.load_image_csv(Data_folder, \"Validation\")\n","dataset_val.prepare()\n","\n"," # Activate the (pre-)trained model\n","class InferenceConfig(ClassConfig):\n"," IMAGE_RESIZE_MODE = resize_mode\n"," IMAGE_MIN_DIM = 128\n"," IMAGE_MAX_DIM = 128\n"," IMAGE_MIN_SCALE = 2.0\n"," RPN_NMS_THRESHOLD = region_proposal_nms_threshold\n"," #DETECTION_NMS_THRESHOLD = 0.0\n"," NAME = \"nucleus\"\n"," IMAGES_PER_GPU = 1\n"," # Number of classes (including background)\n"," DETECTION_MIN_CONFIDENCE = detection_min_confidence\n"," NUM_CLASSES = len(dataset_val.class_names) # Background + nucleus\n"," RPN_ANCHOR_SCALES = (8, 16, 32, 64, 128)\n"," POST_NMS_ROIS_INFERENCE = post_nms_rois\n","\n","inference_config = InferenceConfig()\n","\n","# Recreate the model in inference mode\n","model = modellib.MaskRCNN(mode=\"inference\", \n"," config=inference_config,\n"," model_dir=Prediction_model_folder)\n","\n","# Get path to saved weights\n","if Checkpoint < 10:\n"," pred_model_path = Prediction_model_folder+\"/mask_rcnn_\"+os.path.basename(Prediction_model_folder[:-13])+\"_000\"+str(Checkpoint)+\".h5\"\n","elif Checkpoint < 100:\n"," pred_model_path = Prediction_model_folder+\"/mask_rcnn_\"+os.path.basename(Prediction_model_folder[:-13])+\"_00\"+str(Checkpoint)+\".h5\"\n","elif Checkpoint < 1000:\n"," pred_model_path = Prediction_model_folder+\"/mask_rcnn_\"+os.path.basename(Prediction_model_folder[:-13])+\"_0\"+str(Checkpoint)+\".h5\"\n","\n","# Load trained weights\n","print(\"Loading weights from \", pred_model_path)\n","model.load_weights(pred_model_path, by_name=True)\n","\n","#@markdown ###Choose how you would like to export the predictions:\n","Export_mode = \"image with class_name,score and mask\" #@param[\"image with bbox, class_name, scores, masks\",\"image with bbox,class_name and score\",\"image with class_name,score and mask\",\"mask with black background\"]\n","if Export_mode == \"image with bbox, class_name, scores, masks\":\n"," export_mode = 0\n","elif Export_mode == \"image with bbox,class_name and score\":\n"," export_mode = 1\n","elif Export_mode == \"image with class_name,score and mask\":\n"," export_mode = 2\n","elif Export_mode == \"mask with black background\":\n"," export_mode = 3\n","\n","\n","file_path = os.path.join(Data_folder, 'Validation')\n","for input in os.listdir(file_path):\n"," if input.endswith('.png'):\n"," image = io.imread(os.path.join(file_path,input))\n"," results = model.detect([image], verbose=0)\n"," r = results[0]\n"," save_image(image, \"predicted_\"+input, r['rois'], r['masks'],\n"," r['class_ids'],r['scores'],class_names,\n"," scores_thresh=0,mode=export_mode,save_dir=Result_folder)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Yu4OGubv59qa"},"source":["## **6.2. Inspect the predicted output**\n","---\n"]},{"cell_type":"code","metadata":{"id":"YnWgQZmlIuv9","cellView":"form"},"source":["#@markdown ##Run this cell to display a randomly chosen input with predicted mask.\n","\n","detection_min_confidence = 0.1 #@param{type:\"number\"}\n","region_proposal_nms_threshold = 0.99 #@param{type:\"number\"}\n","resize_mode = \"none\" #@param[\"none\",\"square\",\"crop\",\"pad64\"]\n","post_nms_rois = 10000 #@param{type:\"integer\"}\n","\n"," # Activate the (pre-)trained model\n","class InferenceConfig(ClassConfig):\n"," IMAGE_RESIZE_MODE = resize_mode\n"," IMAGE_MIN_DIM = 128\n"," IMAGE_MAX_DIM = 128\n"," IMAGE_MIN_SCALE = 2.0\n"," RPN_NMS_THRESHOLD = region_proposal_nms_threshold\n"," #DETECTION_NMS_THRESHOLD = 0.0\n"," NAME = \"nucleus\"\n"," IMAGES_PER_GPU = 1\n"," # Number of classes (including background)\n"," DETECTION_MIN_CONFIDENCE = detection_min_confidence\n"," NUM_CLASSES = len(dataset_val.class_names) # Background + nucleus\n"," RPN_ANCHOR_SCALES = (8, 16, 32, 64, 128)\n"," POST_NMS_ROIS_INFERENCE = post_nms_rois\n","\n","inference_config = InferenceConfig()\n","\n","\n","model = modellib.MaskRCNN(mode=\"inference\", \n"," config=inference_config,\n"," model_dir=Prediction_model_folder)\n","\n","model.load_weights(pred_model_path, by_name=True)\n","example_image = random.choice(os.listdir(os.path.join(Data_folder,'Validation')))\n","\n","if example_image.endswith('.csv'):\n"," example_image = example_image[:-4]\n","\n","display_image = io.imread(file_path+'/'+example_image)\n","results = model.detect([display_image], verbose=0)\n","\n","r = results[0]\n","\n","visualize.display_instances(display_image, r['rois'], r['masks'], r['class_ids'], \n"," class_names, r['scores'], ax=get_ax())"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"BrosGM4Z50gX"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"JYfEsBazHhkW"},"source":["# **7. Version log**\n","---\n","**v1.13**: \n","\n","\n","* This notebook is new as ZeroCostDL4Mic version 1.13. and is currently a beta version. \n","* Further edits to this notebook in future versions will be updated in this cell."]},{"cell_type":"markdown","metadata":{"id":"F3zreN5K5S2S"},"source":["#**Thank you for using MaskRCNN**!"]}]} \ No newline at end of file diff --git a/Colab_notebooks/Beta notebooks/U-Net_2D_ZeroCostDL4Mic_BioImageModelZoo_export.ipynb b/Colab_notebooks/Beta notebooks/U-Net_2D_ZeroCostDL4Mic_BioImageModelZoo_export.ipynb index b024b144..5d16aa6d 100644 --- a/Colab_notebooks/Beta notebooks/U-Net_2D_ZeroCostDL4Mic_BioImageModelZoo_export.ipynb +++ b/Colab_notebooks/Beta notebooks/U-Net_2D_ZeroCostDL4Mic_BioImageModelZoo_export.ipynb @@ -1,2422 +1 @@ -{ - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "U-Net_2D_ZeroCostDL4Mic_BioImageModelZoo_export.ipynb", - "provenance": [], - "collapsed_sections": [], - "toc_visible": true, - "machine_shape": "hm" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.7" - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "accelerator": "GPU" - }, - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "IkSguVy8Xv83" - }, - "source": [ - "# **U-Net (2D)**\n", - "---\n", - "\n", - "U-Net is an encoder-decoder network architecture originally used for image segmentation, first published by [Ronneberger *et al.*](https://arxiv.org/abs/1505.04597). The first half of the U-Net architecture is a downsampling convolutional neural network which acts as a feature extractor from input images. The other half upsamples these results and restores an image by combining results from downsampling with the upsampled images.\n", - "\n", - " **This particular notebook enables image segmentation of 2D dataset. If you are interested in 3D dataset, you should use the 3D U-Net notebook instead.**\n", - "\n", - "---\n", - "*Disclaimer*:\n", - "\n", - "This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n", - "\n", - "This notebook is largely based on the papers: \n", - "\n", - "**U-Net: Convolutional Networks for Biomedical Image Segmentation** by Ronneberger *et al.* published on arXiv in 2015 (https://arxiv.org/abs/1505.04597)\n", - "\n", - "and \n", - "\n", - "**U-Net: deep learning for cell counting, detection, and morphometry** by Thorsten Falk *et al.* in Nature Methods 2019\n", - "(https://www.nature.com/articles/s41592-018-0261-2)\n", - "And source code found in: /~https://github.com/zhixuhao/unet by *Zhixuhao*\n", - "\n", - "**Please also cite this original paper when using or developing this notebook.** " - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jWAz2i7RdxUV" - }, - "source": [ - "# **How to use this notebook?**\n", - "\n", - "---\n", - "\n", - "Video describing how to use our notebooks are available on youtube:\n", - " - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n", - " - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n", - "\n", - "\n", - "\n", - "---\n", - "###**Structure of a notebook**\n", - "\n", - "The notebook contains two types of cell: \n", - "\n", - "**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n", - "\n", - "**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n", - "\n", - "---\n", - "###**Table of contents, Code snippets** and **Files**\n", - "\n", - "On the top left side of the notebook you find three tabs which contain from top to bottom:\n", - "\n", - "*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n", - "\n", - "*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n", - "\n", - "*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n", - "\n", - "**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n", - "\n", - "**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n", - "\n", - "---\n", - "###**Making changes to the notebook**\n", - "\n", - "**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n", - "\n", - "To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n", - "You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "gKDLkLWUd-YX" - }, - "source": [ - "# **0. Before getting started**\n", - "---\n", - "\n", - "Before you run the notebook, please ensure that you are logged into your Google account and have the training and/or data to process in your Google Drive.\n", - "\n", - "For U-Net to train, **it needs to have access to a paired training dataset corresponding to images and their corresponding masks**. Information on how to generate a training dataset is available in our Wiki page: /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n", - "\n", - "**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n", - "\n", - "Additionally, the corresponding Training_source and Training_target files need to have **the same name**.\n", - "\n", - "Here's a common data structure that can work:\n", - "* Experiment A\n", - " - **Training dataset**\n", - " - Training_source\n", - " - img_1.tif, img_2.tif, ...\n", - " - Training_target\n", - " - img_1.tif, img_2.tif, ...\n", - " - **Quality control dataset**\n", - " - Training_source\n", - " - img_1.tif, img_2.tif\n", - " - Training_target \n", - " - img_1.tif, img_2.tif\n", - " - **Data to be predicted**\n", - " - **Results**\n", - "\n", - "---\n", - "**Important note**\n", - "\n", - "- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n", - "\n", - "- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n", - "\n", - "- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "n4yWFoJNnoin" - }, - "source": [ - "# **1. Initialise the Colab session**\n", - "\n", - "\n", - "\n", - "\n", - "---\n", - "\n", - "\n", - "\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "DMNHVZfHmbKb" - }, - "source": [ - "\n", - "## **1.1. Check for GPU access**\n", - "---\n", - "\n", - "By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n", - "\n", - "Go to **Runtime -> Change the Runtime type**\n", - "\n", - "**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n", - "\n", - "**Accelerator: GPU** *(Graphics processing unit)*\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "zCvebubeSaGY", - "cellView": "form" - }, - "source": [ - "#@markdown ##Run this cell to check if you have GPU access\n", - "%tensorflow_version 1.x\n", - "\n", - "import tensorflow as tf\n", - "if tf.test.gpu_device_name()=='':\n", - " print('You do not have GPU access.') \n", - " print('Did you change your runtime ?') \n", - " print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n", - " print('Expect slow performance. To access GPU try reconnecting later')\n", - "\n", - "else:\n", - " print('You have GPU access')\n", - " !nvidia-smi\n", - "\n", - "# from tensorflow.python.client import device_lib \n", - "# device_lib.list_local_devices()\n", - "\n", - "# print the tensorflow version\n", - "print('Tensorflow version is ' + str(tf.__version__))\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "sNIVx8_CLolt" - }, - "source": [ - "## **1.2. Mount your Google Drive**\n", - "---\n", - " To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n", - "\n", - " Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n", - "\n", - " Once this is done, your data are available in the **Files** tab on the top left of notebook." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "01Djr8v-5pPk", - "cellView": "form" - }, - "source": [ - "#@markdown ##Play the cell to connect your Google Drive to Colab\n", - "\n", - "#@markdown * Click on the URL. \n", - "\n", - "#@markdown * Sign in your Google Account. \n", - "\n", - "#@markdown * Copy the authorization code. \n", - "\n", - "#@markdown * Enter the authorization code. \n", - "\n", - "#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n", - "\n", - "# mount user's Google Drive to Google Colab.\n", - "from google.colab import drive\n", - "drive.mount('/content/gdrive')" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "AdN8B91xZO0x" - }, - "source": [ - "# **2. Install U-Net dependencies**\n", - "---\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "UGWnGOFsf07b" - }, - "source": [ - "## **2.1. Install key dependencies**\n", - "---\n", - " " - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "uc0haIa-fZiG", - "cellView": "form" - }, - "source": [ - "#@markdown ##Play to install 2D U-Net dependencies\n", - "!pip install pydeepimagej==2.1.2\n", - "!pip install data\n", - "!pip install fpdf\n", - "!pip install h5py==2.10\n", - "\n", - "#Force session restart\n", - "# exit(0)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "I4O5zctbf4Gb" - }, - "source": [ - "## **2.2. Restart your runtime**\n", - "---\n", - " " - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "F4jMunHMfq_c" - }, - "source": [ - "** Skip this step if you already restarted the runtime.**\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "iiX3Ly-7gA5h" - }, - "source": [ - "## **2.3. Load key dependencies**\n", - "---\n", - " " - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "fq21zJVFNASx", - "cellView": "form" - }, - "source": [ - "Notebook_version = ['1.12.1']\n", - "\n", - "from builtins import any as b_any\n", - "\n", - "def get_requirements_path():\n", - " # Store requirements file in 'contents' directory \n", - " current_dir = os.getcwd()\n", - " dir_count = current_dir.count('/') - 1\n", - " path = '../' * (dir_count) + 'requirements.txt'\n", - " return path\n", - "\n", - "def filter_files(file_list, filter_list):\n", - " filtered_list = []\n", - " for fname in file_list:\n", - " if b_any(fname.split('==')[0] in s for s in filter_list):\n", - " filtered_list.append(fname)\n", - " return filtered_list\n", - "\n", - "def build_requirements_file(before, after):\n", - " path = get_requirements_path()\n", - "\n", - " # Exporting requirements.txt for local run\n", - " !pip freeze > $path\n", - "\n", - " # Get minimum requirements file\n", - " df = pd.read_csv(path, delimiter = \"\\n\")\n", - " mod_list = [m.split('.')[0] for m in after if not m in before]\n", - " req_list_temp = df.values.tolist()\n", - " req_list = [x[0] for x in req_list_temp]\n", - "\n", - " # Replace with package name and handle cases where import name is different to module name\n", - " mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]\n", - " mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list] \n", - " filtered_list = filter_files(req_list, mod_replace_list)\n", - "\n", - " file=open(path,'w')\n", - " for item in filtered_list:\n", - " file.writelines(item + '\\n')\n", - "\n", - " file.close()\n", - "\n", - "import sys\n", - "before = [str(m) for m in sys.modules]\n", - "\n", - "#@markdown ##Load key U-Net dependencies\n", - "\n", - "#As this notebokk depends mostly on keras which runs a tensorflow backend (which in turn is pre-installed in colab)\n", - "#only the data library needs to be additionally installed.\n", - "%tensorflow_version 1.x\n", - "import tensorflow as tf\n", - "# print(tensorflow.__version__)\n", - "# print(\"Tensorflow enabled.\")\n", - "\n", - "\n", - "# Keras imports\n", - "from keras import models\n", - "from keras.models import Model, load_model\n", - "from keras.layers import Input, Conv2D, MaxPooling2D, Dropout, concatenate, UpSampling2D\n", - "from keras.optimizers import Adam\n", - "# from keras.callbacks import ModelCheckpoint, LearningRateScheduler, CSVLogger # we currently don't use any other callbacks from ModelCheckpoints\n", - "from keras.callbacks import ModelCheckpoint\n", - "from keras.callbacks import ReduceLROnPlateau\n", - "from keras.preprocessing.image import ImageDataGenerator, img_to_array, load_img\n", - "from keras import backend as keras\n", - "\n", - "# General import\n", - "from __future__ import print_function\n", - "import numpy as np\n", - "import pandas as pd\n", - "import os\n", - "import glob\n", - "from skimage import img_as_ubyte, io, transform\n", - "import matplotlib as mpl\n", - "from matplotlib import pyplot as plt\n", - "from matplotlib.pyplot import imread\n", - "from pathlib import Path\n", - "import shutil\n", - "import random\n", - "import time\n", - "import csv\n", - "import sys\n", - "from math import ceil\n", - "from fpdf import FPDF, HTMLMixin\n", - "from pip._internal.operations.freeze import freeze\n", - "import subprocess\n", - "# Imports for QC\n", - "from PIL import Image\n", - "from scipy import signal\n", - "from scipy import ndimage\n", - "from sklearn.linear_model import LinearRegression\n", - "from skimage.util import img_as_uint\n", - "from skimage.metrics import structural_similarity\n", - "from skimage.metrics import peak_signal_noise_ratio as psnr\n", - "\n", - "# For sliders and dropdown menu and progress bar\n", - "from ipywidgets import interact\n", - "import ipywidgets as widgets\n", - "# from tqdm import tqdm\n", - "from tqdm.notebook import tqdm\n", - "\n", - "from sklearn.feature_extraction import image\n", - "from skimage import img_as_ubyte, io, transform\n", - "from skimage.util.shape import view_as_windows\n", - "\n", - "from datetime import datetime\n", - "\n", - "\n", - "# Suppressing some warnings\n", - "import warnings\n", - "warnings.filterwarnings('ignore')\n", - "\n", - "\n", - "\n", - "\n", - "def create_patches(Training_source, Training_target, patch_width, patch_height, min_fraction):\n", - " \"\"\"\n", - " Function creates patches from the Training_source and Training_target images. \n", - " The steps parameter indicates the offset between patches and, if integer, is the same in x and y.\n", - " Saves all created patches in two new directories in the /content folder.\n", - "\n", - " min_fraction is the minimum fraction of pixels that need to be foreground to be considered as a valid patch\n", - "\n", - " Returns: - Two paths to where the patches are now saved\n", - " \"\"\"\n", - " DEBUG = False\n", - "\n", - " Patch_source = os.path.join('/content','img_patches')\n", - " Patch_target = os.path.join('/content','mask_patches')\n", - " Patch_rejected = os.path.join('/content','rejected')\n", - " \n", - "\n", - " #Here we save the patches, in the /content directory as they will not usually be needed after training\n", - " if os.path.exists(Patch_source):\n", - " shutil.rmtree(Patch_source)\n", - " if os.path.exists(Patch_target):\n", - " shutil.rmtree(Patch_target)\n", - " if os.path.exists(Patch_rejected):\n", - " shutil.rmtree(Patch_rejected)\n", - "\n", - " os.mkdir(Patch_source)\n", - " os.mkdir(Patch_target)\n", - " os.mkdir(Patch_rejected) #This directory will contain the images that have too little signal.\n", - " \n", - " patch_num = 0\n", - "\n", - " for file in tqdm(os.listdir(Training_source)):\n", - "\n", - " img = io.imread(os.path.join(Training_source, file))\n", - " mask = io.imread(os.path.join(Training_target, file),as_gray=True)\n", - "\n", - " if DEBUG:\n", - " print(file)\n", - " print(img.dtype)\n", - "\n", - " # Using view_as_windows with step size equal to the patch size to ensure there is no overlap\n", - " patches_img = view_as_windows(img, (patch_width, patch_height), (patch_width, patch_height))\n", - " patches_mask = view_as_windows(mask, (patch_width, patch_height), (patch_width, patch_height))\n", - "\n", - " patches_img = patches_img.reshape(patches_img.shape[0]*patches_img.shape[1], patch_width,patch_height)\n", - " patches_mask = patches_mask.reshape(patches_mask.shape[0]*patches_mask.shape[1], patch_width,patch_height)\n", - "\n", - " if DEBUG:\n", - " print(all_patches_img.shape)\n", - " print(all_patches_img.dtype)\n", - "\n", - " for i in range(patches_img.shape[0]):\n", - " img_save_path = os.path.join(Patch_source,'patch_'+str(patch_num)+'.tif')\n", - " mask_save_path = os.path.join(Patch_target,'patch_'+str(patch_num)+'.tif')\n", - " patch_num += 1\n", - "\n", - " # if the mask conatins at least 2% of its total number pixels as mask, then go ahead and save the images\n", - " pixel_threshold_array = sorted(patches_mask[i].flatten())\n", - " if pixel_threshold_array[int(round(len(pixel_threshold_array)*(1-min_fraction)))]>0:\n", - " io.imsave(img_save_path, img_as_ubyte(normalizeMinMax(patches_img[i])))\n", - " io.imsave(mask_save_path, convert2Mask(normalizeMinMax(patches_mask[i]),0))\n", - " else:\n", - " io.imsave(Patch_rejected+'/patch_'+str(patch_num)+'_image.tif', img_as_ubyte(normalizeMinMax(patches_img[i])))\n", - " io.imsave(Patch_rejected+'/patch_'+str(patch_num)+'_mask.tif', convert2Mask(normalizeMinMax(patches_mask[i]),0))\n", - "\n", - " return Patch_source, Patch_target\n", - "\n", - "\n", - "def estimatePatchSize(data_path, max_width = 512, max_height = 512):\n", - "\n", - " files = os.listdir(data_path)\n", - " \n", - " # Get the size of the first image found in the folder and initialise the variables to that\n", - " n = 0 \n", - " while os.path.isdir(os.path.join(data_path, files[n])):\n", - " n += 1\n", - " (height_min, width_min) = Image.open(os.path.join(data_path, files[n])).size\n", - "\n", - " # Screen the size of all dataset to find the minimum image size\n", - " for file in files:\n", - " if not os.path.isdir(os.path.join(data_path, file)):\n", - " (height, width) = Image.open(os.path.join(data_path, file)).size\n", - " if width < width_min:\n", - " width_min = width\n", - " if height < height_min:\n", - " height_min = height\n", - " \n", - " # Find the power of patches that will fit within the smallest dataset\n", - " width_min, height_min = (fittingPowerOfTwo(width_min), fittingPowerOfTwo(height_min))\n", - "\n", - " # Clip values at maximum permissible values\n", - " if width_min > max_width:\n", - " width_min = max_width\n", - "\n", - " if height_min > max_height:\n", - " height_min = max_height\n", - " \n", - " return (width_min, height_min)\n", - "\n", - "def fittingPowerOfTwo(number):\n", - " n = 0\n", - " while 2**n <= number:\n", - " n += 1 \n", - " return 2**(n-1)\n", - "\n", - "\n", - "def getClassWeights(Training_target_path):\n", - "\n", - " Mask_dir_list = os.listdir(Training_target_path)\n", - " number_of_dataset = len(Mask_dir_list)\n", - "\n", - " class_count = np.zeros(2, dtype=int)\n", - " for i in tqdm(range(number_of_dataset)):\n", - " mask = io.imread(os.path.join(Training_target_path, Mask_dir_list[i]))\n", - " mask = normalizeMinMax(mask)\n", - " class_count[0] += mask.shape[0]*mask.shape[1] - mask.sum()\n", - " class_count[1] += mask.sum()\n", - "\n", - " n_samples = class_count.sum()\n", - " n_classes = 2\n", - "\n", - " class_weights = n_samples / (n_classes * class_count)\n", - " return class_weights\n", - "\n", - "def weighted_binary_crossentropy(class_weights):\n", - "\n", - " def _weighted_binary_crossentropy(y_true, y_pred):\n", - " binary_crossentropy = keras.binary_crossentropy(y_true, y_pred)\n", - " weight_vector = y_true * class_weights[1] + (1. - y_true) * class_weights[0]\n", - " weighted_binary_crossentropy = weight_vector * binary_crossentropy\n", - "\n", - " return keras.mean(weighted_binary_crossentropy)\n", - "\n", - " return _weighted_binary_crossentropy\n", - "\n", - "\n", - "def save_augment(datagen,orig_img,dir_augmented_data=\"/content/augment\"):\n", - " \"\"\"\n", - " Saves a subset of the augmented data for visualisation, by default in /content.\n", - "\n", - " This is adapted from: https://fairyonice.github.io/Learn-about-ImageDataGenerator.html\n", - " \n", - " \"\"\"\n", - " try:\n", - " os.mkdir(dir_augmented_data)\n", - " except:\n", - " ## if the preview folder exists, then remove\n", - " ## the contents (pictures) in the folder\n", - " for item in os.listdir(dir_augmented_data):\n", - " os.remove(dir_augmented_data + \"/\" + item)\n", - "\n", - " ## convert the original image to array\n", - " x = img_to_array(orig_img)\n", - " ## reshape (Sampke, Nrow, Ncol, 3) 3 = R, G or B\n", - " #print(x.shape)\n", - " x = x.reshape((1,) + x.shape)\n", - " #print(x.shape)\n", - " ## -------------------------- ##\n", - " ## randomly generate pictures\n", - " ## -------------------------- ##\n", - " i = 0\n", - " #We will just save 5 images,\n", - " #but this can be changed, but note the visualisation in 3. currently uses 5.\n", - " Nplot = 5\n", - " for batch in datagen.flow(x,batch_size=1,\n", - " save_to_dir=dir_augmented_data,\n", - " save_format='tif',\n", - " seed=42):\n", - " i += 1\n", - " if i > Nplot - 1:\n", - " break\n", - "\n", - "# Generators\n", - "def buildDoubleGenerator(image_datagen, mask_datagen, image_folder_path, mask_folder_path, subset, batch_size, target_size):\n", - " '''\n", - " Can generate image and mask at the same time use the same seed for image_datagen and mask_datagen to ensure the transformation for image and mask is the same\n", - " \n", - " datagen: ImageDataGenerator \n", - " subset: can take either 'training' or 'validation'\n", - " '''\n", - " seed = 1\n", - " image_generator = image_datagen.flow_from_directory(\n", - " os.path.dirname(image_folder_path),\n", - " classes = [os.path.basename(image_folder_path)],\n", - " class_mode = None,\n", - " color_mode = \"grayscale\",\n", - " target_size = target_size,\n", - " batch_size = batch_size,\n", - " subset = subset,\n", - " interpolation = \"bicubic\",\n", - " seed = seed)\n", - " \n", - " mask_generator = mask_datagen.flow_from_directory(\n", - " os.path.dirname(mask_folder_path),\n", - " classes = [os.path.basename(mask_folder_path)],\n", - " class_mode = None,\n", - " color_mode = \"grayscale\",\n", - " target_size = target_size,\n", - " batch_size = batch_size,\n", - " subset = subset,\n", - " interpolation = \"nearest\",\n", - " seed = seed)\n", - " \n", - " this_generator = zip(image_generator, mask_generator)\n", - " for (img,mask) in this_generator:\n", - " # img,mask = adjustData(img,mask)\n", - " yield (img,mask)\n", - "\n", - "\n", - "def prepareGenerators(image_folder_path, mask_folder_path, datagen_parameters, batch_size = 4, target_size = (512, 512)):\n", - " image_datagen = ImageDataGenerator(**datagen_parameters, preprocessing_function = normalizePercentile)\n", - " mask_datagen = ImageDataGenerator(**datagen_parameters, preprocessing_function = normalizeMinMax)\n", - "\n", - " train_datagen = buildDoubleGenerator(image_datagen, mask_datagen, image_folder_path, mask_folder_path, 'training', batch_size, target_size)\n", - " validation_datagen = buildDoubleGenerator(image_datagen, mask_datagen, image_folder_path, mask_folder_path, 'validation', batch_size, target_size)\n", - "\n", - " return (train_datagen, validation_datagen)\n", - "\n", - "\n", - "# Normalization functions from Martin Weigert\n", - "def normalizePercentile(x, pmin=1, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n", - " \"\"\"This function is adapted from Martin Weigert\"\"\"\n", - " \"\"\"Percentile-based image normalization.\"\"\"\n", - "\n", - " mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n", - " ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n", - " return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n", - "\n", - "\n", - "def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n", - " \"\"\"This function is adapted from Martin Weigert\"\"\"\n", - " if dtype is not None:\n", - " x = x.astype(dtype,copy=False)\n", - " mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n", - " ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n", - " eps = dtype(eps)\n", - "\n", - " try:\n", - " import numexpr\n", - " x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n", - " except ImportError:\n", - " x = (x - mi) / ( ma - mi + eps )\n", - "\n", - " if clip:\n", - " x = np.clip(x,0,1)\n", - "\n", - " return x\n", - "\n", - "\n", - "\n", - "# Simple normalization to min/max fir the Mask\n", - "def normalizeMinMax(x, dtype=np.float32):\n", - " x = x.astype(dtype,copy=False)\n", - " x = (x - np.amin(x)) / (np.amax(x) - np.amin(x))\n", - " return x\n", - "\n", - "\n", - "# This is code outlines the architecture of U-net. The choice of pooling steps decides the depth of the network. \n", - "def unet(pretrained_weights = None, input_size = (256,256,1), pooling_steps = 4, learning_rate = 1e-4, verbose=True, class_weights=np.ones(2)):\n", - " inputs = Input(input_size)\n", - " conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(inputs)\n", - " conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1)\n", - " # Downsampling steps\n", - " pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)\n", - " conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1)\n", - " conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2)\n", - " \n", - " if pooling_steps > 1:\n", - " pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)\n", - " conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2)\n", - " conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3)\n", - "\n", - " if pooling_steps > 2:\n", - " pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)\n", - " conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3)\n", - " conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4)\n", - " drop4 = Dropout(0.5)(conv4)\n", - " \n", - " if pooling_steps > 3:\n", - " pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)\n", - " conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4)\n", - " conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5)\n", - " drop5 = Dropout(0.5)(conv5)\n", - "\n", - " #Upsampling steps\n", - " up6 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5))\n", - " merge6 = concatenate([drop4,up6], axis = 3)\n", - " conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6)\n", - " conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6)\n", - " \n", - " if pooling_steps > 2:\n", - " up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop4))\n", - " if pooling_steps > 3:\n", - " up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6))\n", - " merge7 = concatenate([conv3,up7], axis = 3)\n", - " conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7)\n", - " conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7)\n", - "\n", - " if pooling_steps > 1:\n", - " up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv3))\n", - " if pooling_steps > 2:\n", - " up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7))\n", - " merge8 = concatenate([conv2,up8], axis = 3)\n", - " conv8 = Conv2D(128, 3, activation= 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8)\n", - " conv8 = Conv2D(128, 3, activation= 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8)\n", - " \n", - " if pooling_steps == 1:\n", - " up9 = Conv2D(64, 2, padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv2))\n", - " else:\n", - " up9 = Conv2D(64, 2, padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8)) #activation = 'relu'\n", - " \n", - " merge9 = concatenate([conv1,up9], axis = 3)\n", - " conv9 = Conv2D(64, 3, padding = 'same', kernel_initializer = 'he_normal')(merge9) #activation = 'relu'\n", - " conv9 = Conv2D(64, 3, padding = 'same', kernel_initializer = 'he_normal')(conv9) #activation = 'relu'\n", - " conv9 = Conv2D(2, 3, padding = 'same', kernel_initializer = 'he_normal')(conv9) #activation = 'relu'\n", - " conv10 = Conv2D(1, 1, activation = 'sigmoid')(conv9)\n", - "\n", - " model = Model(inputs = inputs, outputs = conv10)\n", - "\n", - " # model.compile(optimizer = Adam(lr = learning_rate), loss = 'binary_crossentropy', metrics = ['acc'])\n", - " model.compile(optimizer = Adam(lr = learning_rate), loss = weighted_binary_crossentropy(class_weights))\n", - "\n", - "\n", - " if verbose:\n", - " model.summary()\n", - "\n", - " if(pretrained_weights):\n", - " \tmodel.load_weights(pretrained_weights);\n", - "\n", - " return model\n", - "\n", - "\n", - "\n", - "def predict_as_tiles(Image_path, model):\n", - "\n", - " # Read the data in and normalize\n", - " Image_raw = io.imread(Image_path, as_gray = True)\n", - " Image_raw = normalizePercentile(Image_raw)\n", - "\n", - " # Get the patch size from the input layer of the model\n", - " patch_size = model.layers[0].output_shape[1:3]\n", - "\n", - " # Pad the image with zeros if any of its dimensions is smaller than the patch size\n", - " if Image_raw.shape[0] < patch_size[0] or Image_raw.shape[1] < patch_size[1]:\n", - " Image = np.zeros((max(Image_raw.shape[0], patch_size[0]), max(Image_raw.shape[1], patch_size[1])))\n", - " Image[0:Image_raw.shape[0], 0: Image_raw.shape[1]] = Image_raw\n", - " else:\n", - " Image = Image_raw\n", - "\n", - " # Calculate the number of patches in each dimension\n", - " n_patch_in_width = ceil(Image.shape[0]/patch_size[0])\n", - " n_patch_in_height = ceil(Image.shape[1]/patch_size[1])\n", - "\n", - " prediction = np.zeros(Image.shape)\n", - "\n", - " for x in range(n_patch_in_width):\n", - " for y in range(n_patch_in_height):\n", - " xi = patch_size[0]*x\n", - " yi = patch_size[1]*y\n", - "\n", - " # If the patch exceeds the edge of the image shift it back \n", - " if xi+patch_size[0] >= Image.shape[0]:\n", - " xi = Image.shape[0]-patch_size[0]\n", - "\n", - " if yi+patch_size[1] >= Image.shape[1]:\n", - " yi = Image.shape[1]-patch_size[1]\n", - " \n", - " # Extract and reshape the patch\n", - " patch = Image[xi:xi+patch_size[0], yi:yi+patch_size[1]]\n", - " patch = np.reshape(patch,patch.shape+(1,))\n", - " patch = np.reshape(patch,(1,)+patch.shape)\n", - "\n", - " # Get the prediction from the patch and paste it in the prediction in the right place\n", - " predicted_patch = model.predict(patch, batch_size = 1)\n", - " prediction[xi:xi+patch_size[0], yi:yi+patch_size[1]] = np.squeeze(predicted_patch)\n", - "\n", - "\n", - " return prediction[0:Image_raw.shape[0], 0: Image_raw.shape[1]]\n", - " \n", - "\n", - "\n", - "\n", - "def saveResult(save_path, nparray, source_dir_list, prefix='', threshold=None):\n", - " for (filename, image) in zip(source_dir_list, nparray):\n", - " io.imsave(os.path.join(save_path, prefix+os.path.splitext(filename)[0]+'.tif'), img_as_ubyte(image)) # saving as unsigned 8-bit image\n", - " \n", - " # For masks, threshold the images and return 8 bit image\n", - " if threshold is not None:\n", - " mask = convert2Mask(image, threshold)\n", - " io.imsave(os.path.join(save_path, prefix+'mask_'+os.path.splitext(filename)[0]+'.tif'), mask)\n", - "\n", - "\n", - "def convert2Mask(image, threshold):\n", - " mask = img_as_ubyte(image, force_copy=True)\n", - " mask[mask > threshold] = 255\n", - " mask[mask <= threshold] = 0\n", - " return mask\n", - "\n", - "\n", - "def getIoUvsThreshold(prediction_filepath, groud_truth_filepath):\n", - " prediction = io.imread(prediction_filepath)\n", - " ground_truth_image = img_as_ubyte(io.imread(groud_truth_filepath, as_gray=True), force_copy=True)\n", - "\n", - " threshold_list = []\n", - " IoU_scores_list = []\n", - "\n", - " for threshold in range(0,256): \n", - " # Convert to 8-bit for calculating the IoU\n", - " mask = img_as_ubyte(prediction, force_copy=True)\n", - " mask[mask > threshold] = 255\n", - " mask[mask <= threshold] = 0\n", - "\n", - " # Intersection over Union metric\n", - " intersection = np.logical_and(ground_truth_image, np.squeeze(mask))\n", - " union = np.logical_or(ground_truth_image, np.squeeze(mask))\n", - " iou_score = np.sum(intersection) / np.sum(union)\n", - "\n", - " threshold_list.append(threshold)\n", - " IoU_scores_list.append(iou_score)\n", - "\n", - " return (threshold_list, IoU_scores_list)\n", - "\n", - "\n", - "\n", - "# -------------- Other definitions -----------\n", - "W = '\\033[0m' # white (normal)\n", - "R = '\\033[31m' # red\n", - "prediction_prefix = 'Predicted_'\n", - "\n", - "\n", - "print('-------------------')\n", - "print('U-Net and dependencies installed.')\n", - "\n", - "# Colors for the warning messages\n", - "class bcolors:\n", - " WARNING = '\\033[31m'\n", - "\n", - "# Check if this is the latest version of the notebook\n", - "Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n", - "\n", - "print('Notebook version: '+Notebook_version[0])\n", - "\n", - "strlist = Notebook_version[0].split('.')\n", - "Notebook_version_main = strlist[0]+'.'+strlist[1]\n", - "\n", - "if Notebook_version_main == Latest_notebook_version.columns:\n", - " print(\"This notebook is up-to-date.\")\n", - "else:\n", - " print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n", - "\n", - "def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n", - " class MyFPDF(FPDF, HTMLMixin):\n", - " pass\n", - "\n", - " pdf = MyFPDF()\n", - " pdf.add_page()\n", - " pdf.set_right_margin(-1)\n", - " pdf.set_font(\"Arial\", size = 11, style='B') \n", - "\n", - " Network = 'U-Net 2D'\n", - "\n", - "\n", - " day = datetime.now()\n", - " datetime_str = str(day)[0:10]\n", - "\n", - " Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n", - " pdf.multi_cell(180, 5, txt = Header, align = 'L') \n", - " \n", - " # add another cell \n", - " if trained:\n", - " training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n", - " pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n", - " pdf.ln(1)\n", - "\n", - " Header_2 = 'Information for your materials and method:'\n", - " pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n", - "\n", - " all_packages = ''\n", - " for requirement in freeze(local_only=True):\n", - " all_packages = all_packages+requirement+', '\n", - " #print(all_packages)\n", - "\n", - " #Main Packages\n", - " main_packages = ''\n", - " version_numbers = []\n", - " for name in ['tensorflow','numpy','Keras']:\n", - " find_name=all_packages.find(name)\n", - " main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n", - " #Version numbers only here:\n", - " version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n", - "\n", - " cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n", - " cuda_version = cuda_version.stdout.decode('utf-8')\n", - " cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n", - " gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n", - " gpu_name = gpu_name.stdout.decode('utf-8')\n", - " gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n", - " #print(cuda_version[cuda_version.find(', V')+3:-1])\n", - " #print(gpu_name)\n", - " loss = str(model.loss)[str(model.loss).find('function')+len('function'):str(model.loss).find('.<')]\n", - " shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n", - " dataset_size = len(os.listdir(Training_source))\n", - "\n", - " text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(number_of_training_dataset)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_width)+','+str(patch_height)+')) with a batch size of '+str(batch_size)+' and a'+loss+' loss function,'+' using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n", - "\n", - " if pretrained_model:\n", - " text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(number_of_training_dataset)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_width)+','+str(patch_height)+')) with a batch size of '+str(batch_size)+' and a'+loss+' loss function,'+' using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was re-trained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n", - "\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - " pdf.multi_cell(180, 5, txt = text, align='L')\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.ln(1)\n", - " pdf.cell(28, 5, txt='Augmentation: ', ln=1)\n", - " pdf.set_font('')\n", - " if augmentation:\n", - " aug_text = 'The dataset was augmented by'\n", - " if rotation_range != 0:\n", - " aug_text = aug_text+'\\n- rotation'\n", - " if horizontal_flip == True or vertical_flip == True:\n", - " aug_text = aug_text+'\\n- flipping'\n", - " if zoom_range != 0:\n", - " aug_text = aug_text+'\\n- random zoom magnification'\n", - " if horizontal_shift != 0 or vertical_shift != 0:\n", - " aug_text = aug_text+'\\n- shifting'\n", - " if shear_range != 0:\n", - " aug_text = aug_text+'\\n- image shearing'\n", - " else:\n", - " aug_text = 'No augmentation was used for training.'\n", - " pdf.multi_cell(190, 5, txt=aug_text, align='L')\n", - " pdf.set_font('Arial', size = 11, style = 'B')\n", - " pdf.ln(1)\n", - " pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - " if Use_Default_Advanced_Parameters:\n", - " pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n", - " pdf.cell(200, 5, txt='The following parameters were used for training:')\n", - " pdf.ln(1)\n", - " html = \"\"\" \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
ParameterValue
number_of_epochs{0}
patch_size{1}
batch_size{2}
number_of_steps{3}
percentage_validation{4}
initial_learning_rate{5}
pooling_steps{6}
min_fraction{7}
\n", - " \"\"\".format(number_of_epochs, str(patch_width)+'x'+str(patch_height), batch_size, number_of_steps, percentage_validation, initial_learning_rate, pooling_steps, min_fraction)\n", - " pdf.write_html(html)\n", - "\n", - " #pdf.multi_cell(190, 5, txt = text_2, align='L')\n", - " pdf.set_font(\"Arial\", size = 11, style='B')\n", - " pdf.ln(1)\n", - " pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.cell(29, 5, txt= 'Training_source:', align = 'L', ln=0)\n", - " pdf.set_font('')\n", - " pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.cell(28, 5, txt= 'Training_target:', align = 'L', ln=0)\n", - " pdf.set_font('')\n", - " pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n", - " #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n", - " pdf.ln(1)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.cell(21, 5, txt= 'Model Path:', align = 'L', ln=0)\n", - " pdf.set_font('')\n", - " pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n", - " pdf.ln(1)\n", - " pdf.cell(60, 5, txt = 'Example Training pair', ln=1)\n", - " pdf.ln(1)\n", - " exp_size = io.imread('/content/TrainingDataExample_Unet2D.png').shape\n", - " pdf.image('/content/TrainingDataExample_Unet2D.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n", - " pdf.ln(1)\n", - " ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n", - " pdf.multi_cell(190, 5, txt = ref_1, align='L')\n", - " ref_2 = '- Unet: Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. \"U-net: Convolutional networks for biomedical image segmentation.\" International Conference on Medical image computing and computer-assisted intervention. Springer, Cham, 2015.'\n", - " pdf.multi_cell(190, 5, txt = ref_2, align='L')\n", - " # if Use_Data_augmentation:\n", - " # ref_3 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\n", - " # pdf.multi_cell(190, 5, txt = ref_3, align='L')\n", - " pdf.ln(3)\n", - " reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n", - " pdf.set_font('Arial', size = 11, style='B')\n", - " pdf.multi_cell(190, 5, txt=reminder, align='C')\n", - "\n", - " pdf.output(model_path+'/'+model_name+'/'+model_name+'_training_report.pdf')\n", - "\n", - " print('------------------------------')\n", - " print('PDF report exported in '+model_path+'/'+model_name+'/')\n", - "\n", - "def qc_pdf_export():\n", - " class MyFPDF(FPDF, HTMLMixin):\n", - " pass\n", - "\n", - " pdf = MyFPDF()\n", - " pdf.add_page()\n", - " pdf.set_right_margin(-1)\n", - " pdf.set_font(\"Arial\", size = 11, style='B') \n", - "\n", - " Network = 'Unet 2D'\n", - "\n", - " day = datetime.now()\n", - " datetime_str = str(day)[0:10]\n", - "\n", - " Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n", - " pdf.multi_cell(180, 5, txt = Header, align = 'L') \n", - "\n", - " all_packages = ''\n", - " for requirement in freeze(local_only=True):\n", - " all_packages = all_packages+requirement+', '\n", - "\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 11, style = 'B')\n", - " pdf.ln(2)\n", - " pdf.cell(190, 5, txt = 'Loss curves', ln=1, align='L')\n", - " pdf.ln(1)\n", - " exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n", - " if os.path.exists(full_QC_model_path+'/Quality Control/lossCurvePlots.png'):\n", - " pdf.image(full_QC_model_path+'/Quality Control/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/12), h = round(exp_size[0]/3))\n", - " else:\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size=10)\n", - " pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.',align='L')\n", - " pdf.ln(2)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 11, style = 'B')\n", - " pdf.ln(2)\n", - " pdf.cell(190, 5, txt = 'Threshold Optimisation', ln=1, align='L')\n", - " #pdf.ln(1)\n", - " exp_size = io.imread(full_QC_model_path+'/Quality Control/'+QC_model_name+'_IoUvsThresholdPlot.png').shape\n", - " pdf.image(full_QC_model_path+'/Quality Control/'+QC_model_name+'_IoUvsThresholdPlot.png', x = 11, y = None, w = round(exp_size[1]/6), h = round(exp_size[0]/7))\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.ln(3)\n", - " pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n", - " pdf.ln(1)\n", - " exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n", - " pdf.image(full_QC_model_path+'/Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n", - " pdf.ln(1)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 11, style = 'B')\n", - " pdf.ln(1)\n", - " pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - "\n", - " pdf.ln(1)\n", - " html = \"\"\"\n", - " \n", - " \n", - " \"\"\"\n", - " with open(full_QC_model_path+'/Quality Control/QC_metrics_'+QC_model_name+'.csv', 'r') as csvfile:\n", - " metrics = csv.reader(csvfile)\n", - " header = next(metrics)\n", - " image = header[0]\n", - " IoU = header[1]\n", - " IoU_OptThresh = header[2]\n", - " header = \"\"\"\n", - " \n", - " \n", - " \n", - " \n", - " \"\"\".format(image,IoU,IoU_OptThresh)\n", - " html = html+header\n", - " i=0\n", - " for row in metrics:\n", - " i+=1\n", - " image = row[0]\n", - " IoU = row[1]\n", - " IoU_OptThresh = row[2]\n", - " cells = \"\"\"\n", - " \n", - " \n", - " \n", - " \n", - " \"\"\".format(image,str(round(float(IoU),3)),str(round(float(IoU_OptThresh),3)))\n", - " html = html+cells\n", - " html = html+\"\"\"
{0}{1}{2}
{0}{1}{2}
\"\"\"\n", - " \n", - " pdf.write_html(html)\n", - "\n", - " pdf.ln(1)\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - " ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n", - " pdf.multi_cell(190, 5, txt = ref_1, align='L')\n", - " ref_2 = '- Unet: Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. \"U-net: Convolutional networks for biomedical image segmentation.\" International Conference on Medical image computing and computer-assisted intervention. Springer, Cham, 2015.'\n", - " pdf.multi_cell(190, 5, txt = ref_2, align='L')\n", - "\n", - " pdf.ln(3)\n", - " reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n", - "\n", - " pdf.set_font('Arial', size = 11, style='B')\n", - " pdf.multi_cell(190, 5, txt=reminder, align='C')\n", - "\n", - " pdf.output(full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n", - "\n", - " print('------------------------------')\n", - " print('QC PDF report exported as '+full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n", - "\n", - "# Build requirements file for local run\n", - "after = [str(m) for m in sys.modules]\n", - "build_requirements_file(before, after)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "HLYcZR9gMv42" - }, - "source": [ - "# **3. Select your parameters and paths**\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "FQ_QxtSWQ7CL" - }, - "source": [ - "## **3.1. Setting main training parameters**\n", - "---\n", - " " - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "AuESFimvMv43" - }, - "source": [ - " **Paths for training data and models**\n", - "\n", - "**`Training_source`, `Training_target`:** These are the folders containing your source (e.g. EM images) and target files (segmentation masks). Enter the path to the source and target images for training. **These should be located in the same parent folder.**\n", - "\n", - "**`model_name`:** Use only my_model -style, not my-model. If you want to use a previously trained model, enter the name of the pretrained model (which should be contained in the trained_model -folder after training).\n", - "\n", - "**`model_path`**: Enter the path of the folder where you want to save your model.\n", - "\n", - "**`visual_validation_after_training`**: If you select this option, a random image pair will be set aside from your training set and will be used to display a predicted image of the trained network next to the input and the ground-truth. This can aid in visually assessing the performance of your network after training. **Note: Your training set size will decrease by 1 if you select this option.**\n", - "\n", - " **Select training parameters**\n", - "\n", - "**`number_of_epochs`**: Choose more epochs for larger training sets. Observing how much the loss reduces between epochs during training may help determine the optimal value. **Default: 200**\n", - "\n", - "**Advanced parameters - experienced users only**\n", - "\n", - "**`batch_size`**: This parameter describes the amount of images that are loaded into the network per step. Smaller batchsizes may improve training performance slightly but may increase training time. If the notebook crashes while loading the dataset this can be due to a too large batch size. Decrease the number in this case. **Default: 4**\n", - "\n", - "**`number_of_steps`**: This number should be equivalent to the number of samples in the training set divided by the batch size, to ensure the training iterates through the entire training set. The default value is calculated to ensure this. This behaviour can also be obtained by setting it to 0. Other values can be used for testing.\n", - "\n", - " **`pooling_steps`**: Choosing a different number of pooling layers can affect the performance of the network. Each additional pooling step will also two additional convolutions. The network can learn more complex information but is also more likely to overfit. Achieving best performance may require testing different values here. **Default: 2**\n", - "\n", - "**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during training. **Default value: 10** \n", - "\n", - "**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0003**\n", - "\n", - "**`patch_width` and `patch_height`:** The notebook crops the data in patches of fixed size prior to training. The dimensions of the patches can be defined here. When `Use_Default_Advanced_Parameters` is selected, the largest 2^n x 2^n patch that fits in the smallest dataset is chosen. Larger patches than 512x512 should **NOT** be selected for network stability.\n", - "\n", - "**`min_fraction`:** Minimum fraction of pixels being foreground for a slected patch to be considered valid. It should be between 0 and 1.**Default value: 0.02** (2%)\n", - "\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "ewpNJ_I0Mv47", - "cellView": "form" - }, - "source": [ - "# ------------- Initial user input ------------\n", - "#@markdown ###Path to training images:\n", - "Training_source = '' #@param {type:\"string\"}\n", - "Training_target = '' #@param {type:\"string\"}\n", - "\n", - "model_name = '' #@param {type:\"string\"}\n", - "model_path = '' #@param {type:\"string\"}\n", - "\n", - "#@markdown ###Training parameters:\n", - "#@markdown Number of epochs\n", - "number_of_epochs = 200#@param {type:\"number\"}\n", - "\n", - "#@markdown ###Advanced parameters:\n", - "Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n", - "\n", - "#@markdown ###If not, please input:\n", - "batch_size = 4#@param {type:\"integer\"}\n", - "number_of_steps = 0#@param {type:\"number\"}\n", - "pooling_steps = 2 #@param [1,2,3,4]{type:\"raw\"}\n", - "percentage_validation = 10#@param{type:\"number\"}\n", - "initial_learning_rate = 0.0003 #@param {type:\"number\"}\n", - "\n", - "patch_width = 512#@param{type:\"number\"}\n", - "patch_height = 512#@param{type:\"number\"}\n", - "min_fraction = 0.02#@param{type:\"number\"}\n", - "\n", - "\n", - "# ------------- Initialising folder, variables and failsafes ------------\n", - "# Create the folders where to save the model and the QC\n", - "full_model_path = os.path.join(model_path, model_name)\n", - "if os.path.exists(full_model_path):\n", - " print(R+'!! WARNING: Folder already exists and will be overwritten !!'+W)\n", - "\n", - "if (Use_Default_Advanced_Parameters): \n", - " print(\"Default advanced parameters enabled\")\n", - " batch_size = 4\n", - " pooling_steps = 2\n", - " percentage_validation = 10\n", - " initial_learning_rate = 0.0003\n", - " patch_width, patch_height = estimatePatchSize(Training_source)\n", - " min_fraction = 0.02\n", - "\n", - "\n", - "#The create_patches function will create the two folders below\n", - "# Patch_source = '/content/img_patches'\n", - "# Patch_target = '/content/mask_patches'\n", - "print('Training on patches of size (x,y): ('+str(patch_width)+','+str(patch_height)+')')\n", - "\n", - "#Create patches\n", - "print('Creating patches...')\n", - "Patch_source, Patch_target = create_patches(Training_source, Training_target, patch_width, patch_height, min_fraction)\n", - "\n", - "number_of_training_dataset = len(os.listdir(Patch_source))\n", - "print('Total number of valid patches: '+str(number_of_training_dataset))\n", - "\n", - "if Use_Default_Advanced_Parameters or number_of_steps == 0:\n", - " number_of_steps = ceil((100-percentage_validation)/100*number_of_training_dataset/batch_size)\n", - "print('Number of steps: '+str(number_of_steps))\n", - "\n", - "# Calculate the number of steps to use for validation\n", - "validation_steps = max(1, ceil(percentage_validation/100*number_of_training_dataset/batch_size))\n", - "\n", - "\n", - "# Here we disable pre-trained model by default (in case the next cell is not ran)\n", - "Use_pretrained_model = False\n", - "# Here we disable data augmentation by default (in case the cell is not ran)\n", - "Use_Data_augmentation = False\n", - "# Build the default dict for the ImageDataGenerator\n", - "data_gen_args = dict(width_shift_range = 0.,\n", - " height_shift_range = 0.,\n", - " rotation_range = 0., #90\n", - " zoom_range = 0.,\n", - " shear_range = 0.,\n", - " horizontal_flip = False,\n", - " vertical_flip = False,\n", - " validation_split = percentage_validation/100,\n", - " fill_mode = 'reflect')\n", - "\n", - "# ------------- Display ------------\n", - "\n", - "#if not os.path.exists('/content/img_patches/'):\n", - "random_choice = random.choice(os.listdir(Patch_source))\n", - "x = io.imread(os.path.join(Patch_source, random_choice))\n", - "\n", - "#os.chdir(Training_target)\n", - "y = io.imread(os.path.join(Patch_target, random_choice), as_gray=True)\n", - "\n", - "f=plt.figure(figsize=(16,8))\n", - "plt.subplot(1,2,1)\n", - "plt.imshow(x, interpolation='nearest',cmap='gray')\n", - "plt.title('Training image patch')\n", - "plt.axis('off');\n", - "\n", - "plt.subplot(1,2,2)\n", - "plt.imshow(y, interpolation='nearest',cmap='gray')\n", - "plt.title('Training mask patch')\n", - "plt.axis('off');\n", - "\n", - "plt.savefig('/content/TrainingDataExample_Unet2D.png',bbox_inches='tight',pad_inches=0)\n", - "\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "w_jCy7xOx2g3" - }, - "source": [ - "##**3.2. Data augmentation**\n", - "\n", - "---\n", - "\n", - " Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if the dataset is large the values can be set to 0.\n", - "\n", - " The augmentation options below are to be used as follows:\n", - "\n", - "* **shift**: a translation of the image by a fraction of the image size (width or height), **default: 10%**\n", - "* **zoom_range**: Increasing or decreasing the field of view. E.g. 10% will result in a zoom range of (0.9 to 1.1), with pixels added or interpolated, depending on the transformation, **default: 10%**\n", - "* **shear_range**: Shear angle in counter-clockwise direction, **default: 10%**\n", - "* **flip**: creating a mirror image along specified axis (horizontal or vertical), **default: True**\n", - "* **rotation_range**: range of allowed rotation angles in degrees (from 0 to *value*), **default: 180**" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "DMqWq5-AxnFU", - "cellView": "form" - }, - "source": [ - "#@markdown ##**Augmentation options**\n", - "\n", - "Use_Data_augmentation = True #@param {type:\"boolean\"}\n", - "Use_Default_Augmentation_Parameters = True #@param {type:\"boolean\"}\n", - "\n", - "if Use_Data_augmentation:\n", - " if Use_Default_Augmentation_Parameters:\n", - " horizontal_shift = 10 \n", - " vertical_shift = 20 \n", - " zoom_range = 10\n", - " shear_range = 10\n", - " horizontal_flip = True\n", - " vertical_flip = True\n", - " rotation_range = 180\n", - "#@markdown ###If you are not using the default settings, please provide the values below:\n", - "\n", - "#@markdown ###**Image shift, zoom, shear and flip (%)**\n", - " else:\n", - " horizontal_shift = 10 #@param {type:\"slider\", min:0, max:100, step:1}\n", - " vertical_shift = 10 #@param {type:\"slider\", min:0, max:100, step:1}\n", - " zoom_range = 10 #@param {type:\"slider\", min:0, max:100, step:1}\n", - " shear_range = 10 #@param {type:\"slider\", min:0, max:100, step:1}\n", - " horizontal_flip = True #@param {type:\"boolean\"}\n", - " vertical_flip = True #@param {type:\"boolean\"}\n", - "\n", - "#@markdown ###**Rotate image within angle range (degrees):**\n", - " rotation_range = 180 #@param {type:\"slider\", min:0, max:180, step:1}\n", - "\n", - "#given behind the # are the default values for each parameter.\n", - "\n", - "else:\n", - " horizontal_shift = 0 \n", - " vertical_shift = 0 \n", - " zoom_range = 0\n", - " shear_range = 0\n", - " horizontal_flip = False\n", - " vertical_flip = False\n", - " rotation_range = 0\n", - "\n", - "\n", - "# Build the dict for the ImageDataGenerator\n", - "data_gen_args = dict(width_shift_range = horizontal_shift/100.,\n", - " height_shift_range = vertical_shift/100.,\n", - " rotation_range = rotation_range, #90\n", - " zoom_range = zoom_range/100.,\n", - " shear_range = shear_range/100.,\n", - " horizontal_flip = horizontal_flip,\n", - " vertical_flip = vertical_flip,\n", - " validation_split = percentage_validation/100,\n", - " fill_mode = 'reflect')\n", - "\n", - "\n", - "\n", - "# ------------- Display ------------\n", - "dir_augmented_data_imgs=\"/content/augment_img\"\n", - "dir_augmented_data_masks=\"/content/augment_mask\"\n", - "random_choice = random.choice(os.listdir(Patch_source))\n", - "orig_img = load_img(os.path.join(Patch_source,random_choice))\n", - "orig_mask = load_img(os.path.join(Patch_target,random_choice))\n", - "\n", - "augment_view = ImageDataGenerator(**data_gen_args)\n", - "\n", - "if Use_Data_augmentation:\n", - " print(\"Parameters enabled\")\n", - " print(\"Here is what a subset of your augmentations looks like:\")\n", - " save_augment(augment_view, orig_img, dir_augmented_data=dir_augmented_data_imgs)\n", - " save_augment(augment_view, orig_mask, dir_augmented_data=dir_augmented_data_masks)\n", - "\n", - " fig = plt.figure(figsize=(15, 7))\n", - " fig.subplots_adjust(hspace=0.0,wspace=0.1,left=0,right=1.1,bottom=0, top=0.8)\n", - "\n", - " \n", - " ax = fig.add_subplot(2, 6, 1,xticks=[],yticks=[]) \n", - " new_img=img_as_ubyte(normalizeMinMax(img_to_array(orig_img)))\n", - " ax.imshow(new_img)\n", - " ax.set_title('Original Image')\n", - " i = 2\n", - " for imgnm in os.listdir(dir_augmented_data_imgs):\n", - " ax = fig.add_subplot(2, 6, i,xticks=[],yticks=[]) \n", - " img = load_img(dir_augmented_data_imgs + \"/\" + imgnm)\n", - " ax.imshow(img)\n", - " i += 1\n", - "\n", - " ax = fig.add_subplot(2, 6, 7,xticks=[],yticks=[]) \n", - " new_mask=img_as_ubyte(normalizeMinMax(img_to_array(orig_mask)))\n", - " ax.imshow(new_mask)\n", - " ax.set_title('Original Mask')\n", - " j=2\n", - " for imgnm in os.listdir(dir_augmented_data_masks):\n", - " ax = fig.add_subplot(2, 6, j+6,xticks=[],yticks=[]) \n", - " mask = load_img(dir_augmented_data_masks + \"/\" + imgnm)\n", - " ax.imshow(mask)\n", - " j += 1\n", - " plt.show()\n", - "\n", - "else:\n", - " print(\"No augmentation will be used\")\n", - "\n", - " " - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "3L9zSGtORKYI" - }, - "source": [ - "\n", - "## **3.3. Using weights from a pre-trained model as initial weights**\n", - "---\n", - " Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a U-Net model**. \n", - "\n", - " This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n", - "\n", - " In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. " - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "9vC2n-HeLdiJ", - "cellView": "form" - }, - "source": [ - "# @markdown ##Loading weights from a pre-trained network\n", - "\n", - "Use_pretrained_model = True #@param {type:\"boolean\"}\n", - "pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n", - "Weights_choice = \"last\" #@param [\"last\", \"best\"]\n", - "\n", - "\n", - "#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n", - "pretrained_model_path = \"\" #@param {type:\"string\"}\n", - "\n", - "# --------------------- Check if we load a previously trained model ------------------------\n", - "if Use_pretrained_model:\n", - "\n", - "# --------------------- Load the model from the choosen path ------------------------\n", - " if pretrained_model_choice == \"Model_from_file\":\n", - " h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".hdf5\")\n", - "\n", - "\n", - "# --------------------- Download the a model provided in the XXX ------------------------\n", - "\n", - " if pretrained_model_choice == \"Model_name\":\n", - " pretrained_model_name = \"Model_name\"\n", - " pretrained_model_path = \"/content/\"+pretrained_model_name\n", - " print(\"Downloading the UNET_Model_from_\")\n", - " if os.path.exists(pretrained_model_path):\n", - " shutil.rmtree(pretrained_model_path)\n", - " os.makedirs(pretrained_model_path)\n", - " wget.download(\"\", pretrained_model_path)\n", - " wget.download(\"\", pretrained_model_path)\n", - " wget.download(\"\", pretrained_model_path) \n", - " wget.download(\"\", pretrained_model_path)\n", - " h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".hdf5\")\n", - "\n", - "# --------------------- Add additional pre-trained models here ------------------------\n", - "\n", - "\n", - "\n", - "# --------------------- Check the model exist ------------------------\n", - "# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n", - " if not os.path.exists(h5_file_path):\n", - " print(R+'WARNING: pretrained model does not exist')\n", - " Use_pretrained_model = False\n", - " \n", - "\n", - "# If the model path contains a pretrain model, we load the training rate, \n", - " if os.path.exists(h5_file_path):\n", - "#Here we check if the learning rate can be loaded from the quality control folder\n", - " if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n", - "\n", - " with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n", - " csvRead = pd.read_csv(csvfile, sep=',')\n", - " #print(csvRead)\n", - " \n", - " if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n", - " print(\"pretrained network learning rate found\")\n", - " #find the last learning rate\n", - " lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n", - " #Find the learning rate corresponding to the lowest validation loss\n", - " min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n", - " #print(min_val_loss)\n", - " bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n", - "\n", - " if Weights_choice == \"last\":\n", - " print('Last learning rate: '+str(lastLearningRate))\n", - "\n", - " if Weights_choice == \"best\":\n", - " print('Learning rate of best validation loss: '+str(bestLearningRate))\n", - "\n", - " if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n", - " bestLearningRate = initial_learning_rate\n", - " lastLearningRate = initial_learning_rate\n", - " print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n", - "\n", - "#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n", - " if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n", - " print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n", - " bestLearningRate = initial_learning_rate\n", - " lastLearningRate = initial_learning_rate\n", - "\n", - "\n", - "# Display info about the pretrained model to be loaded (or not)\n", - "if Use_pretrained_model:\n", - " print('Weights found in:')\n", - " print(h5_file_path)\n", - " print('will be loaded prior to training.')\n", - "\n", - "else:\n", - " print(R+'No pretrained network will be used.')\n", - "\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "MCGklf1vZf2M" - }, - "source": [ - "# **4. Train the network**\n", - "---\n", - "####**Troubleshooting:** If you receive a time-out or exhausted error, try reducing the batchsize of your training set. This reduces the amount of data loaded into the model at one point in time. " - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1KYOuygETJkT" - }, - "source": [ - "## **4.1. Prepare the training data and model for training**\n", - "---\n", - "Here, we use the information from 3. to build the model and convert the training data into a suitable format for training." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "lIUAOJ_LMv5E", - "cellView": "form" - }, - "source": [ - "#@markdown ##Play this cell to prepare the model for training\n", - "\n", - "\n", - "# ------------------ Set the generators, model and logger ------------------\n", - "# This will take the image size and set that as a patch size (arguable...)\n", - "# Read image size (without actuall reading the data)\n", - "\n", - "(train_datagen, validation_datagen) = prepareGenerators(Patch_source, Patch_target, data_gen_args, batch_size, target_size = (patch_width, patch_height))\n", - "\n", - "\n", - "# This modelcheckpoint will only save the best model from the validation loss point of view\n", - "model_checkpoint = ModelCheckpoint(os.path.join(full_model_path, 'weights_best.hdf5'), monitor='val_loss',verbose=1, save_best_only=True)\n", - "\n", - "print('Getting class weights...')\n", - "class_weights = getClassWeights(Training_target)\n", - "\n", - "# --------------------- Using pretrained model ------------------------\n", - "#Here we ensure that the learning rate set correctly when using pre-trained models\n", - "if Use_pretrained_model:\n", - " if Weights_choice == \"last\":\n", - " initial_learning_rate = lastLearningRate\n", - "\n", - " if Weights_choice == \"best\": \n", - " initial_learning_rate = bestLearningRate\n", - "else:\n", - " h5_file_path = None\n", - "\n", - "# --------------------- ---------------------- ------------------------\n", - "\n", - "# --------------------- Reduce learning rate on plateau ------------------------\n", - "\n", - "reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, verbose=1, mode='auto',\n", - " patience=10, min_lr=0)\n", - "# --------------------- ---------------------- ------------------------\n", - "\n", - "\n", - "# Define the model\n", - "model = unet(pretrained_weights = h5_file_path, \n", - " input_size = (patch_width,patch_height,1), \n", - " pooling_steps = pooling_steps, \n", - " learning_rate = initial_learning_rate, \n", - " class_weights = class_weights)\n", - "\n", - "config_model= model.optimizer.get_config()\n", - "print(config_model)\n", - "\n", - "\n", - "# ------------------ Failsafes ------------------\n", - "if os.path.exists(full_model_path):\n", - " print(R+'!! WARNING: Model folder already existed and has been removed !!'+W)\n", - " shutil.rmtree(full_model_path)\n", - "\n", - "os.makedirs(full_model_path)\n", - "os.makedirs(os.path.join(full_model_path,'Quality Control'))\n", - "\n", - "\n", - "# ------------------ Display ------------------\n", - "print('---------------------------- Main training parameters ----------------------------')\n", - "print('Number of epochs: '+str(number_of_epochs))\n", - "print('Batch size: '+str(batch_size))\n", - "print('Number of training dataset: '+str(number_of_training_dataset))\n", - "print('Number of training steps: '+str(number_of_steps))\n", - "print('Number of validation steps: '+str(validation_steps))\n", - "print('---------------------------- ------------------------ ----------------------------')\n", - "\n", - "pdf_export(augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "0Dfn8ZsEMv5d" - }, - "source": [ - "## **4.2. Start Training**\n", - "---\n", - "When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n", - "\n", - "* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point.\n", - "\n", - "Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder from Google Drive as all data can be erased at the next training if using the same folder." - ] - }, - { - "cell_type": "code", - "metadata": { - "scrolled": true, - "id": "iwNmp1PUzRDQ", - "cellView": "form" - }, - "source": [ - "#@markdown ##Start training\n", - "\n", - "start = time.time()\n", - "# history = model.fit_generator(train_datagen, steps_per_epoch = number_of_steps, epochs=epochs, callbacks=[model_checkpoint,csv_log], validation_data = validation_datagen, validation_steps = validation_steps, shuffle=True, verbose=1)\n", - "history = model.fit_generator(train_datagen, steps_per_epoch = number_of_steps, epochs = number_of_epochs, callbacks=[model_checkpoint, reduce_lr], validation_data = validation_datagen, validation_steps = validation_steps, shuffle=True, verbose=1)\n", - "\n", - "# Save the last model\n", - "model.save(os.path.join(full_model_path, 'weights_last.hdf5'))\n", - "\n", - "\n", - "# convert the history.history dict to a pandas DataFrame: \n", - "lossData = pd.DataFrame(history.history) \n", - "\n", - "# The training evaluation.csv is saved (overwrites the Files if needed). \n", - "lossDataCSVpath = os.path.join(full_model_path,'Quality Control/training_evaluation.csv')\n", - "with open(lossDataCSVpath, 'w') as f:\n", - " writer = csv.writer(f)\n", - " writer.writerow(['loss','val_loss', 'learning rate'])\n", - " for i in range(len(history.history['loss'])):\n", - " writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n", - " \n", - "\n", - "\n", - "# Displaying the time elapsed for training\n", - "print(\"------------------------------------------\")\n", - "dt = time.time() - start\n", - "mins, sec = divmod(dt, 60) \n", - "hour, mins = divmod(mins, 60) \n", - "print(\"Time elapsed:\", hour, \"hour(s)\", mins,\"min(s)\",round(sec),\"sec(s)\")\n", - "print(\"------------------------------------------\")\n", - "\n", - "#Create a pdf document with training summary\n", - "\n", - "pdf_export(trained = True, augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "_0Hynw3-xHp1" - }, - "source": [ - "# **5. Evaluate your model**\n", - "---\n", - "\n", - "This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n", - "\n", - "**We highly recommend to perform quality control on all newly trained models.**\n", - "\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "eAJzMwPA6tlH", - "cellView": "form" - }, - "source": [ - "#@markdown ###Do you want to assess the model you just trained ?\n", - "\n", - "Use_the_current_trained_model = True #@param {type:\"boolean\"}\n", - "\n", - "#@markdown ###If not, please provide the path to the model folder:\n", - "\n", - "QC_model_folder = \"\" #@param {type:\"string\"}\n", - "\n", - "#Here we define the loaded model name and path\n", - "QC_model_name = os.path.basename(QC_model_folder)\n", - "QC_model_path = os.path.dirname(QC_model_folder)\n", - "\n", - "\n", - "if (Use_the_current_trained_model): \n", - " print(\"Using current trained network\")\n", - " QC_model_name = model_name\n", - " QC_model_path = model_path\n", - "\n", - "\n", - "full_QC_model_path = os.path.join(QC_model_path, QC_model_name)\n", - "if os.path.exists(os.path.join(full_QC_model_path, 'weights_best.hdf5')):\n", - " print(\"The \"+QC_model_name+\" network will be evaluated\")\n", - "else:\n", - " print(R+'!! WARNING: The chosen model does not exist !!'+W)\n", - " print('Please make sure you provide a valid model path and model name before proceeding further.')\n", - "\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "dhJROwlAMv5o" - }, - "source": [ - "## **5.1. Inspection of the loss function**\n", - "---\n", - "\n", - "First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n", - "\n", - "**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n", - "\n", - "**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n", - "\n", - "During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n", - "\n", - "Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "vMzSP50kMv5p", - "cellView": "form" - }, - "source": [ - "#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n", - "\n", - "epochNumber = []\n", - "lossDataFromCSV = []\n", - "vallossDataFromCSV = []\n", - "\n", - "with open(os.path.join(full_QC_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n", - " csvRead = csv.reader(csvfile, delimiter=',')\n", - " next(csvRead)\n", - " for row in csvRead:\n", - " lossDataFromCSV.append(float(row[0]))\n", - " vallossDataFromCSV.append(float(row[1]))\n", - "\n", - "epochNumber = range(len(lossDataFromCSV))\n", - "\n", - "plt.figure(figsize=(15,10))\n", - "\n", - "plt.subplot(2,1,1)\n", - "plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n", - "plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n", - "plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n", - "plt.ylabel('Loss')\n", - "plt.xlabel('Epoch number')\n", - "plt.legend()\n", - "\n", - "plt.subplot(2,1,2)\n", - "plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n", - "plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n", - "plt.title('Training loss and validation loss vs. epoch number (log scale)')\n", - "plt.ylabel('Loss')\n", - "plt.xlabel('Epoch number')\n", - "plt.legend()\n", - "plt.savefig(os.path.join(full_QC_model_path, 'Quality Control', 'lossCurvePlots.png'),bbox_inches='tight',pad_inches=0)\n", - "plt.show()\n", - "\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "X5_92nL2xdP6" - }, - "source": [ - "## **5.2. Error mapping and quality metrics estimation**\n", - "---\n", - "This section will calculate the Intersection over Union score for all the images provided in the Source_QC_folder and Target_QC_folder. The result for one of the image will also be displayed.\n", - "\n", - "The **Intersection over Union** metric is a method that can be used to quantify the percent overlap between the target mask and your prediction output. **Therefore, the closer to 1, the better the performance.** This metric can be used to assess the quality of your model to accurately predict nuclei. \n", - "\n", - "The Input, Ground Truth, Prediction and IoU maps are shown below for the last example in the QC set.\n", - "\n", - " The results for all QC examples can be found in the \"*Quality Control*\" folder which is located inside your \"model_folder\".\n", - "\n", - "### **Thresholds for image masks**\n", - "\n", - " Since the output from Unet is not a binary mask, the output images are converted to binary masks using thresholding. This section will test different thresholds (from 0 to 255) to find the one yielding the best IoU score compared with the ground truth. The best threshold for each image and the average of these thresholds will be displayed below. **These values can be a guideline when creating masks for unseen data in section 6.**" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "w90MdriMxhjD", - "cellView": "form" - }, - "source": [ - "# ------------- User input ------------\n", - "#@markdown ##Choose the folders that contain your Quality Control dataset\n", - "Source_QC_folder = \"\" #@param{type:\"string\"}\n", - "Target_QC_folder = \"\" #@param{type:\"string\"}\n", - "\n", - "\n", - "# ------------- Initialise folders ------------\n", - "# Create a quality control/Prediction Folder\n", - "prediction_QC_folder = os.path.join(full_QC_model_path, 'Quality Control', 'Prediction')\n", - "if os.path.exists(prediction_QC_folder):\n", - " shutil.rmtree(prediction_QC_folder)\n", - "\n", - "os.makedirs(prediction_QC_folder)\n", - "\n", - "\n", - "# ------------- Prepare the model and run predictions ------------\n", - "\n", - "# Load the model\n", - "unet = load_model(os.path.join(full_QC_model_path, 'weights_best.hdf5'), custom_objects={'_weighted_binary_crossentropy': weighted_binary_crossentropy(np.ones(2))})\n", - "Input_size = unet.layers[0].output_shape[1:3]\n", - "print('Model input size: '+str(Input_size[0])+'x'+str(Input_size[1]))\n", - "\n", - "# Create a list of sources\n", - "source_dir_list = os.listdir(Source_QC_folder)\n", - "number_of_dataset = len(source_dir_list)\n", - "print('Number of dataset found in the folder: '+str(number_of_dataset))\n", - "\n", - "predictions = []\n", - "for i in tqdm(range(number_of_dataset)):\n", - " predictions.append(predict_as_tiles(os.path.join(Source_QC_folder, source_dir_list[i]), unet))\n", - "\n", - "\n", - "# Save the results in the folder along with the masks according to the set threshold\n", - "saveResult(prediction_QC_folder, predictions, source_dir_list, prefix=prediction_prefix, threshold=None)\n", - "\n", - "#-----------------------------Calculate Metrics----------------------------------------#\n", - "\n", - "f = plt.figure(figsize=((5,5)))\n", - "\n", - "with open(os.path.join(full_QC_model_path,'Quality Control', 'QC_metrics_'+QC_model_name+'.csv'), \"w\", newline='') as file:\n", - " writer = csv.writer(file)\n", - " writer.writerow([\"File name\",\"IoU\", \"IoU-optimised threshold\"]) \n", - "\n", - " # Initialise the lists \n", - " filename_list = []\n", - " best_threshold_list = []\n", - " best_IoU_score_list = []\n", - "\n", - " for filename in os.listdir(Source_QC_folder):\n", - "\n", - " if not os.path.isdir(os.path.join(Source_QC_folder, filename)):\n", - " print('Running QC on: '+filename)\n", - " test_input = io.imread(os.path.join(Source_QC_folder, filename), as_gray=True)\n", - " test_ground_truth_image = io.imread(os.path.join(Target_QC_folder, filename), as_gray=True)\n", - "\n", - " (threshold_list, iou_scores_per_threshold) = getIoUvsThreshold(os.path.join(prediction_QC_folder, prediction_prefix+filename), os.path.join(Target_QC_folder, filename))\n", - " plt.plot(threshold_list,iou_scores_per_threshold, label=filename)\n", - "\n", - " # Here we find which threshold yielded the highest IoU score for image n.\n", - " best_IoU_score = max(iou_scores_per_threshold)\n", - " best_threshold = iou_scores_per_threshold.index(best_IoU_score)\n", - "\n", - " # Write the results in the CSV file\n", - " writer.writerow([filename, str(best_IoU_score), str(best_threshold)])\n", - "\n", - " # Here we append the best threshold and score to the lists\n", - " filename_list.append(filename)\n", - " best_IoU_score_list.append(best_IoU_score)\n", - " best_threshold_list.append(best_threshold)\n", - "\n", - "# Display the IoV vs Threshold plot\n", - "plt.title('IoU vs. Threshold')\n", - "plt.ylabel('IoU')\n", - "plt.xlabel('Threshold value')\n", - "plt.legend()\n", - "plt.savefig(full_QC_model_path+'/Quality Control/'+QC_model_name+'_IoUvsThresholdPlot.png',bbox_inches='tight',pad_inches=0)\n", - "plt.show()\n", - "\n", - "\n", - "# Table with metrics as dataframe output\n", - "pdResults = pd.DataFrame(index = filename_list)\n", - "pdResults[\"IoU\"] = best_IoU_score_list\n", - "pdResults[\"IoU-optimised threshold\"] = best_threshold_list\n", - "\n", - "average_best_threshold = sum(best_threshold_list)/len(best_threshold_list)\n", - "\n", - "\n", - "# ------------- For display ------------\n", - "print('--------------------------------------------------------------')\n", - "@interact\n", - "def show_QC_results(file=os.listdir(Source_QC_folder)):\n", - " \n", - " plt.figure(figsize=(25,5))\n", - " #Input\n", - " plt.subplot(1,4,1)\n", - " plt.axis('off')\n", - " plt.imshow(plt.imread(os.path.join(Source_QC_folder, file)), aspect='equal', cmap='gray', interpolation='nearest')\n", - " plt.title('Input')\n", - "\n", - " #Ground-truth\n", - " plt.subplot(1,4,2)\n", - " plt.axis('off')\n", - " test_ground_truth_image = io.imread(os.path.join(Target_QC_folder, file),as_gray=True)\n", - " plt.imshow(test_ground_truth_image, aspect='equal', cmap='Greens')\n", - " plt.title('Ground Truth')\n", - "\n", - " #Prediction\n", - " plt.subplot(1,4,3)\n", - " plt.axis('off')\n", - " test_prediction = plt.imread(os.path.join(prediction_QC_folder, prediction_prefix+file))\n", - " test_prediction_mask = np.empty_like(test_prediction)\n", - " test_prediction_mask[test_prediction > average_best_threshold] = 255\n", - " test_prediction_mask[test_prediction <= average_best_threshold] = 0\n", - " plt.imshow(test_prediction_mask, aspect='equal', cmap='Purples')\n", - " plt.title('Prediction')\n", - "\n", - " #Overlay\n", - " plt.subplot(1,4,4)\n", - " plt.axis('off')\n", - " plt.imshow(test_ground_truth_image, cmap='Greens')\n", - " plt.imshow(test_prediction_mask, alpha=0.5, cmap='Purples')\n", - " metrics_title = 'Overlay (IoU: ' + str(round(pdResults.loc[file][\"IoU\"],3)) + ' T: ' + str(round(pdResults.loc[file][\"IoU-optimised threshold\"])) + ')'\n", - " plt.title(metrics_title)\n", - " plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n", - "\n", - "\n", - "print('--------------------------------------------------------------')\n", - "print('Best average threshold is: '+str(round(average_best_threshold)))\n", - "print('--------------------------------------------------------------')\n", - "\n", - "pdResults.head()\n", - "\n", - "qc_pdf_export()" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "8MIEOw6o11q8" - }, - "source": [ - "## **5.3. Export your model into the BioImage Model Zoo format**\n", - "---\n", - "This section exports the model into the BioImage Model Zoo format so it can be used directly with DeepImageJ. The new files will be stored in the model folder specified at the beginning of Section 5. \n", - "\n", - "Once the cell is executed, you will find a new zip file with the name specified in `Trained_model_name.bioimage.io.model`.\n", - "\n", - "To use it with deepImageJ, download it and unzip it in the ImageJ/models/ or Fiji/models/ folder of your local machine. \n", - "\n", - "In ImageJ, open the example image given within the downloaded zip file. Go to Plugins > DeepImageJ > DeepImageJ Run. Choose this model from the list and click OK.\n", - "\n", - " More information at https://deepimagej.github.io/deepimagej/" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "y7a_acB114o7", - "cellView": "form" - }, - "source": [ - "# ------------- User input ------------\n", - "# information about the model\n", - "#@markdown ##Introduce the metadata of the model architecture:\n", - "Trained_model_name = \"\" #@param {type:\"string\"}\n", - "Trained_model_authors = \"[Author 1, Author 2, Author 3]\" #@param {type:\"string\"}\n", - "Trained_model_description = \"\" #@param {type:\"string\"}\n", - "Trained_model_license = 'MIT'#@param {type:\"string\"}\n", - "Trained_model_references = [\"Falk et al. Nature Methods 2019\", \"Ronneberger et al. arXiv in 2015\", \"Lucas von Chamier et al. biorXiv 2020\"] \n", - "Trained_model_DOI = [\"https://doi.org/10.1038/s41592-018-0261-2\",\"https://doi.org/10.1007/978-3-319-24574-4_28\", \"https://doi.org/10.1101/2020.03.20.000133\"] \n", - "\n", - "#@markdown ##Choose a threshold for DeepImageJ's postprocessing macro:\n", - "Use_The_Best_Average_Threshold = True #@param {type:\"boolean\"}\n", - "#@markdown ###If not, please input:\n", - "threshold = 85 #@param {type:\"number\"}\n", - "\n", - "#@markdown ##Introduce the pixel size (in microns) of the image provided as an example of the model processing:\n", - "# information about the example image\n", - "PixelSize = 0.0004 #@param {type:\"number\"}\n", - "#@markdown ##Do you want to choose the exampleimage?\n", - "default_example_image = True #@param {type:\"boolean\"}\n", - "#@markdown ###If not, please input:\n", - "fileID = \"\" #@param {type:\"string\"}\n", - "\n", - "\n", - "if Use_The_Best_Average_Threshold:\n", - " threshold = average_best_threshold\n", - "\n", - "from skimage import io\n", - "\n", - "if default_example_image:\n", - " source_dir_list = os.listdir(Source_QC_folder)\n", - " fileID = os.path.join(Source_QC_folder, source_dir_list[0])\n", - " \n", - "\n", - "# Read the input image\n", - "test_img = io.imread(fileID)\n", - "\n", - "# Load the model\n", - "unet = load_model(os.path.join(full_QC_model_path, 'weights_best.hdf5'), custom_objects={'_weighted_binary_crossentropy': weighted_binary_crossentropy(np.ones(2))})\n", - "test_prediction = predict_as_tiles(fileID, unet)\n", - "# test_prediction = io.imread(os.path.join(prediction_QC_folder, prediction_prefix+fileID))\n", - "# # Binarize it with the threshold chosen\n", - "test_prediction_mask = convert2Mask(test_prediction, threshold)\n", - "\n", - "## Run this cell to export the model to the BioImage Model Zoo format.\n", - "####\n", - "\n", - "from pydeepimagej.yaml import BioImageModelZooConfig\n", - "import urllib\n", - "\n", - "# Check minimum size: it is [8,8] for the 2D XY plane\n", - "pooling_steps = 0\n", - "for keras_layer in unet.layers:\n", - " if keras_layer.name.startswith('max') or \"pool\" in keras_layer.name:\n", - " pooling_steps += 1\n", - "MinimumSize = [2**(pooling_steps), 2**(pooling_steps)]\n", - "\n", - "dij_config = BioImageModelZooConfig(unet, MinimumSize)\n", - "\n", - "\n", - "# Model developer details\n", - "dij_config.Authors = Trained_model_authors[1:-1].split(',')\n", - "dij_config.Description = Trained_model_description\n", - "dij_config.Name = Trained_model_name\n", - "dij_config.References = Trained_model_references\n", - "dij_config.DOI = Trained_model_DOI\n", - "dij_config.License = Trained_model_license\n", - "\n", - "# Additional information about the model\n", - "dij_config.GitHub = '/~https://github.com/HenriquesLab/ZeroCostDL4Mic'\n", - "dij_config.Date = datetime.now()\n", - "dij_config.Documentation = '/~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki'\n", - "dij_config.Tags = ['ZeroCostDL4Mic', 'deepimagej', 'segmentation', 'TEM', 'unet']\n", - "dij_config.Framework = 'tensorflow'\n", - "\n", - "\n", - "# Add the information about the test image. Note here PixelSize is given in nm\n", - "dij_config.add_test_info(test_img, test_prediction_mask, [PixelSize, PixelSize])\n", - "dij_config.create_covers([test_img, test_prediction_mask])\n", - "dij_config.Covers = ['./input.png', './output.png']\n", - "\n", - "\n", - "## Prepare preprocessing file\n", - "min_percentile = 0\n", - "max_percentile = 99.85\n", - "\n", - "path_preprocessing = \"per_sample_scale_range.ijm\"\n", - "urllib.request.urlretrieve('https://raw.githubusercontent.com/deepimagej/imagej-macros/master/bioimage.io/per_sample_scale_range.ijm', path_preprocessing )\n", - "\n", - "# Modify the threshold in the macro to the chosen threshold\n", - "ijmacro = open(path_preprocessing,\"r\") \n", - "list_of_lines = ijmacro. readlines()\n", - "# Line 21 is the one corresponding to the optimal threshold\n", - "list_of_lines[24] = \"min_percentile = {};\\n\".format(min_percentile)\n", - "list_of_lines[25] = \"max_percentile = {};\\n\".format(max_percentile)\n", - "ijmacro = open(path_preprocessing,\"w\") \n", - "ijmacro. writelines(list_of_lines)\n", - "ijmacro. close()\n", - "# Include info about the macros \n", - "dij_config.Preprocessing = [path_preprocessing]\n", - "dij_config.Preprocessing_files = [path_preprocessing]\n", - "# Preprocessing following BioImage Model Zoo specifications\n", - "dij_config.add_bioimageio_spec('pre-processing', 'scale_range',\n", - " mode='per_sample', axes='xyzc',\n", - " min_percentile = min_percentile,\n", - " max_percentile = max_percentile)\n", - "\n", - "## Prepare postprocessing file\n", - "path_postprocessing = \"8bitBinarize.ijm\"\n", - "urllib.request.urlretrieve(\"https://raw.githubusercontent.com/deepimagej/imagej-macros/master/8bitBinarize.ijm\", path_postprocessing )\n", - "\n", - "# Modify the threshold in the macro to the chosen threshold\n", - "ijmacro = open(path_postprocessing,\"r\") \n", - "list_of_lines = ijmacro. readlines()\n", - "# Line 21 is the one corresponding to the optimal threshold\n", - "list_of_lines[21] = \"optimalThreshold = {};\\n\".format(threshold)\n", - "ijmacro.close()\n", - "ijmacro = open(path_postprocessing,\"w\") \n", - "ijmacro. writelines(list_of_lines)\n", - "ijmacro. close()\n", - "\n", - "# Include info about the macros \n", - "dij_config.Postprocessing = [path_postprocessing]\n", - "dij_config.Postprocessing_files = [path_postprocessing]\n", - "# Preprocessing following BioImage Model Zoo specifications\n", - "dij_config.add_bioimageio_spec('post-processing', 'scale_range',\n", - " mode='per_sample', axes='xyzc', \n", - " min_percentile=0, max_percentile=100)\n", - "\n", - "dij_config.add_bioimageio_spec('post-processing', 'scale_linear',\n", - " gain=255, offset=0, axes='xy')\n", - "\n", - "dij_config.add_bioimageio_spec('post-processing', 'binarize',\n", - " threshold=threshold)\n", - "\n", - "\n", - "# Store the model weights\n", - "# ---------------------------------------\n", - "# used_bioimageio_model_for_training_URL = \"/Some/URL/bioimage.io/\"\n", - "# dij_config.Parent = used_bioimageio_model_for_training_URL\n", - "\n", - "# Add weights information\n", - "format_authors = [\"pydeepimagej\"]\n", - "dij_config.add_weights_formats(unet, 'TensorFlow', \n", - " parent=\"keras_hdf5\",\n", - " authors=[a for a in format_authors])\n", - "dij_config.add_weights_formats(unet, 'KerasHDF5', \n", - " authors=[a for a in format_authors])\n", - "\n", - "## EXPORT THE MODEL\n", - "deepimagej_model_path = os.path.join(full_QC_model_path, Trained_model_name+'.bioimage.io.model')\n", - "dij_config.export_model(deepimagej_model_path)\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-tJeeJjLnRkP" - }, - "source": [ - "# **6. Using the trained model**\n", - "\n", - "---\n", - "In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "d8wuQGjoq6eN" - }, - "source": [ - "## **6.1 Generate prediction(s) from unseen dataset**\n", - "---\n", - "\n", - "The current trained model (from section 4.1) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder.\n", - "\n", - "**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n", - "\n", - "**`Result_folder`:** This folder will contain the predicted output images.\n", - "\n", - " Once the predictions are complete the cell will display a random example prediction beside the input image and the calculated mask for visual inspection.\n", - "\n", - " **Troubleshooting:** If there is a low contrast image warning when saving the images, this may be due to overfitting of the model to the data. It may result in images containing only a single colour. Train the network again with different network hyperparameters." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "y2TD5p7MZrEb", - "cellView": "form" - }, - "source": [ - "\n", - "\n", - "# ------------- Initial user input ------------\n", - "#@markdown ###Provide the path to your dataset and to the folder where the prediction will be saved (Result folder), then play the cell to predict output on your unseen images.\n", - "Data_folder = '' #@param {type:\"string\"}\n", - "Results_folder = '' #@param {type:\"string\"}\n", - "\n", - "#@markdown ###Do you want to use the current trained model?\n", - "Use_the_current_trained_model = True #@param {type:\"boolean\"}\n", - "\n", - "#@markdown ###If not, please provide the path to the model folder:\n", - "\n", - "Prediction_model_folder = \"\" #@param {type:\"string\"}\n", - "\n", - "#Here we find the loaded model name and parent path\n", - "Prediction_model_name = os.path.basename(Prediction_model_folder)\n", - "Prediction_model_path = os.path.dirname(Prediction_model_folder)\n", - "\n", - "\n", - "# ------------- Failsafes ------------\n", - "if (Use_the_current_trained_model): \n", - " print(\"Using current trained network\")\n", - " Prediction_model_name = model_name\n", - " Prediction_model_path = model_path\n", - "\n", - "full_Prediction_model_path = os.path.join(Prediction_model_path, Prediction_model_name)\n", - "if os.path.exists(full_Prediction_model_path):\n", - " print(\"The \"+Prediction_model_name+\" network will be used.\")\n", - "else:\n", - " print(R+'!! WARNING: The chosen model does not exist !!'+W)\n", - " print('Please make sure you provide a valid model path and model name before proceeding further.')\n", - "\n", - "\n", - "# ------------- Prepare the model and run predictions ------------\n", - "\n", - "# Load the model and prepare generator\n", - "\n", - "\n", - "\n", - "unet = load_model(os.path.join(Prediction_model_path, Prediction_model_name, 'weights_best.hdf5'), custom_objects={'_weighted_binary_crossentropy': weighted_binary_crossentropy(np.ones(2))})\n", - "Input_size = unet.layers[0].output_shape[1:3]\n", - "print('Model input size: '+str(Input_size[0])+'x'+str(Input_size[1]))\n", - "\n", - "# Create a list of sources\n", - "source_dir_list = os.listdir(Data_folder)\n", - "number_of_dataset = len(source_dir_list)\n", - "print('Number of dataset found in the folder: '+str(number_of_dataset))\n", - "\n", - "predictions = []\n", - "for i in tqdm(range(number_of_dataset)):\n", - " predictions.append(predict_as_tiles(os.path.join(Data_folder, source_dir_list[i]), unet))\n", - " # predictions.append(prediction(os.path.join(Data_folder, source_dir_list[i]), os.path.join(Prediction_model_path, Prediction_model_name)))\n", - "\n", - "\n", - "# Save the results in the folder along with the masks according to the set threshold\n", - "saveResult(Results_folder, predictions, source_dir_list, prefix=prediction_prefix, threshold=None)\n", - "\n", - "\n", - "# ------------- For display ------------\n", - "print('--------------------------------------------------------------')\n", - "\n", - "\n", - "def show_prediction_mask(file=os.listdir(Data_folder), threshold=(0,255,1)):\n", - "\n", - " plt.figure(figsize=(18,6))\n", - " # Wide-field\n", - " plt.subplot(1,3,1)\n", - " plt.axis('off')\n", - " img_Source = plt.imread(os.path.join(Data_folder, file))\n", - " plt.imshow(img_Source, cmap='gray')\n", - " plt.title('Source image',fontsize=15)\n", - " # Prediction\n", - " plt.subplot(1,3,2)\n", - " plt.axis('off')\n", - " img_Prediction = plt.imread(os.path.join(Results_folder, prediction_prefix+file))\n", - " plt.imshow(img_Prediction, cmap='gray')\n", - " plt.title('Prediction',fontsize=15)\n", - "\n", - " # Thresholded mask\n", - " plt.subplot(1,3,3)\n", - " plt.axis('off')\n", - " img_Mask = convert2Mask(img_Prediction, threshold)\n", - " plt.imshow(img_Mask, cmap='gray')\n", - " plt.title('Mask (Threshold: '+str(round(threshold))+')',fontsize=15)\n", - "\n", - "\n", - "interact(show_prediction_mask, continuous_update=False);\n", - "\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "stS96mFZLMOU" - }, - "source": [ - "## **6.2. Export results as masks**\n", - "---\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "cellView": "form", - "id": "qb5ZmFstLNbR" - }, - "source": [ - "\n", - "# @markdown #Play this cell to save results as masks with the chosen threshold\n", - "threshold = 120#@param {type:\"number\"}\n", - "\n", - "saveResult(Results_folder, predictions, source_dir_list, prefix=prediction_prefix, threshold=threshold)\n", - "print('-------------------')\n", - "print('Masks were saved in: '+Results_folder)\n", - "\n", - "\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "hvkd66PldsXB" - }, - "source": [ - "## **6.3. Download your predictions**\n", - "---\n", - "\n", - "**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "UvSlTaH14s3t" - }, - "source": [ - "#**Thank you for using 2D U-Net!**\n" - ] - } - ] -} \ No newline at end of file +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"U-Net_2D_ZeroCostDL4Mic_BioImageModelZoo_export.ipynb","provenance":[{"file_id":"1USP7Bmd4UEdhp9cOBc_wlqDXnZjMHk_f","timestamp":1622215174280},{"file_id":"1EZG34jBKULVmO__Fmv7Lr76sVHIMxwJx","timestamp":1622041273450},{"file_id":"1MmLTCC0nyX3Akb9V4C_OVxM3X_M8u-eX","timestamp":1610543191319},{"file_id":"1paNjUObR5Rcr4BMGADJTz0PQBBLZDPrY","timestamp":1602522500580},{"file_id":"1WZRIoSBNcRUEq4-Rq5M4mDkIaOlEHnxz","timestamp":1588762142860},{"file_id":"10weAY0es-pEfHlACCaBCKK7PmgdoJqdh","timestamp":1587728072051},{"file_id":"10Ze0rFZoooyyTL_OIVWGdFJEhWE6_cSB","timestamp":1586789421439},{"file_id":"1SsGyUbWcMaLGHFepMuKElRNYLdEBUwf6","timestamp":1583244509550}],"collapsed_sections":[],"machine_shape":"hm"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.7"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"IkSguVy8Xv83"},"source":["# **U-Net (2D)**\n","---\n","\n","U-Net is an encoder-decoder network architecture originally used for image segmentation, first published by [Ronneberger *et al.*](https://arxiv.org/abs/1505.04597). The first half of the U-Net architecture is a downsampling convolutional neural network which acts as a feature extractor from input images. The other half upsamples these results and restores an image by combining results from downsampling with the upsampled images.\n","\n"," **This particular notebook enables image segmentation of 2D dataset. If you are interested in 3D dataset, you should use the 3D U-Net notebook instead.**\n","\n","---\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the papers: \n","\n","**U-Net: Convolutional Networks for Biomedical Image Segmentation** by Ronneberger *et al.* published on arXiv in 2015 (https://arxiv.org/abs/1505.04597)\n","\n","and \n","\n","**U-Net: deep learning for cell counting, detection, and morphometry** by Thorsten Falk *et al.* in Nature Methods 2019\n","(https://www.nature.com/articles/s41592-018-0261-2)\n","And source code found in: /~https://github.com/zhixuhao/unet by *Zhixuhao*\n","\n","**Please also cite this original paper when using or developing this notebook.** "]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"gKDLkLWUd-YX"},"source":["# **0. Before getting started**\n","---\n","\n","Before you run the notebook, please ensure that you are logged into your Google account and have the training and/or data to process in your Google Drive.\n","\n","For U-Net to train, **it needs to have access to a paired training dataset corresponding to images and their corresponding masks**. Information on how to generate a training dataset is available in our Wiki page: /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n","Additionally, the corresponding Training_source and Training_target files need to have **the same name**.\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Training_source\n"," - img_1.tif, img_2.tif, ...\n"," - Training_target\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - Training_source\n"," - img_1.tif, img_2.tif\n"," - Training_target \n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"AdN8B91xZO0x"},"source":["# **1. Install U-Net dependencies**\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"UGWnGOFsf07b"},"source":["## **1.1. Install key dependencies**\n","---\n"," "]},{"cell_type":"code","metadata":{"id":"uc0haIa-fZiG","cellView":"form"},"source":["#@markdown ##Play to install U-Net dependencies\n","\n","!pip install pydeepimagej==2.1.2\n","!pip install data\n","!pip install fpdf\n","!pip install h5py==2.10\n","\n","#Force session restart\n","exit(0)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"I4O5zctbf4Gb"},"source":["\n","## **1.2. Restart your runtime**\n","---\n","\n","\n","\n","** Ignore the following message error message. Your Runtime has automatically restarted. This is normal.**\n","\n","\"\"
"]},{"cell_type":"markdown","metadata":{"id":"iiX3Ly-7gA5h"},"source":["## **1.3. Load key dependencies**\n","---\n"," "]},{"cell_type":"code","metadata":{"id":"fq21zJVFNASx","cellView":"form"},"source":["Notebook_version = '1.13'\n","Network = 'U-Net (2D)'\n","\n","\n","from builtins import any as b_any\n","\n","def get_requirements_path():\n"," # Store requirements file in 'contents' directory \n"," current_dir = os.getcwd()\n"," dir_count = current_dir.count('/') - 1\n"," path = '../' * (dir_count) + 'requirements.txt'\n"," return path\n","\n","def filter_files(file_list, filter_list):\n"," filtered_list = []\n"," for fname in file_list:\n"," if b_any(fname.split('==')[0] in s for s in filter_list):\n"," filtered_list.append(fname)\n"," return filtered_list\n","\n","def build_requirements_file(before, after):\n"," path = get_requirements_path()\n","\n"," # Exporting requirements.txt for local run\n"," !pip freeze > $path\n","\n"," # Get minimum requirements file\n"," df = pd.read_csv(path, delimiter = \"\\n\")\n"," mod_list = [m.split('.')[0] for m in after if not m in before]\n"," req_list_temp = df.values.tolist()\n"," req_list = [x[0] for x in req_list_temp]\n","\n"," # Replace with package name and handle cases where import name is different to module name\n"," mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]\n"," mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list] \n"," filtered_list = filter_files(req_list, mod_replace_list)\n","\n"," file=open(path,'w')\n"," for item in filtered_list:\n"," file.writelines(item + '\\n')\n","\n"," file.close()\n","\n","import sys\n","before = [str(m) for m in sys.modules]\n","\n","#@markdown ##Load key U-Net dependencies\n","\n","#As this notebokk depends mostly on keras which runs a tensorflow backend (which in turn is pre-installed in colab)\n","#only the data library needs to be additionally installed.\n","%tensorflow_version 1.x\n","import tensorflow as tf\n","# print(tensorflow.__version__)\n","# print(\"Tensorflow enabled.\")\n","\n","\n","# Keras imports\n","from keras import models\n","from keras.models import Model, load_model\n","from keras.layers import Input, Conv2D, MaxPooling2D, Dropout, concatenate, UpSampling2D\n","from keras.optimizers import Adam\n","# from keras.callbacks import ModelCheckpoint, LearningRateScheduler, CSVLogger # we currently don't use any other callbacks from ModelCheckpoints\n","from keras.callbacks import ModelCheckpoint\n","from keras.callbacks import ReduceLROnPlateau\n","from keras.preprocessing.image import ImageDataGenerator, img_to_array, load_img\n","from keras import backend as keras\n","\n","# General import\n","from __future__ import print_function\n","import numpy as np\n","import pandas as pd\n","import os\n","import glob\n","from skimage import img_as_ubyte, io, transform\n","import matplotlib as mpl\n","from matplotlib import pyplot as plt\n","from matplotlib.pyplot import imread\n","from pathlib import Path\n","import shutil\n","import random\n","import time\n","import csv\n","import sys\n","from math import ceil\n","from fpdf import FPDF, HTMLMixin\n","from pip._internal.operations.freeze import freeze\n","import subprocess\n","# Imports for QC\n","from PIL import Image\n","from scipy import signal\n","from scipy import ndimage\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","\n","# For sliders and dropdown menu and progress bar\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","# from tqdm import tqdm\n","from tqdm.notebook import tqdm\n","\n","from sklearn.feature_extraction import image\n","from skimage import img_as_ubyte, io, transform\n","from skimage.util.shape import view_as_windows\n","\n","from datetime import datetime\n","\n","\n","# Suppressing some warnings\n","import warnings\n","warnings.filterwarnings('ignore')\n","\n","\n","\n","\n","def create_patches(Training_source, Training_target, patch_width, patch_height, min_fraction):\n"," \"\"\"\n"," Function creates patches from the Training_source and Training_target images. \n"," The steps parameter indicates the offset between patches and, if integer, is the same in x and y.\n"," Saves all created patches in two new directories in the /content folder.\n","\n"," min_fraction is the minimum fraction of pixels that need to be foreground to be considered as a valid patch\n","\n"," Returns: - Two paths to where the patches are now saved\n"," \"\"\"\n"," DEBUG = False\n","\n"," Patch_source = os.path.join('/content','img_patches')\n"," Patch_target = os.path.join('/content','mask_patches')\n"," Patch_rejected = os.path.join('/content','rejected')\n"," \n","\n"," #Here we save the patches, in the /content directory as they will not usually be needed after training\n"," if os.path.exists(Patch_source):\n"," shutil.rmtree(Patch_source)\n"," if os.path.exists(Patch_target):\n"," shutil.rmtree(Patch_target)\n"," if os.path.exists(Patch_rejected):\n"," shutil.rmtree(Patch_rejected)\n","\n"," os.mkdir(Patch_source)\n"," os.mkdir(Patch_target)\n"," os.mkdir(Patch_rejected) #This directory will contain the images that have too little signal.\n"," \n"," patch_num = 0\n","\n"," for file in tqdm(os.listdir(Training_source)):\n","\n"," img = io.imread(os.path.join(Training_source, file))\n"," mask = io.imread(os.path.join(Training_target, file),as_gray=True)\n","\n"," if DEBUG:\n"," print(file)\n"," print(img.dtype)\n","\n"," # Using view_as_windows with step size equal to the patch size to ensure there is no overlap\n"," patches_img = view_as_windows(img, (patch_width, patch_height), (patch_width, patch_height))\n"," patches_mask = view_as_windows(mask, (patch_width, patch_height), (patch_width, patch_height))\n","\n"," patches_img = patches_img.reshape(patches_img.shape[0]*patches_img.shape[1], patch_width,patch_height)\n"," patches_mask = patches_mask.reshape(patches_mask.shape[0]*patches_mask.shape[1], patch_width,patch_height)\n","\n"," if DEBUG:\n"," print(all_patches_img.shape)\n"," print(all_patches_img.dtype)\n","\n"," for i in range(patches_img.shape[0]):\n"," img_save_path = os.path.join(Patch_source,'patch_'+str(patch_num)+'.tif')\n"," mask_save_path = os.path.join(Patch_target,'patch_'+str(patch_num)+'.tif')\n"," patch_num += 1\n","\n"," # if the mask conatins at least 2% of its total number pixels as mask, then go ahead and save the images\n"," pixel_threshold_array = sorted(patches_mask[i].flatten())\n"," if pixel_threshold_array[int(round(len(pixel_threshold_array)*(1-min_fraction)))]>0:\n"," io.imsave(img_save_path, img_as_ubyte(normalizeMinMax(patches_img[i])))\n"," io.imsave(mask_save_path, convert2Mask(normalizeMinMax(patches_mask[i]),0))\n"," else:\n"," io.imsave(Patch_rejected+'/patch_'+str(patch_num)+'_image.tif', img_as_ubyte(normalizeMinMax(patches_img[i])))\n"," io.imsave(Patch_rejected+'/patch_'+str(patch_num)+'_mask.tif', convert2Mask(normalizeMinMax(patches_mask[i]),0))\n","\n"," return Patch_source, Patch_target\n","\n","\n","def estimatePatchSize(data_path, max_width = 512, max_height = 512):\n","\n"," files = os.listdir(data_path)\n"," \n"," # Get the size of the first image found in the folder and initialise the variables to that\n"," n = 0 \n"," while os.path.isdir(os.path.join(data_path, files[n])):\n"," n += 1\n"," (height_min, width_min) = Image.open(os.path.join(data_path, files[n])).size\n","\n"," # Screen the size of all dataset to find the minimum image size\n"," for file in files:\n"," if not os.path.isdir(os.path.join(data_path, file)):\n"," (height, width) = Image.open(os.path.join(data_path, file)).size\n"," if width < width_min:\n"," width_min = width\n"," if height < height_min:\n"," height_min = height\n"," \n"," # Find the power of patches that will fit within the smallest dataset\n"," width_min, height_min = (fittingPowerOfTwo(width_min), fittingPowerOfTwo(height_min))\n","\n"," # Clip values at maximum permissible values\n"," if width_min > max_width:\n"," width_min = max_width\n","\n"," if height_min > max_height:\n"," height_min = max_height\n"," \n"," return (width_min, height_min)\n","\n","def fittingPowerOfTwo(number):\n"," n = 0\n"," while 2**n <= number:\n"," n += 1 \n"," return 2**(n-1)\n","\n","\n","def getClassWeights(Training_target_path):\n","\n"," Mask_dir_list = os.listdir(Training_target_path)\n"," number_of_dataset = len(Mask_dir_list)\n","\n"," class_count = np.zeros(2, dtype=int)\n"," for i in tqdm(range(number_of_dataset)):\n"," mask = io.imread(os.path.join(Training_target_path, Mask_dir_list[i]))\n"," mask = normalizeMinMax(mask)\n"," class_count[0] += mask.shape[0]*mask.shape[1] - mask.sum()\n"," class_count[1] += mask.sum()\n","\n"," n_samples = class_count.sum()\n"," n_classes = 2\n","\n"," class_weights = n_samples / (n_classes * class_count)\n"," return class_weights\n","\n","def weighted_binary_crossentropy(class_weights):\n","\n"," def _weighted_binary_crossentropy(y_true, y_pred):\n"," binary_crossentropy = keras.binary_crossentropy(y_true, y_pred)\n"," weight_vector = y_true * class_weights[1] + (1. - y_true) * class_weights[0]\n"," weighted_binary_crossentropy = weight_vector * binary_crossentropy\n","\n"," return keras.mean(weighted_binary_crossentropy)\n","\n"," return _weighted_binary_crossentropy\n","\n","\n","def save_augment(datagen,orig_img,dir_augmented_data=\"/content/augment\"):\n"," \"\"\"\n"," Saves a subset of the augmented data for visualisation, by default in /content.\n","\n"," This is adapted from: https://fairyonice.github.io/Learn-about-ImageDataGenerator.html\n"," \n"," \"\"\"\n"," try:\n"," os.mkdir(dir_augmented_data)\n"," except:\n"," ## if the preview folder exists, then remove\n"," ## the contents (pictures) in the folder\n"," for item in os.listdir(dir_augmented_data):\n"," os.remove(dir_augmented_data + \"/\" + item)\n","\n"," ## convert the original image to array\n"," x = img_to_array(orig_img)\n"," ## reshape (Sampke, Nrow, Ncol, 3) 3 = R, G or B\n"," #print(x.shape)\n"," x = x.reshape((1,) + x.shape)\n"," #print(x.shape)\n"," ## -------------------------- ##\n"," ## randomly generate pictures\n"," ## -------------------------- ##\n"," i = 0\n"," #We will just save 5 images,\n"," #but this can be changed, but note the visualisation in 3. currently uses 5.\n"," Nplot = 5\n"," for batch in datagen.flow(x,batch_size=1,\n"," save_to_dir=dir_augmented_data,\n"," save_format='tif',\n"," seed=42):\n"," i += 1\n"," if i > Nplot - 1:\n"," break\n","\n","# Generators\n","def buildDoubleGenerator(image_datagen, mask_datagen, image_folder_path, mask_folder_path, subset, batch_size, target_size):\n"," '''\n"," Can generate image and mask at the same time use the same seed for image_datagen and mask_datagen to ensure the transformation for image and mask is the same\n"," \n"," datagen: ImageDataGenerator \n"," subset: can take either 'training' or 'validation'\n"," '''\n"," seed = 1\n"," image_generator = image_datagen.flow_from_directory(\n"," os.path.dirname(image_folder_path),\n"," classes = [os.path.basename(image_folder_path)],\n"," class_mode = None,\n"," color_mode = \"grayscale\",\n"," target_size = target_size,\n"," batch_size = batch_size,\n"," subset = subset,\n"," interpolation = \"bicubic\",\n"," seed = seed)\n"," \n"," mask_generator = mask_datagen.flow_from_directory(\n"," os.path.dirname(mask_folder_path),\n"," classes = [os.path.basename(mask_folder_path)],\n"," class_mode = None,\n"," color_mode = \"grayscale\",\n"," target_size = target_size,\n"," batch_size = batch_size,\n"," subset = subset,\n"," interpolation = \"nearest\",\n"," seed = seed)\n"," \n"," this_generator = zip(image_generator, mask_generator)\n"," for (img,mask) in this_generator:\n"," # img,mask = adjustData(img,mask)\n"," yield (img,mask)\n","\n","\n","def prepareGenerators(image_folder_path, mask_folder_path, datagen_parameters, batch_size = 4, target_size = (512, 512)):\n"," image_datagen = ImageDataGenerator(**datagen_parameters, preprocessing_function = normalizePercentile)\n"," mask_datagen = ImageDataGenerator(**datagen_parameters, preprocessing_function = normalizeMinMax)\n","\n"," train_datagen = buildDoubleGenerator(image_datagen, mask_datagen, image_folder_path, mask_folder_path, 'training', batch_size, target_size)\n"," validation_datagen = buildDoubleGenerator(image_datagen, mask_datagen, image_folder_path, mask_folder_path, 'validation', batch_size, target_size)\n","\n"," return (train_datagen, validation_datagen)\n","\n","\n","# Normalization functions from Martin Weigert\n","def normalizePercentile(x, pmin=1, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","\n","\n","# Simple normalization to min/max fir the Mask\n","def normalizeMinMax(x, dtype=np.float32):\n"," x = x.astype(dtype,copy=False)\n"," x = (x - np.amin(x)) / (np.amax(x) - np.amin(x))\n"," return x\n","\n","\n","# This is code outlines the architecture of U-net. The choice of pooling steps decides the depth of the network. \n","def unet(pretrained_weights = None, input_size = (256,256,1), pooling_steps = 4, learning_rate = 1e-4, verbose=True, class_weights=np.ones(2)):\n"," inputs = Input(input_size)\n"," conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(inputs)\n"," conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1)\n"," # Downsampling steps\n"," pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)\n"," conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1)\n"," conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2)\n"," \n"," if pooling_steps > 1:\n"," pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)\n"," conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2)\n"," conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3)\n","\n"," if pooling_steps > 2:\n"," pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)\n"," conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3)\n"," conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4)\n"," drop4 = Dropout(0.5)(conv4)\n"," \n"," if pooling_steps > 3:\n"," pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)\n"," conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4)\n"," conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5)\n"," drop5 = Dropout(0.5)(conv5)\n","\n"," #Upsampling steps\n"," up6 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5))\n"," merge6 = concatenate([drop4,up6], axis = 3)\n"," conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6)\n"," conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6)\n"," \n"," if pooling_steps > 2:\n"," up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop4))\n"," if pooling_steps > 3:\n"," up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6))\n"," merge7 = concatenate([conv3,up7], axis = 3)\n"," conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7)\n"," conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7)\n","\n"," if pooling_steps > 1:\n"," up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv3))\n"," if pooling_steps > 2:\n"," up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7))\n"," merge8 = concatenate([conv2,up8], axis = 3)\n"," conv8 = Conv2D(128, 3, activation= 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8)\n"," conv8 = Conv2D(128, 3, activation= 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8)\n"," \n"," if pooling_steps == 1:\n"," up9 = Conv2D(64, 2, padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv2))\n"," else:\n"," up9 = Conv2D(64, 2, padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8)) #activation = 'relu'\n"," \n"," merge9 = concatenate([conv1,up9], axis = 3)\n"," conv9 = Conv2D(64, 3, padding = 'same', kernel_initializer = 'he_normal')(merge9) #activation = 'relu'\n"," conv9 = Conv2D(64, 3, padding = 'same', kernel_initializer = 'he_normal')(conv9) #activation = 'relu'\n"," conv9 = Conv2D(2, 3, padding = 'same', kernel_initializer = 'he_normal')(conv9) #activation = 'relu'\n"," conv10 = Conv2D(1, 1, activation = 'sigmoid')(conv9)\n","\n"," model = Model(inputs = inputs, outputs = conv10)\n","\n"," # model.compile(optimizer = Adam(lr = learning_rate), loss = 'binary_crossentropy', metrics = ['acc'])\n"," model.compile(optimizer = Adam(lr = learning_rate), loss = weighted_binary_crossentropy(class_weights))\n","\n","\n"," if verbose:\n"," model.summary()\n","\n"," if(pretrained_weights):\n"," \tmodel.load_weights(pretrained_weights);\n","\n"," return model\n","\n","\n","\n","def predict_as_tiles(Image_path, model):\n","\n"," # Read the data in and normalize\n"," Image_raw = io.imread(Image_path, as_gray = True)\n"," Image_raw = normalizePercentile(Image_raw)\n","\n"," # Get the patch size from the input layer of the model\n"," patch_size = model.layers[0].output_shape[1:3]\n","\n"," # Pad the image with zeros if any of its dimensions is smaller than the patch size\n"," if Image_raw.shape[0] < patch_size[0] or Image_raw.shape[1] < patch_size[1]:\n"," Image = np.zeros((max(Image_raw.shape[0], patch_size[0]), max(Image_raw.shape[1], patch_size[1])))\n"," Image[0:Image_raw.shape[0], 0: Image_raw.shape[1]] = Image_raw\n"," else:\n"," Image = Image_raw\n","\n"," # Calculate the number of patches in each dimension\n"," n_patch_in_width = ceil(Image.shape[0]/patch_size[0])\n"," n_patch_in_height = ceil(Image.shape[1]/patch_size[1])\n","\n"," prediction = np.zeros(Image.shape)\n","\n"," for x in range(n_patch_in_width):\n"," for y in range(n_patch_in_height):\n"," xi = patch_size[0]*x\n"," yi = patch_size[1]*y\n","\n"," # If the patch exceeds the edge of the image shift it back \n"," if xi+patch_size[0] >= Image.shape[0]:\n"," xi = Image.shape[0]-patch_size[0]\n","\n"," if yi+patch_size[1] >= Image.shape[1]:\n"," yi = Image.shape[1]-patch_size[1]\n"," \n"," # Extract and reshape the patch\n"," patch = Image[xi:xi+patch_size[0], yi:yi+patch_size[1]]\n"," patch = np.reshape(patch,patch.shape+(1,))\n"," patch = np.reshape(patch,(1,)+patch.shape)\n","\n"," # Get the prediction from the patch and paste it in the prediction in the right place\n"," predicted_patch = model.predict(patch, batch_size = 1)\n"," prediction[xi:xi+patch_size[0], yi:yi+patch_size[1]] = np.squeeze(predicted_patch)\n","\n","\n"," return prediction[0:Image_raw.shape[0], 0: Image_raw.shape[1]]\n"," \n","\n","\n","\n","def saveResult(save_path, nparray, source_dir_list, prefix='', threshold=None):\n"," for (filename, image) in zip(source_dir_list, nparray):\n"," io.imsave(os.path.join(save_path, prefix+os.path.splitext(filename)[0]+'.tif'), img_as_ubyte(image)) # saving as unsigned 8-bit image\n"," \n"," # For masks, threshold the images and return 8 bit image\n"," if threshold is not None:\n"," mask = convert2Mask(image, threshold)\n"," io.imsave(os.path.join(save_path, prefix+'mask_'+os.path.splitext(filename)[0]+'.tif'), mask)\n","\n","\n","def convert2Mask(image, threshold):\n"," mask = img_as_ubyte(image, force_copy=True)\n"," mask[mask > threshold] = 255\n"," mask[mask <= threshold] = 0\n"," return mask\n","\n","\n","def getIoUvsThreshold(prediction_filepath, groud_truth_filepath):\n"," prediction = io.imread(prediction_filepath)\n"," ground_truth_image = img_as_ubyte(io.imread(groud_truth_filepath, as_gray=True), force_copy=True)\n","\n"," threshold_list = []\n"," IoU_scores_list = []\n","\n"," for threshold in range(0,256): \n"," # Convert to 8-bit for calculating the IoU\n"," mask = img_as_ubyte(prediction, force_copy=True)\n"," mask[mask > threshold] = 255\n"," mask[mask <= threshold] = 0\n","\n"," # Intersection over Union metric\n"," intersection = np.logical_and(ground_truth_image, np.squeeze(mask))\n"," union = np.logical_or(ground_truth_image, np.squeeze(mask))\n"," iou_score = np.sum(intersection) / np.sum(union)\n","\n"," threshold_list.append(threshold)\n"," IoU_scores_list.append(iou_score)\n","\n"," return (threshold_list, IoU_scores_list)\n","\n","\n","\n","# -------------- Other definitions -----------\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","prediction_prefix = 'Predicted_'\n","\n","\n","print('-------------------')\n","print('U-Net and dependencies installed.')\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","\n","# Check if this is the latest version of the notebook\n","\n","All_notebook_versions = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_Notebook_versions.csv\", dtype=str)\n","print('Notebook version: '+Notebook_version)\n","Latest_Notebook_version = All_notebook_versions[All_notebook_versions[\"Notebook\"] == Network]['Version'].iloc[0]\n","print('Latest notebook version: '+Latest_Notebook_version)\n","if Notebook_version == Latest_Notebook_version:\n"," print(\"This notebook is up-to-date.\")\n","else:\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","\n","def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n","\n","\n","\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," \n"," # add another cell \n"," if trained:\n"," training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n"," pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n"," pdf.ln(1)\n","\n"," Header_2 = 'Information for your materials and method:'\n"," pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n"," #print(all_packages)\n","\n"," #Main Packages\n"," main_packages = ''\n"," version_numbers = []\n"," for name in ['tensorflow','numpy','Keras']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n"," cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n"," cuda_version = cuda_version.stdout.decode('utf-8')\n"," cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n"," gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n"," gpu_name = gpu_name.stdout.decode('utf-8')\n"," gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n"," #print(cuda_version[cuda_version.find(', V')+3:-1])\n"," #print(gpu_name)\n"," loss = str(model.loss)[str(model.loss).find('function')+len('function'):str(model.loss).find('.<')]\n"," shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n"," dataset_size = len(os.listdir(Training_source))\n","\n"," text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(number_of_training_dataset)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_width)+','+str(patch_height)+')) with a batch size of '+str(batch_size)+' and a'+loss+' loss function,'+' using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," if pretrained_model:\n"," text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(number_of_training_dataset)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_width)+','+str(patch_height)+')) with a batch size of '+str(batch_size)+' and a'+loss+' loss function,'+' using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was re-trained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.multi_cell(180, 5, txt = text, align='L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(28, 5, txt='Augmentation: ', ln=1)\n"," pdf.set_font('')\n"," if augmentation:\n"," aug_text = 'The dataset was augmented by'\n"," if rotation_range != 0:\n"," aug_text = aug_text+'\\n- rotation'\n"," if horizontal_flip == True or vertical_flip == True:\n"," aug_text = aug_text+'\\n- flipping'\n"," if zoom_range != 0:\n"," aug_text = aug_text+'\\n- random zoom magnification'\n"," if horizontal_shift != 0 or vertical_shift != 0:\n"," aug_text = aug_text+'\\n- shifting'\n"," if shear_range != 0:\n"," aug_text = aug_text+'\\n- image shearing'\n"," else:\n"," aug_text = 'No augmentation was used for training.'\n"," pdf.multi_cell(190, 5, txt=aug_text, align='L')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," if Use_Default_Advanced_Parameters:\n"," pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n"," pdf.cell(200, 5, txt='The following parameters were used for training:')\n"," pdf.ln(1)\n"," html = \"\"\" \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
number_of_epochs{0}
patch_size{1}
batch_size{2}
number_of_steps{3}
percentage_validation{4}
initial_learning_rate{5}
pooling_steps{6}
min_fraction{7}
\n"," \"\"\".format(number_of_epochs, str(patch_width)+'x'+str(patch_height), batch_size, number_of_steps, percentage_validation, initial_learning_rate, pooling_steps, min_fraction)\n"," pdf.write_html(html)\n","\n"," #pdf.multi_cell(190, 5, txt = text_2, align='L')\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," pdf.ln(1)\n"," pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(29, 5, txt= 'Training_source:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(28, 5, txt= 'Training_target:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n"," #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(21, 5, txt= 'Model Path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n"," pdf.ln(1)\n"," pdf.cell(60, 5, txt = 'Example Training pair', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread('/content/TrainingDataExample_Unet2D.png').shape\n"," pdf.image('/content/TrainingDataExample_Unet2D.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- Unet: Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. \"U-net: Convolutional networks for biomedical image segmentation.\" International Conference on Medical image computing and computer-assisted intervention. Springer, Cham, 2015.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," # if Use_Data_augmentation:\n"," # ref_3 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\n"," # pdf.multi_cell(190, 5, txt = ref_3, align='L')\n"," pdf.ln(3)\n"," reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(model_path+'/'+model_name+'/'+model_name+'_training_report.pdf')\n","\n"," print('------------------------------')\n"," print('PDF report exported in '+model_path+'/'+model_name+'/')\n","\n","def qc_pdf_export():\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'Unet 2D'\n","\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(2)\n"," pdf.cell(190, 5, txt = 'Loss curves', ln=1, align='L')\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n"," if os.path.exists(full_QC_model_path+'/Quality Control/lossCurvePlots.png'):\n"," pdf.image(full_QC_model_path+'/Quality Control/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/12), h = round(exp_size[0]/3))\n"," else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.',align='L')\n"," pdf.ln(2)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(2)\n"," pdf.cell(190, 5, txt = 'Threshold Optimisation', ln=1, align='L')\n"," #pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/'+QC_model_name+'_IoUvsThresholdPlot.png').shape\n"," pdf.image(full_QC_model_path+'/Quality Control/'+QC_model_name+'_IoUvsThresholdPlot.png', x = 11, y = None, w = round(exp_size[1]/6), h = round(exp_size[0]/7))\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(3)\n"," pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n"," pdf.image(full_QC_model_path+'/Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n","\n"," pdf.ln(1)\n"," html = \"\"\"\n"," \n"," \n"," \"\"\"\n"," with open(full_QC_model_path+'/Quality Control/QC_metrics_'+QC_model_name+'.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," IoU = header[1]\n"," IoU_OptThresh = header[2]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \"\"\".format(image,IoU,IoU_OptThresh)\n"," html = html+header\n"," i=0\n"," for row in metrics:\n"," i+=1\n"," image = row[0]\n"," IoU = row[1]\n"," IoU_OptThresh = row[2]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \"\"\".format(image,str(round(float(IoU),3)),str(round(float(IoU_OptThresh),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}
{0}{1}{2}
\"\"\"\n"," \n"," pdf.write_html(html)\n","\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- Unet: Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. \"U-net: Convolutional networks for biomedical image segmentation.\" International Conference on Medical image computing and computer-assisted intervention. Springer, Cham, 2015.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n"," pdf.ln(3)\n"," reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n","\n"," print('------------------------------')\n"," print('QC PDF report exported as '+full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n","\n","# Build requirements file for local run\n","after = [str(m) for m in sys.modules]\n","build_requirements_file(before, after)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **2. Complete the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["\n","## **2.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelerator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"zCvebubeSaGY","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi\n","\n","# from tensorflow.python.client import device_lib \n","# device_lib.list_local_devices()\n","\n","# print the tensorflow version\n","print('Tensorflow version is ' + str(tf.__version__))\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"sNIVx8_CLolt"},"source":["## **2.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","cellView":"form"},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"dm3eCMYB5d-H"},"source":["** If you cannot see your files, reactivate your session by connecting to your hosted runtime.** \n","\n","\n","\"Example
Connect to a hosted runtime.
"]},{"cell_type":"markdown","metadata":{"id":"HLYcZR9gMv42"},"source":["# **3. Select your parameters and paths**\n","---"]},{"cell_type":"markdown","metadata":{"id":"FQ_QxtSWQ7CL"},"source":["## **3.1. Setting main training parameters**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"id":"AuESFimvMv43"},"source":[" **Paths for training data and models**\n","\n","**`Training_source`, `Training_target`:** These are the folders containing your source (e.g. EM images) and target files (segmentation masks). Enter the path to the source and target images for training. **These should be located in the same parent folder.**\n","\n","**`model_name`:** Use only my_model -style, not my-model. If you want to use a previously trained model, enter the name of the pretrained model (which should be contained in the trained_model -folder after training).\n","\n","**`model_path`**: Enter the path of the folder where you want to save your model.\n","\n","**`visual_validation_after_training`**: If you select this option, a random image pair will be set aside from your training set and will be used to display a predicted image of the trained network next to the input and the ground-truth. This can aid in visually assessing the performance of your network after training. **Note: Your training set size will decrease by 1 if you select this option.**\n","\n"," **Select training parameters**\n","\n","**`number_of_epochs`**: Choose more epochs for larger training sets. Observing how much the loss reduces between epochs during training may help determine the optimal value. **Default: 200**\n","\n","**Advanced parameters - experienced users only**\n","\n","**`batch_size`**: This parameter describes the amount of images that are loaded into the network per step. Smaller batchsizes may improve training performance slightly but may increase training time. If the notebook crashes while loading the dataset this can be due to a too large batch size. Decrease the number in this case. **Default: 4**\n","\n","**`number_of_steps`**: This number should be equivalent to the number of samples in the training set divided by the batch size, to ensure the training iterates through the entire training set. The default value is calculated to ensure this. This behaviour can also be obtained by setting it to 0. Other values can be used for testing.\n","\n"," **`pooling_steps`**: Choosing a different number of pooling layers can affect the performance of the network. Each additional pooling step will also two additional convolutions. The network can learn more complex information but is also more likely to overfit. Achieving best performance may require testing different values here. **Default: 2**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during training. **Default value: 10** \n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0003**\n","\n","**`patch_width` and `patch_height`:** The notebook crops the data in patches of fixed size prior to training. The dimensions of the patches can be defined here. When `Use_Default_Advanced_Parameters` is selected, the largest 2^n x 2^n patch that fits in the smallest dataset is chosen. Larger patches than 512x512 should **NOT** be selected for network stability.\n","\n","**`min_fraction`:** Minimum fraction of pixels being foreground for a slected patch to be considered valid. It should be between 0 and 1.**Default value: 0.02** (2%)\n","\n"]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","cellView":"form"},"source":["# ------------- Initial user input ------------\n","#@markdown ###Path to training images:\n","Training_source = '' #@param {type:\"string\"}\n","Training_target = '' #@param {type:\"string\"}\n","\n","model_name = '' #@param {type:\"string\"}\n","model_path = '' #@param {type:\"string\"}\n","\n","#@markdown ###Training parameters:\n","#@markdown Number of epochs\n","number_of_epochs = 200#@param {type:\"number\"}\n","\n","#@markdown ###Advanced parameters:\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please input:\n","batch_size = 4#@param {type:\"integer\"}\n","number_of_steps = 0#@param {type:\"number\"}\n","pooling_steps = 2 #@param [1,2,3,4]{type:\"raw\"}\n","percentage_validation = 10#@param{type:\"number\"}\n","initial_learning_rate = 0.0003 #@param {type:\"number\"}\n","\n","patch_width = 512#@param{type:\"number\"}\n","patch_height = 512#@param{type:\"number\"}\n","min_fraction = 0.02#@param{type:\"number\"}\n","\n","\n","# ------------- Initialising folder, variables and failsafes ------------\n","# Create the folders where to save the model and the QC\n","full_model_path = os.path.join(model_path, model_name)\n","if os.path.exists(full_model_path):\n"," print(R+'!! WARNING: Folder already exists and will be overwritten !!'+W)\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 4\n"," pooling_steps = 2\n"," percentage_validation = 10\n"," initial_learning_rate = 0.0003\n"," patch_width, patch_height = estimatePatchSize(Training_source)\n"," min_fraction = 0.02\n","\n","\n","#The create_patches function will create the two folders below\n","# Patch_source = '/content/img_patches'\n","# Patch_target = '/content/mask_patches'\n","print('Training on patches of size (x,y): ('+str(patch_width)+','+str(patch_height)+')')\n","\n","#Create patches\n","print('Creating patches...')\n","Patch_source, Patch_target = create_patches(Training_source, Training_target, patch_width, patch_height, min_fraction)\n","\n","number_of_training_dataset = len(os.listdir(Patch_source))\n","print('Total number of valid patches: '+str(number_of_training_dataset))\n","\n","if Use_Default_Advanced_Parameters or number_of_steps == 0:\n"," number_of_steps = ceil((100-percentage_validation)/100*number_of_training_dataset/batch_size)\n","print('Number of steps: '+str(number_of_steps))\n","\n","# Calculate the number of steps to use for validation\n","validation_steps = max(1, ceil(percentage_validation/100*number_of_training_dataset/batch_size))\n","\n","\n","# Here we disable pre-trained model by default (in case the next cell is not ran)\n","Use_pretrained_model = False\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","Use_Data_augmentation = False\n","# Build the default dict for the ImageDataGenerator\n","data_gen_args = dict(width_shift_range = 0.,\n"," height_shift_range = 0.,\n"," rotation_range = 0., #90\n"," zoom_range = 0.,\n"," shear_range = 0.,\n"," horizontal_flip = False,\n"," vertical_flip = False,\n"," validation_split = percentage_validation/100,\n"," fill_mode = 'reflect')\n","\n","# ------------- Display ------------\n","\n","#if not os.path.exists('/content/img_patches/'):\n","random_choice = random.choice(os.listdir(Patch_source))\n","x = io.imread(os.path.join(Patch_source, random_choice))\n","\n","#os.chdir(Training_target)\n","y = io.imread(os.path.join(Patch_target, random_choice), as_gray=True)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest',cmap='gray')\n","plt.title('Training image patch')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(y, interpolation='nearest',cmap='gray')\n","plt.title('Training mask patch')\n","plt.axis('off');\n","\n","plt.savefig('/content/TrainingDataExample_Unet2D.png',bbox_inches='tight',pad_inches=0)\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"w_jCy7xOx2g3"},"source":["##**3.2. Data augmentation**\n","\n","---\n","\n"," Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if the dataset is large the values can be set to 0.\n","\n"," The augmentation options below are to be used as follows:\n","\n","* **shift**: a translation of the image by a fraction of the image size (width or height), **default: 10%**\n","* **zoom_range**: Increasing or decreasing the field of view. E.g. 10% will result in a zoom range of (0.9 to 1.1), with pixels added or interpolated, depending on the transformation, **default: 10%**\n","* **shear_range**: Shear angle in counter-clockwise direction, **default: 10%**\n","* **flip**: creating a mirror image along specified axis (horizontal or vertical), **default: True**\n","* **rotation_range**: range of allowed rotation angles in degrees (from 0 to *value*), **default: 180**"]},{"cell_type":"code","metadata":{"id":"DMqWq5-AxnFU","cellView":"form"},"source":["#@markdown ##**Augmentation options**\n","\n","Use_Data_augmentation = True #@param {type:\"boolean\"}\n","Use_Default_Augmentation_Parameters = True #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation:\n"," if Use_Default_Augmentation_Parameters:\n"," horizontal_shift = 10 \n"," vertical_shift = 20 \n"," zoom_range = 10\n"," shear_range = 10\n"," horizontal_flip = True\n"," vertical_flip = True\n"," rotation_range = 180\n","#@markdown ###If you are not using the default settings, please provide the values below:\n","\n","#@markdown ###**Image shift, zoom, shear and flip (%)**\n"," else:\n"," horizontal_shift = 10 #@param {type:\"slider\", min:0, max:100, step:1}\n"," vertical_shift = 10 #@param {type:\"slider\", min:0, max:100, step:1}\n"," zoom_range = 10 #@param {type:\"slider\", min:0, max:100, step:1}\n"," shear_range = 10 #@param {type:\"slider\", min:0, max:100, step:1}\n"," horizontal_flip = True #@param {type:\"boolean\"}\n"," vertical_flip = True #@param {type:\"boolean\"}\n","\n","#@markdown ###**Rotate image within angle range (degrees):**\n"," rotation_range = 180 #@param {type:\"slider\", min:0, max:180, step:1}\n","\n","#given behind the # are the default values for each parameter.\n","\n","else:\n"," horizontal_shift = 0 \n"," vertical_shift = 0 \n"," zoom_range = 0\n"," shear_range = 0\n"," horizontal_flip = False\n"," vertical_flip = False\n"," rotation_range = 0\n","\n","\n","# Build the dict for the ImageDataGenerator\n","data_gen_args = dict(width_shift_range = horizontal_shift/100.,\n"," height_shift_range = vertical_shift/100.,\n"," rotation_range = rotation_range, #90\n"," zoom_range = zoom_range/100.,\n"," shear_range = shear_range/100.,\n"," horizontal_flip = horizontal_flip,\n"," vertical_flip = vertical_flip,\n"," validation_split = percentage_validation/100,\n"," fill_mode = 'reflect')\n","\n","\n","\n","# ------------- Display ------------\n","dir_augmented_data_imgs=\"/content/augment_img\"\n","dir_augmented_data_masks=\"/content/augment_mask\"\n","random_choice = random.choice(os.listdir(Patch_source))\n","orig_img = load_img(os.path.join(Patch_source,random_choice))\n","orig_mask = load_img(os.path.join(Patch_target,random_choice))\n","\n","augment_view = ImageDataGenerator(**data_gen_args)\n","\n","if Use_Data_augmentation:\n"," print(\"Parameters enabled\")\n"," print(\"Here is what a subset of your augmentations looks like:\")\n"," save_augment(augment_view, orig_img, dir_augmented_data=dir_augmented_data_imgs)\n"," save_augment(augment_view, orig_mask, dir_augmented_data=dir_augmented_data_masks)\n","\n"," fig = plt.figure(figsize=(15, 7))\n"," fig.subplots_adjust(hspace=0.0,wspace=0.1,left=0,right=1.1,bottom=0, top=0.8)\n","\n"," \n"," ax = fig.add_subplot(2, 6, 1,xticks=[],yticks=[]) \n"," new_img=img_as_ubyte(normalizeMinMax(img_to_array(orig_img)))\n"," ax.imshow(new_img)\n"," ax.set_title('Original Image')\n"," i = 2\n"," for imgnm in os.listdir(dir_augmented_data_imgs):\n"," ax = fig.add_subplot(2, 6, i,xticks=[],yticks=[]) \n"," img = load_img(dir_augmented_data_imgs + \"/\" + imgnm)\n"," ax.imshow(img)\n"," i += 1\n","\n"," ax = fig.add_subplot(2, 6, 7,xticks=[],yticks=[]) \n"," new_mask=img_as_ubyte(normalizeMinMax(img_to_array(orig_mask)))\n"," ax.imshow(new_mask)\n"," ax.set_title('Original Mask')\n"," j=2\n"," for imgnm in os.listdir(dir_augmented_data_masks):\n"," ax = fig.add_subplot(2, 6, j+6,xticks=[],yticks=[]) \n"," mask = load_img(dir_augmented_data_masks + \"/\" + imgnm)\n"," ax.imshow(mask)\n"," j += 1\n"," plt.show()\n","\n","else:\n"," print(\"No augmentation will be used\")\n","\n"," "],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"3L9zSGtORKYI"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a U-Net model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"9vC2n-HeLdiJ","cellView":"form"},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","Weights_choice = \"last\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".hdf5\")\n","\n","\n","# --------------------- Download the a model provided in the XXX ------------------------\n","\n"," if pretrained_model_choice == \"Model_name\":\n"," pretrained_model_name = \"Model_name\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the UNET_Model_from_\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path) \n"," wget.download(\"\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".hdf5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(R+'WARNING: pretrained model does not exist')\n"," Use_pretrained_model = False\n"," \n","\n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print(R+'No pretrained network will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"MCGklf1vZf2M"},"source":["# **4. Train the network**\n","---\n","####**Troubleshooting:** If you receive a time-out or exhausted error, try reducing the batchsize of your training set. This reduces the amount of data loaded into the model at one point in time. "]},{"cell_type":"markdown","metadata":{"id":"1KYOuygETJkT"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"id":"lIUAOJ_LMv5E","cellView":"form"},"source":["#@markdown ##Play this cell to prepare the model for training\n","\n","\n","# ------------------ Set the generators, model and logger ------------------\n","# This will take the image size and set that as a patch size (arguable...)\n","# Read image size (without actuall reading the data)\n","\n","(train_datagen, validation_datagen) = prepareGenerators(Patch_source, Patch_target, data_gen_args, batch_size, target_size = (patch_width, patch_height))\n","\n","\n","# This modelcheckpoint will only save the best model from the validation loss point of view\n","model_checkpoint = ModelCheckpoint(os.path.join(full_model_path, 'weights_best.hdf5'), monitor='val_loss',verbose=1, save_best_only=True)\n","\n","print('Getting class weights...')\n","class_weights = getClassWeights(Training_target)\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","else:\n"," h5_file_path = None\n","\n","# --------------------- ---------------------- ------------------------\n","\n","# --------------------- Reduce learning rate on plateau ------------------------\n","\n","reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, verbose=1, mode='auto',\n"," patience=10, min_lr=0)\n","# --------------------- ---------------------- ------------------------\n","\n","\n","# Define the model\n","model = unet(pretrained_weights = h5_file_path, \n"," input_size = (patch_width,patch_height,1), \n"," pooling_steps = pooling_steps, \n"," learning_rate = initial_learning_rate, \n"," class_weights = class_weights)\n","\n","config_model= model.optimizer.get_config()\n","print(config_model)\n","\n","\n","# ------------------ Failsafes ------------------\n","if os.path.exists(full_model_path):\n"," print(R+'!! WARNING: Model folder already existed and has been removed !!'+W)\n"," shutil.rmtree(full_model_path)\n","\n","os.makedirs(full_model_path)\n","os.makedirs(os.path.join(full_model_path,'Quality Control'))\n","\n","\n","# ------------------ Display ------------------\n","print('---------------------------- Main training parameters ----------------------------')\n","print('Number of epochs: '+str(number_of_epochs))\n","print('Batch size: '+str(batch_size))\n","print('Number of training dataset: '+str(number_of_training_dataset))\n","print('Number of training steps: '+str(number_of_steps))\n","print('Number of validation steps: '+str(validation_steps))\n","print('---------------------------- ------------------------ ----------------------------')\n","\n","pdf_export(augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"0Dfn8ZsEMv5d"},"source":["## **4.2. Start Training**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point.\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder from Google Drive as all data can be erased at the next training if using the same folder."]},{"cell_type":"code","metadata":{"scrolled":true,"id":"iwNmp1PUzRDQ","cellView":"form"},"source":["#@markdown ##Start training\n","\n","start = time.time()\n","# history = model.fit_generator(train_datagen, steps_per_epoch = number_of_steps, epochs=epochs, callbacks=[model_checkpoint,csv_log], validation_data = validation_datagen, validation_steps = validation_steps, shuffle=True, verbose=1)\n","history = model.fit_generator(train_datagen, steps_per_epoch = number_of_steps, epochs = number_of_epochs, callbacks=[model_checkpoint, reduce_lr], validation_data = validation_datagen, validation_steps = validation_steps, shuffle=True, verbose=1)\n","\n","# Save the last model\n","model.save(os.path.join(full_model_path, 'weights_last.hdf5'))\n","\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = os.path.join(full_model_path,'Quality Control/training_evaluation.csv')\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n"," \n","\n","\n","# Displaying the time elapsed for training\n","print(\"------------------------------------------\")\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\", hour, \"hour(s)\", mins,\"min(s)\",round(sec),\"sec(s)\")\n","print(\"------------------------------------------\")\n","\n","#Create a pdf document with training summary\n","\n","pdf_export(trained = True, augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"_0Hynw3-xHp1"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"id":"eAJzMwPA6tlH","cellView":"form"},"source":["#@markdown ###Do you want to assess the model you just trained ?\n","\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","\n","full_QC_model_path = os.path.join(QC_model_path, QC_model_name)\n","if os.path.exists(os.path.join(full_QC_model_path, 'weights_best.hdf5')):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"dhJROwlAMv5o"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"id":"vMzSP50kMv5p","cellView":"form"},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","epochNumber = []\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(os.path.join(full_QC_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(os.path.join(full_QC_model_path, 'Quality Control', 'lossCurvePlots.png'),bbox_inches='tight',pad_inches=0)\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"X5_92nL2xdP6"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","This section will calculate the Intersection over Union score for all the images provided in the Source_QC_folder and Target_QC_folder. The result for one of the image will also be displayed.\n","\n","The **Intersection over Union** metric is a method that can be used to quantify the percent overlap between the target mask and your prediction output. **Therefore, the closer to 1, the better the performance.** This metric can be used to assess the quality of your model to accurately predict nuclei. \n","\n","The Input, Ground Truth, Prediction and IoU maps are shown below for the last example in the QC set.\n","\n"," The results for all QC examples can be found in the \"*Quality Control*\" folder which is located inside your \"model_folder\".\n","\n","### **Thresholds for image masks**\n","\n"," Since the output from Unet is not a binary mask, the output images are converted to binary masks using thresholding. This section will test different thresholds (from 0 to 255) to find the one yielding the best IoU score compared with the ground truth. The best threshold for each image and the average of these thresholds will be displayed below. **These values can be a guideline when creating masks for unseen data in section 6.**"]},{"cell_type":"code","metadata":{"id":"w90MdriMxhjD","cellView":"form"},"source":["# ------------- User input ------------\n","#@markdown ##Choose the folders that contain your Quality Control dataset\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","\n","# ------------- Initialise folders ------------\n","# Create a quality control/Prediction Folder\n","prediction_QC_folder = os.path.join(full_QC_model_path, 'Quality Control', 'Prediction')\n","if os.path.exists(prediction_QC_folder):\n"," shutil.rmtree(prediction_QC_folder)\n","\n","os.makedirs(prediction_QC_folder)\n","\n","\n","# ------------- Prepare the model and run predictions ------------\n","\n","# Load the model\n","unet = load_model(os.path.join(full_QC_model_path, 'weights_best.hdf5'), custom_objects={'_weighted_binary_crossentropy': weighted_binary_crossentropy(np.ones(2))})\n","Input_size = unet.layers[0].output_shape[1:3]\n","print('Model input size: '+str(Input_size[0])+'x'+str(Input_size[1]))\n","\n","# Create a list of sources\n","source_dir_list = os.listdir(Source_QC_folder)\n","number_of_dataset = len(source_dir_list)\n","print('Number of dataset found in the folder: '+str(number_of_dataset))\n","\n","predictions = []\n","for i in tqdm(range(number_of_dataset)):\n"," predictions.append(predict_as_tiles(os.path.join(Source_QC_folder, source_dir_list[i]), unet))\n","\n","\n","# Save the results in the folder along with the masks according to the set threshold\n","saveResult(prediction_QC_folder, predictions, source_dir_list, prefix=prediction_prefix, threshold=None)\n","\n","#-----------------------------Calculate Metrics----------------------------------------#\n","\n","f = plt.figure(figsize=((5,5)))\n","\n","with open(os.path.join(full_QC_model_path,'Quality Control', 'QC_metrics_'+QC_model_name+'.csv'), \"w\", newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"File name\",\"IoU\", \"IoU-optimised threshold\"]) \n","\n"," # Initialise the lists \n"," filename_list = []\n"," best_threshold_list = []\n"," best_IoU_score_list = []\n","\n"," for filename in os.listdir(Source_QC_folder):\n","\n"," if not os.path.isdir(os.path.join(Source_QC_folder, filename)):\n"," print('Running QC on: '+filename)\n"," test_input = io.imread(os.path.join(Source_QC_folder, filename), as_gray=True)\n"," test_ground_truth_image = io.imread(os.path.join(Target_QC_folder, filename), as_gray=True)\n","\n"," (threshold_list, iou_scores_per_threshold) = getIoUvsThreshold(os.path.join(prediction_QC_folder, prediction_prefix+filename), os.path.join(Target_QC_folder, filename))\n"," plt.plot(threshold_list,iou_scores_per_threshold, label=filename)\n","\n"," # Here we find which threshold yielded the highest IoU score for image n.\n"," best_IoU_score = max(iou_scores_per_threshold)\n"," best_threshold = iou_scores_per_threshold.index(best_IoU_score)\n","\n"," # Write the results in the CSV file\n"," writer.writerow([filename, str(best_IoU_score), str(best_threshold)])\n","\n"," # Here we append the best threshold and score to the lists\n"," filename_list.append(filename)\n"," best_IoU_score_list.append(best_IoU_score)\n"," best_threshold_list.append(best_threshold)\n","\n","# Display the IoV vs Threshold plot\n","plt.title('IoU vs. Threshold')\n","plt.ylabel('IoU')\n","plt.xlabel('Threshold value')\n","plt.legend()\n","plt.savefig(full_QC_model_path+'/Quality Control/'+QC_model_name+'_IoUvsThresholdPlot.png',bbox_inches='tight',pad_inches=0)\n","plt.show()\n","\n","\n","# Table with metrics as dataframe output\n","pdResults = pd.DataFrame(index = filename_list)\n","pdResults[\"IoU\"] = best_IoU_score_list\n","pdResults[\"IoU-optimised threshold\"] = best_threshold_list\n","\n","average_best_threshold = sum(best_threshold_list)/len(best_threshold_list)\n","\n","\n","# ------------- For display ------------\n","print('--------------------------------------------------------------')\n","@interact\n","def show_QC_results(file=os.listdir(Source_QC_folder)):\n"," \n"," plt.figure(figsize=(25,5))\n"," #Input\n"," plt.subplot(1,4,1)\n"," plt.axis('off')\n"," plt.imshow(plt.imread(os.path.join(Source_QC_folder, file)), aspect='equal', cmap='gray', interpolation='nearest')\n"," plt.title('Input')\n","\n"," #Ground-truth\n"," plt.subplot(1,4,2)\n"," plt.axis('off')\n"," test_ground_truth_image = io.imread(os.path.join(Target_QC_folder, file),as_gray=True)\n"," plt.imshow(test_ground_truth_image, aspect='equal', cmap='Greens')\n"," plt.title('Ground Truth')\n","\n"," #Prediction\n"," plt.subplot(1,4,3)\n"," plt.axis('off')\n"," test_prediction = plt.imread(os.path.join(prediction_QC_folder, prediction_prefix+file))\n"," test_prediction_mask = np.empty_like(test_prediction)\n"," test_prediction_mask[test_prediction > average_best_threshold] = 255\n"," test_prediction_mask[test_prediction <= average_best_threshold] = 0\n"," plt.imshow(test_prediction_mask, aspect='equal', cmap='Purples')\n"," plt.title('Prediction')\n","\n"," #Overlay\n"," plt.subplot(1,4,4)\n"," plt.axis('off')\n"," plt.imshow(test_ground_truth_image, cmap='Greens')\n"," plt.imshow(test_prediction_mask, alpha=0.5, cmap='Purples')\n"," metrics_title = 'Overlay (IoU: ' + str(round(pdResults.loc[file][\"IoU\"],3)) + ' T: ' + str(round(pdResults.loc[file][\"IoU-optimised threshold\"])) + ')'\n"," plt.title(metrics_title)\n"," plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","\n","\n","print('--------------------------------------------------------------')\n","print('Best average threshold is: '+str(round(average_best_threshold)))\n","print('--------------------------------------------------------------')\n","\n","pdResults.head()\n","\n","qc_pdf_export()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"6mpdb00q2D7F"},"source":["## **5.3. Export your model into the BioImage Model Zoo format**\n","---\n","This section exports the model into the BioImage Model Zoo format so it can be used directly with DeepImageJ. The new files will be stored in the model folder specified at the beginning of Section 5. \n","\n","Once the cell is executed, you will find a new zip file with the name specified in `Trained_model_name.bioimage.io.model`.\n","\n","To use it with deepImageJ, download it and unzip it in the ImageJ/models/ or Fiji/models/ folder of your local machine. \n","\n","In ImageJ, open the example image given within the downloaded zip file. Go to Plugins > DeepImageJ > DeepImageJ Run. Choose this model from the list and click OK.\n","\n"," More information at https://deepimagej.github.io/deepimagej/"]},{"cell_type":"code","metadata":{"cellView":"form","id":"PyS-TGOw2FhU"},"source":["# ------------- User input ------------\n","# information about the model\n","#@markdown ##Introduce the metadata of the model architecture:\n","Trained_model_name = \"\" #@param {type:\"string\"}\n","Trained_model_authors = \"[Author 1, Author 2, Author 3]\" #@param {type:\"string\"}\n","Trained_model_description = \"\" #@param {type:\"string\"}\n","Trained_model_license = 'MIT'#@param {type:\"string\"}\n","Trained_model_references = [\"Falk et al. Nature Methods 2019\", \"Ronneberger et al. arXiv in 2015\", \"Lucas von Chamier et al. biorXiv 2020\"] \n","Trained_model_DOI = [\"https://doi.org/10.1038/s41592-018-0261-2\",\"https://doi.org/10.1007/978-3-319-24574-4_28\", \"https://doi.org/10.1101/2020.03.20.000133\"] \n","\n","#@markdown ##Choose a threshold for DeepImageJ's postprocessing macro:\n","Use_The_Best_Average_Threshold = True #@param {type:\"boolean\"}\n","#@markdown ###If not, please input:\n","threshold = 85 #@param {type:\"number\"}\n","\n","#@markdown ##Introduce the pixel size (in microns) of the image provided as an example of the model processing:\n","# information about the example image\n","PixelSize = 0.0004 #@param {type:\"number\"}\n","#@markdown ##Do you want to choose the exampleimage?\n","default_example_image = True #@param {type:\"boolean\"}\n","#@markdown ###If not, please input:\n","fileID = \"\" #@param {type:\"string\"}\n","\n","\n","if Use_The_Best_Average_Threshold:\n"," threshold = average_best_threshold\n","\n","from skimage import io\n","\n","if default_example_image:\n"," source_dir_list = os.listdir(Source_QC_folder)\n"," fileID = os.path.join(Source_QC_folder, source_dir_list[0])\n"," \n","\n","# Read the input image\n","test_img = io.imread(fileID)\n","\n","# Load the model\n","unet = load_model(os.path.join(full_QC_model_path, 'weights_best.hdf5'), custom_objects={'_weighted_binary_crossentropy': weighted_binary_crossentropy(np.ones(2))})\n","test_prediction = predict_as_tiles(fileID, unet)\n","# test_prediction = io.imread(os.path.join(prediction_QC_folder, prediction_prefix+fileID))\n","# # Binarize it with the threshold chosen\n","test_prediction_mask = convert2Mask(test_prediction, threshold)\n","\n","## Run this cell to export the model to the BioImage Model Zoo format.\n","####\n","\n","from pydeepimagej.yaml import BioImageModelZooConfig\n","import urllib\n","\n","# Check minimum size: it is [8,8] for the 2D XY plane\n","pooling_steps = 0\n","for keras_layer in unet.layers:\n"," if keras_layer.name.startswith('max') or \"pool\" in keras_layer.name:\n"," pooling_steps += 1\n","MinimumSize = [2**(pooling_steps), 2**(pooling_steps)]\n","\n","dij_config = BioImageModelZooConfig(unet, MinimumSize)\n","\n","\n","# Model developer details\n","dij_config.Authors = Trained_model_authors[1:-1].split(',')\n","dij_config.Description = Trained_model_description\n","dij_config.Name = Trained_model_name\n","dij_config.References = Trained_model_references\n","dij_config.DOI = Trained_model_DOI\n","dij_config.License = Trained_model_license\n","\n","# Additional information about the model\n","dij_config.GitHub = '/~https://github.com/HenriquesLab/ZeroCostDL4Mic'\n","dij_config.Date = datetime.now()\n","dij_config.Documentation = '/~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki'\n","dij_config.Tags = ['ZeroCostDL4Mic', 'deepimagej', 'segmentation', 'TEM', 'unet']\n","dij_config.Framework = 'tensorflow'\n","\n","\n","# Add the information about the test image. Note here PixelSize is given in nm\n","dij_config.add_test_info(test_img, test_prediction_mask, [PixelSize, PixelSize])\n","dij_config.create_covers([test_img, test_prediction_mask])\n","dij_config.Covers = ['./input.png', './output.png']\n","\n","\n","## Prepare preprocessing file\n","min_percentile = 0\n","max_percentile = 99.85\n","\n","path_preprocessing = \"per_sample_scale_range.ijm\"\n","urllib.request.urlretrieve('https://raw.githubusercontent.com/deepimagej/imagej-macros/master/bioimage.io/per_sample_scale_range.ijm', path_preprocessing )\n","\n","# Modify the threshold in the macro to the chosen threshold\n","ijmacro = open(path_preprocessing,\"r\") \n","list_of_lines = ijmacro. readlines()\n","# Line 21 is the one corresponding to the optimal threshold\n","list_of_lines[24] = \"min_percentile = {};\\n\".format(min_percentile)\n","list_of_lines[25] = \"max_percentile = {};\\n\".format(max_percentile)\n","ijmacro = open(path_preprocessing,\"w\") \n","ijmacro. writelines(list_of_lines)\n","ijmacro. close()\n","# Include info about the macros \n","dij_config.Preprocessing = [path_preprocessing]\n","dij_config.Preprocessing_files = [path_preprocessing]\n","# Preprocessing following BioImage Model Zoo specifications\n","dij_config.add_bioimageio_spec('pre-processing', 'scale_range',\n"," mode='per_sample', axes='xyzc',\n"," min_percentile = min_percentile,\n"," max_percentile = max_percentile)\n","\n","## Prepare postprocessing file\n","path_postprocessing = \"8bitBinarize.ijm\"\n","urllib.request.urlretrieve(\"https://raw.githubusercontent.com/deepimagej/imagej-macros/master/8bitBinarize.ijm\", path_postprocessing )\n","\n","# Modify the threshold in the macro to the chosen threshold\n","ijmacro = open(path_postprocessing,\"r\") \n","list_of_lines = ijmacro. readlines()\n","# Line 21 is the one corresponding to the optimal threshold\n","list_of_lines[21] = \"optimalThreshold = {};\\n\".format(threshold)\n","ijmacro.close()\n","ijmacro = open(path_postprocessing,\"w\") \n","ijmacro. writelines(list_of_lines)\n","ijmacro. close()\n","\n","# Include info about the macros \n","dij_config.Postprocessing = [path_postprocessing]\n","dij_config.Postprocessing_files = [path_postprocessing]\n","# Preprocessing following BioImage Model Zoo specifications\n","dij_config.add_bioimageio_spec('post-processing', 'scale_range',\n"," mode='per_sample', axes='xyzc', \n"," min_percentile=0, max_percentile=100)\n","\n","dij_config.add_bioimageio_spec('post-processing', 'scale_linear',\n"," gain=255, offset=0, axes='xy')\n","\n","dij_config.add_bioimageio_spec('post-processing', 'binarize',\n"," threshold=threshold)\n","\n","\n","# Store the model weights\n","# ---------------------------------------\n","# used_bioimageio_model_for_training_URL = \"/Some/URL/bioimage.io/\"\n","# dij_config.Parent = used_bioimageio_model_for_training_URL\n","\n","# Add weights information\n","format_authors = [\"pydeepimagej\"]\n","dij_config.add_weights_formats(unet, 'TensorFlow', \n"," parent=\"keras_hdf5\",\n"," authors=[a for a in format_authors])\n","dij_config.add_weights_formats(unet, 'KerasHDF5', \n"," authors=[a for a in format_authors])\n","\n","## EXPORT THE MODEL\n","deepimagej_model_path = os.path.join(full_QC_model_path, Trained_model_name+'.bioimage.io.model')\n","dij_config.export_model(deepimagej_model_path)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-tJeeJjLnRkP"},"source":["# **6. Using the trained model**\n","\n","---\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"gz3BxSwu1_XN"},"source":[""]},{"cell_type":"markdown","metadata":{"id":"d8wuQGjoq6eN"},"source":["## **6.1 Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.1) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder.\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images.\n","\n"," Once the predictions are complete the cell will display a random example prediction beside the input image and the calculated mask for visual inspection.\n","\n"," **Troubleshooting:** If there is a low contrast image warning when saving the images, this may be due to overfitting of the model to the data. It may result in images containing only a single colour. Train the network again with different network hyperparameters."]},{"cell_type":"code","metadata":{"id":"y2TD5p7MZrEb","cellView":"form"},"source":["\n","\n","# ------------- Initial user input ------------\n","#@markdown ###Provide the path to your dataset and to the folder where the prediction will be saved (Result folder), then play the cell to predict output on your unseen images.\n","Data_folder = '' #@param {type:\"string\"}\n","Results_folder = '' #@param {type:\"string\"}\n","\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","\n","# ------------- Failsafes ------------\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = os.path.join(Prediction_model_path, Prediction_model_name)\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","# ------------- Prepare the model and run predictions ------------\n","\n","# Load the model and prepare generator\n","\n","\n","\n","unet = load_model(os.path.join(Prediction_model_path, Prediction_model_name, 'weights_best.hdf5'), custom_objects={'_weighted_binary_crossentropy': weighted_binary_crossentropy(np.ones(2))})\n","Input_size = unet.layers[0].output_shape[1:3]\n","print('Model input size: '+str(Input_size[0])+'x'+str(Input_size[1]))\n","\n","# Create a list of sources\n","source_dir_list = os.listdir(Data_folder)\n","number_of_dataset = len(source_dir_list)\n","print('Number of dataset found in the folder: '+str(number_of_dataset))\n","\n","predictions = []\n","for i in tqdm(range(number_of_dataset)):\n"," predictions.append(predict_as_tiles(os.path.join(Data_folder, source_dir_list[i]), unet))\n"," # predictions.append(prediction(os.path.join(Data_folder, source_dir_list[i]), os.path.join(Prediction_model_path, Prediction_model_name)))\n","\n","\n","# Save the results in the folder along with the masks according to the set threshold\n","saveResult(Results_folder, predictions, source_dir_list, prefix=prediction_prefix, threshold=None)\n","\n","\n","# ------------- For display ------------\n","print('--------------------------------------------------------------')\n","\n","\n","def show_prediction_mask(file=os.listdir(Data_folder), threshold=(0,255,1)):\n","\n"," plt.figure(figsize=(18,6))\n"," # Wide-field\n"," plt.subplot(1,3,1)\n"," plt.axis('off')\n"," img_Source = plt.imread(os.path.join(Data_folder, file))\n"," plt.imshow(img_Source, cmap='gray')\n"," plt.title('Source image',fontsize=15)\n"," # Prediction\n"," plt.subplot(1,3,2)\n"," plt.axis('off')\n"," img_Prediction = plt.imread(os.path.join(Results_folder, prediction_prefix+file))\n"," plt.imshow(img_Prediction, cmap='gray')\n"," plt.title('Prediction',fontsize=15)\n","\n"," # Thresholded mask\n"," plt.subplot(1,3,3)\n"," plt.axis('off')\n"," img_Mask = convert2Mask(img_Prediction, threshold)\n"," plt.imshow(img_Mask, cmap='gray')\n"," plt.title('Mask (Threshold: '+str(round(threshold))+')',fontsize=15)\n","\n","\n","interact(show_prediction_mask, continuous_update=False);\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"stS96mFZLMOU"},"source":["## **6.2. Export results as masks**\n","---\n"]},{"cell_type":"code","metadata":{"id":"qb5ZmFstLNbR","cellView":"form"},"source":["\n","# @markdown #Play this cell to save results as masks with the chosen threshold\n","threshold = 120#@param {type:\"number\"}\n","\n","saveResult(Results_folder, predictions, source_dir_list, prefix=prediction_prefix, threshold=threshold)\n","print('-------------------')\n","print('Masks were saved in: '+Results_folder)\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"BphZ0wBrC2Zw"},"source":["# **7. Version log**\n","\n","---\n","**v1.13**: \n","* This version now includes an automatic restart allowing to set the h5py library to v2.10. \n","* The section 1 and 2 are now swapped for better export of *requirements.txt*. \n","* This version also now includes built-in version check and the version log that you're reading now.\n","\n","---"]},{"cell_type":"markdown","metadata":{"id":"UvSlTaH14s3t"},"source":["#**Thank you for using 2D U-Net!**\n"]}]} \ No newline at end of file diff --git a/Colab_notebooks/Beta notebooks/U-Net_3D_ZeroCostDL4Mic_BioImageModelZoo_export.ipynb b/Colab_notebooks/Beta notebooks/U-Net_3D_ZeroCostDL4Mic_BioImageModelZoo_export.ipynb index 06af3d25..98e36718 100644 --- a/Colab_notebooks/Beta notebooks/U-Net_3D_ZeroCostDL4Mic_BioImageModelZoo_export.ipynb +++ b/Colab_notebooks/Beta notebooks/U-Net_3D_ZeroCostDL4Mic_BioImageModelZoo_export.ipynb @@ -1,2689 +1 @@ -{ - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "U-Net_3D_ZeroCostDL4Mic_BioImageModelZoo_export.ipynb", - "provenance": [], - "collapsed_sections": [], - "toc_visible": true, - "machine_shape": "hm" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.7" - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "accelerator": "GPU" - }, - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "IkSguVy8Xv83" - }, - "source": [ - "# **U-Net (3D)**\n", - " ---\n", - "\n", - " The 3D U-Net was first introduced by [Çiçek et al](https://arxiv.org/abs/1606.06650) for learning dense volumetric segmentations from sparsely annotated ground-truth data building upon the original U-Net architecture by [Ronneberger et al](https://arxiv.org/abs/1505.04597). \n", - "\n", - "**This particular implementation allows supervised learning between any two types of 3D image data. If you are interested in image segmentation of 2D datasets, you should use the 2D U-Net notebook instead.**\n", - "\n", - "---\n", - "\n", - "*Disclaimer*:\n", - "\n", - "This notebook is part of the *Zero-Cost Deep-Learning to Enhance Microscopy* project ([ZeroCostDL4Mic](/~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki)) jointly developed by the [Jacquemet](https://cellmig.org/) and [Henriques](https://henriqueslab.github.io/) laboratories and created by Daniel Krentzel.\n", - "\n", - "This notebook is laregly based on the following paper: \n", - "\n", - "[**3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation**](https://arxiv.org/pdf/1606.06650.pdf) by Özgün Çiçek *et al.* published on arXiv in 2016\n", - "\n", - "The following two Python libraries play an important role in the notebook: \n", - "\n", - "1. [**Elasticdeform**](/~https://github.com/gvtulder/elasticdeform)\n", - " by Gijs van Tulder was used to augment the 3D training data using elastic grid-based deformations as described in the original 3D U-Net paper. \n", - "\n", - "2. [**Tifffile**](/~https://github.com/cgohlke/tifffile) by Christoph Gohlke is a great library for reading and writing TIFF files. \n", - "\n", - "3. [**Imgaug**](/~https://github.com/aleju/imgaug) by Alexander Jung *et al.* is an amazing library for image augmentation in machine learning - it is the most complete and extensive image augmentation package I have found to date. \n", - "\n", - "The [example dataset](https://www.epfl.ch/labs/cvlab/data/data-em/) represents a 5x5x5µm section taken from the CA1 hippocampus region of the brain with annotated mitochondria and was acquired by Graham Knott and Marco Cantoni at EPFL.\n", - "\n", - "\n", - "**Please also cite the original paper and relevant Python libraries when using or developing this notebook.**" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jWAz2i7RdxUV" - }, - "source": [ - "# **How to use this notebook?**\n", - "\n", - "---\n", - "\n", - "Video describing how to use ZeroCostDL4Mic notebooks are available on youtube:\n", - " - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n", - " - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n", - "\n", - "\n", - "---\n", - "###**Structure of a notebook**\n", - "\n", - "The notebook contains two types of cells: \n", - "\n", - "**Text cells** provide information and can be modified by double-clicking the cell. You are currently reading a text cell. You can create a new one by clicking `+ Text`.\n", - "\n", - "**Code cells** contain code which can be modfied by selecting the cell. To execute the cell, move your cursor to the `[]`-symbol on the left side of the cell (a play button should appear). Click it to execute the cell. Once the cell is fully executed, the animation stops. You can create a new coding cell by clicking `+ Code`.\n", - "\n", - "---\n", - "###**Table of contents, Code snippets** and **Files**\n", - "\n", - "Three tabs are located on the upper left side of the notebook:\n", - "\n", - "1. *Table of contents* contains the structure of the notebook. Click the headers to move quickly between sections.\n", - "\n", - "2. *Code snippets* provides a wide array of example code specific to Google Colab. You can ignore this when using this notebook.\n", - "\n", - "3. *Files* displays the current working directory. We will mount your Google Drive in Section 1.2. so that you can access your files and save them permanently.\n", - "\n", - "**Important:** All uploaded files are purged once the runtime ends.\n", - "\n", - "**Note:** The directory *sample data* in *Files* contains default files. Do not upload anything there!\n", - "\n", - "---\n", - "###**Making changes to the notebook**\n", - "\n", - "**You can make a copy** of the notebook and save it to your Google Drive by clicking *File* -> *Save a copy in Drive*.\n", - "\n", - "To **edit a cell**, double click on the text. This will either display the source code (in code cells) or the [markdown](https://colab.research.google.com/notebooks/markdown_guide.ipynb#scrollTo=70pYkR9LiOV0) (in text cells).\n", - "You can use `#` in code cells to comment out parts of the code. This allows you to keep the original piece of code while not executing it." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "gKDLkLWUd-YX" - }, - "source": [ - "#**0. Before getting started**\n", - "---\n", - "\n", - "As the network operates in three dimensions, certain consideration should be given to correctly pre-processing the data. Ensure that the structure of interest does not substantially change between slices - image volumes with isotropic pixelsizes are ideal for this architecture.\n", - "\n", - "Each image volume must be provided as an **8-bit** or **binary multipage TIFF file** to maintain the correct ordering of individual image slices. If more than one image volume has been annotated, source and target files must be named identically and placed in separate directories. In case only one image volume has been annotated, source and target file do not have to be placed in separate directories and can be named differently, as long as their paths are explicitly provided in Section 3. \n", - "\n", - "**Prepare two datasets** (*training* and *testing*) for quality control puproses. Make sure that the *testing* dataset does not overlap with the *training* dataset and is ideally sourced from a different acquisiton and sample to ensure robustness of the trained model. \n", - "\n", - "\n", - "---\n", - "\n", - "\n", - "### **Directory structure**\n", - "\n", - "Make sure to adhere to one of the following directory structures. If only one annotated training volume exists, choose the first structure. In case more than one training volume is available, choose the second structure.\n", - "\n", - "**Structure 1:** Only one training volume\n", - "```\n", - "path/to/directory/with/one/training/volume\n", - "│--training_source.tif\n", - "│--training_target.tif\n", - "| \n", - "│--testing_source.tif\n", - "|--testing_target.tif \n", - "|\n", - "|--data_to_predict_on.tif\n", - "|--prediction_results.tif\n", - "\n", - "```\n", - "**Structure 2:** Various training volumes\n", - "```\n", - "path/to/directory/with/various/training/volumes\n", - "│--testing_source.tif\n", - "|--testing_target.tif \n", - "|\n", - "└───training\n", - "| └───source\n", - "| | |--training_volume_one.tif\n", - "| | |--training_volume_two.tif\n", - "| | |--...\n", - "| | |--training_volume_n.tif\n", - "| |\n", - "| └───target\n", - "| |--training_volume_one.tif\n", - "| |--training_volume_two.tif\n", - "| |--...\n", - "| |--training_volume_n.tif\n", - "|\n", - "|--data_to_predict_on.tif\n", - "|--prediction_results.tif\n", - "```\n", - "**Note:** Naming directories is completely up to you, as long as the paths are correctly specified throughout the notebook.\n", - "\n", - "\n", - "---\n", - "\n", - "\n", - "### **Important note**\n", - "\n", - "* If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do so), you will need to run **Sections 1 - 4**, then use **Section 5** to assess the quality of your model and **Section 6** to run predictions using the model that you trained.\n", - "\n", - "* If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **Sections 1 and 2** to set up the notebook, then use **Section 5** to assess the quality of your model.\n", - "\n", - "* If you only wish to **Run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **Sections 1 and 2** to set up the notebook, then use **Section 6** to run the predictions on the desired model.\n", - "---" - ] - }, - { - "cell_type": "code", - "metadata": { - "cellView": "form", - "id": "M-GZMaL7pd8a" - }, - "source": [ - "#@markdown ##**Download example dataset**\n", - "\n", - "#@markdown This usually takes a few minutes. The images are saved in *example_dataset*.\n", - "\n", - "import requests \n", - "import os\n", - "from tqdm.notebook import tqdm \n", - "\n", - "def make_directory(dir):\n", - " if not os.path.exists(dir):\n", - " os.makedirs(dir)\n", - "\n", - "def download_from_url(url, save_as):\n", - " file_url = url\n", - " r = requests.get(file_url, stream=True) \n", - " \n", - " with open(save_as, 'wb') as file: \n", - " for block in tqdm(r.iter_content(chunk_size = 1024), desc = 'Downloading ' + os.path.basename(save_as), total=126875, ncols=1000):\n", - " if block:\n", - " file.write(block) \n", - "\n", - "\n", - "make_directory('example_dataset')\n", - "\n", - "download_from_url('https://documents.epfl.ch/groups/c/cv/cvlab-unit/www/data/%20ElectronMicroscopy_Hippocampus/training.tif', 'example_dataset/training.tif')\n", - "download_from_url('https://documents.epfl.ch/groups/c/cv/cvlab-unit/www/data/%20ElectronMicroscopy_Hippocampus/training_groundtruth.tif', 'example_dataset/training_groundtruth.tif')\n", - "download_from_url('https://documents.epfl.ch/groups/c/cv/cvlab-unit/www/data/%20ElectronMicroscopy_Hippocampus/testing.tif', 'example_dataset/testing.tif')\n", - "download_from_url('https://documents.epfl.ch/groups/c/cv/cvlab-unit/www/data/%20ElectronMicroscopy_Hippocampus/testing_groundtruth.tif', 'example_dataset/testing_groundtruth.tif')\n", - "\n", - "print('Example dataset successfully downloaded!')" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "n4yWFoJNnoin" - }, - "source": [ - "# **1. Initialise the Colab session**\n", - "\n", - "\n", - "\n", - "\n", - "---\n", - "\n", - "\n", - "\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "DMNHVZfHmbKb" - }, - "source": [ - "\n", - "## **1.1. Check for GPU access**\n", - "---\n", - "\n", - "By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n", - "\n", - "Go to **Runtime -> Change the Runtime type**\n", - "\n", - "**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n", - "\n", - "**Accelerator: GPU** *(Graphics processing unit)*\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "zCvebubeSaGY", - "cellView": "form" - }, - "source": [ - "#@markdown ##Run this cell to check if you have GPU access\n", - "%tensorflow_version 1.x\n", - "\n", - "import tensorflow as tf\n", - "if tf.test.gpu_device_name()=='':\n", - " print('You do not have GPU access.') \n", - " print('Did you change your runtime?') \n", - " print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n", - " print('Expect slow performance. To access GPU try reconnecting later')\n", - "\n", - "else:\n", - " print('You have GPU access')\n", - " !nvidia-smi\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "sNIVx8_CLolt" - }, - "source": [ - "## **1.2. Mount Google Drive**\n", - "---\n", - " To use this notebook with your **own data**, place it in a folder on **Google Drive** following one of the directory structures outlined in **Section 0**.\n", - "\n", - "1. **Run** the **cell** below to mount your Google Drive and follow the link. \n", - "\n", - "2. **Sign in** to your Google account and press 'Allow'. \n", - "\n", - "3. Next, copy the **authorization code**, paste it into the cell and press enter. This will allow Colab to read and write data from and to your Google Drive. \n", - "\n", - "4. Once this is done, your data can be viewed in the **Files tab** on the top left of the notebook after hitting 'Refresh'." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "01Djr8v-5pPk", - "cellView": "form" - }, - "source": [ - "#@markdown ##Play the cell to connect your Google Drive to Colab\n", - "\n", - "# mount user's Google Drive to Google Colab.\n", - "from google.colab import drive\n", - "drive.mount('/content/gdrive')" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "cellView": "form", - "id": "zxELU7CIp4oF" - }, - "source": [ - "#@markdown ##Unzip pre-trained model directory\n", - "\n", - "#@markdown 1. Upload a zipped model directory using the *Files* tab\n", - "#@markdown 2. Run this cell to unzip your model file\n", - "#@markdown 3. The model directory will appear in the *Files* tab \n", - "\n", - "from google.colab import files\n", - "\n", - "zipped_model_file = \"\" #@param {type:\"string\"}\n", - "\n", - "!unzip \"$zipped_model_file\"" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "AdN8B91xZO0x" - }, - "source": [ - "# **2. Install 3D U-Net dependencies**\n", - "---\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "mK-HtWvmaj_e" - }, - "source": [ - "## 2.1. Install key dependencies\n", - "---\n", - " " - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "IE8kgtUOakRh", - "cellView": "form" - }, - "source": [ - "#@markdown ##Play to install 3D U-Net dependencies\n", - "\n", - "!pip install pydeepimagej==2.1.2\n", - "# !pip uninstall -y keras-nightly\n", - "!pip install data\n", - "!pip install fpdf\n", - "!pip install h5py==2.10\n", - "\n", - "\n", - "#Force session restart\n", - "# exit(0)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "CHMcIOZjatrI" - }, - "source": [ - "## 2.2. Restart your runtime and run all the cells again. \n", - "---\n", - " " - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "b3SS2C5eeffQ" - }, - "source": [ - "** Skip this step if you already restarted the runtime.**" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "kvoAFGpba9P1" - }, - "source": [ - "## 2.3. Load key dependencies\n", - "---\n", - " " - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "fq21zJVFNASx", - "cellView": "form" - }, - "source": [ - "#@markdown ##Install dependencies and instantiate network\n", - "Notebook_version = ['1.12']\n", - "\n", - "from builtins import any as b_any\n", - "\n", - "def get_requirements_path():\n", - " # Store requirements file in 'contents' directory \n", - " current_dir = os.getcwd()\n", - " dir_count = current_dir.count('/') - 1\n", - " path = '../' * (dir_count) + 'requirements.txt'\n", - " return path\n", - "\n", - "def filter_files(file_list, filter_list):\n", - " filtered_list = []\n", - " for fname in file_list:\n", - " if b_any(fname.split('==')[0] in s for s in filter_list):\n", - " filtered_list.append(fname)\n", - " return filtered_list\n", - "\n", - "def build_requirements_file(before, after):\n", - " path = get_requirements_path()\n", - "\n", - " # Exporting requirements.txt for local run\n", - " !pip freeze > $path\n", - "\n", - " # Get minimum requirements file\n", - " df = pd.read_csv(path, delimiter = \"\\n\")\n", - " mod_list = [m.split('.')[0] for m in after if not m in before]\n", - " req_list_temp = df.values.tolist()\n", - " req_list = [x[0] for x in req_list_temp]\n", - "\n", - " # Replace with package name and handle cases where import name is different to module name\n", - " mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]\n", - " mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list] \n", - " filtered_list = filter_files(req_list, mod_replace_list)\n", - "\n", - " file=open(path,'w')\n", - " for item in filtered_list:\n", - " file.writelines(item + '\\n')\n", - "\n", - " file.close()\n", - "\n", - "import sys\n", - "before = [str(m) for m in sys.modules]\n", - "\n", - "#Put the imported code and libraries here\n", - "from __future__ import absolute_import, division, print_function, unicode_literals\n", - "\n", - "try:\n", - " import elasticdeform\n", - "except:\n", - " !pip install elasticdeform\n", - " import elasticdeform\n", - "\n", - "try:\n", - " import tifffile\n", - "except:\n", - " !pip install tifffile\n", - " import tifffile\n", - "\n", - "try:\n", - " import imgaug.augmenters as iaa\n", - "except:\n", - " !pip install imgaug\n", - " import imgaug.augmenters as iaa\n", - "\n", - "import os\n", - "import csv\n", - "import random\n", - "import h5py\n", - "import imageio\n", - "import math\n", - "import shutil\n", - "\n", - "import pandas as pd\n", - "from glob import glob\n", - "from tqdm import tqdm\n", - "\n", - "from skimage import transform\n", - "from skimage import exposure\n", - "from skimage import color\n", - "from skimage import io\n", - "\n", - "from scipy.ndimage import zoom\n", - "\n", - "import matplotlib.pyplot as plt\n", - "\n", - "import numpy as np\n", - "import tensorflow as tf\n", - "\n", - "from keras import backend as K\n", - "\n", - "from keras.layers import Conv3D\n", - "from keras.layers import BatchNormalization\n", - "from keras.layers import ReLU\n", - "from keras.layers import MaxPooling3D\n", - "from keras.layers import Conv3DTranspose\n", - "from keras.layers import Input\n", - "from keras.layers import Concatenate\n", - "\n", - "from keras.optimizers import Adam, SGD, RMSprop\n", - "from keras.models import Model\n", - "\n", - "# from keras.utils import Sequence\n", - "from tensorflow.keras.utils import Sequence\n", - "\n", - "from keras.callbacks import ModelCheckpoint\n", - "from keras.callbacks import CSVLogger\n", - "from keras.callbacks import Callback\n", - "\n", - "from keras.metrics import RootMeanSquaredError\n", - "\n", - "from ipywidgets import interact\n", - "from ipywidgets import interactive\n", - "from ipywidgets import fixed\n", - "from ipywidgets import interact_manual \n", - "import ipywidgets as widgets\n", - "\n", - "from fpdf import FPDF, HTMLMixin\n", - "from datetime import datetime\n", - "import subprocess\n", - "from pip._internal.operations.freeze import freeze\n", - "import time\n", - "\n", - "from skimage import io\n", - "import matplotlib\n", - "\n", - "print(\"Dependencies installed and imported.\")\n", - "\n", - "# Define MultiPageTiffGenerator class\n", - "class MultiPageTiffGenerator(Sequence):\n", - "\n", - " def __init__(self,\n", - " source_path,\n", - " target_path,\n", - " batch_size=1,\n", - " shape=(128,128,32,1),\n", - " augment=False,\n", - " augmentations=[],\n", - " deform_augment=False,\n", - " deform_augmentation_params=(5,3,4),\n", - " val_split=0.2,\n", - " is_val=False,\n", - " random_crop=True,\n", - " downscale=1,\n", - " binary_target=False):\n", - "\n", - " # If directory with various multi-page tiffiles is provided read as list\n", - " if os.path.isfile(source_path):\n", - " self.dir_flag = False\n", - " self.source = tifffile.imread(source_path)\n", - " if binary_target:\n", - " self.target = tifffile.imread(target_path).astype(np.bool)\n", - " else:\n", - " self.target = tifffile.imread(target_path)\n", - "\n", - " elif os.path.isdir(source_path):\n", - " self.dir_flag = True\n", - " self.source_dir_list = glob(os.path.join(source_path, '*'))\n", - " self.target_dir_list = glob(os.path.join(target_path, '*'))\n", - "\n", - " self.source_dir_list.sort()\n", - " self.target_dir_list.sort()\n", - "\n", - " self.shape = shape\n", - " self.batch_size = batch_size\n", - " self.augment = augment\n", - " self.val_split = val_split\n", - " self.is_val = is_val\n", - " self.random_crop = random_crop\n", - " self.downscale = downscale\n", - " self.binary_target = binary_target\n", - " self.deform_augment = deform_augment\n", - " self.on_epoch_end()\n", - " \n", - " if self.augment:\n", - " # pass list of augmentation functions \n", - " self.seq = iaa.Sequential(augmentations, random_order=True) # apply augmenters in random order\n", - " if self.deform_augment:\n", - " self.deform_sigma, self.deform_points, self.deform_order = deform_augmentation_params\n", - "\n", - " def __len__(self):\n", - " # If various multi-page tiff files provided sum all images within each\n", - " if self.augment:\n", - " augment_factor = 4\n", - " else:\n", - " augment_factor = 1\n", - " \n", - " if self.dir_flag:\n", - " num_of_imgs = 0\n", - " for tiff_path in self.source_dir_list:\n", - " num_of_imgs += tifffile.imread(tiff_path).shape[0]\n", - " xy_shape = tifffile.imread(self.source_dir_list[0]).shape[1:]\n", - "\n", - " if self.is_val:\n", - " if self.random_crop:\n", - " crop_volume = self.shape[0] * self.shape[1] * self.shape[2]\n", - " volume = xy_shape[0] * xy_shape[1] * self.val_split * num_of_imgs\n", - " return math.floor(augment_factor * volume / (crop_volume * self.batch_size * self.downscale))\n", - " else:\n", - " return math.floor(self.val_split * num_of_imgs / self.batch_size)\n", - " else:\n", - " if self.random_crop:\n", - " crop_volume = self.shape[0] * self.shape[1] * self.shape[2]\n", - " volume = xy_shape[0] * xy_shape[1] * (1 - self.val_split) * num_of_imgs\n", - " return math.floor(augment_factor * volume / (crop_volume * self.batch_size * self.downscale))\n", - "\n", - " else:\n", - " return math.floor(augment_factor*(1 - self.val_split) * num_of_imgs/self.batch_size)\n", - " else:\n", - " if self.is_val:\n", - " if self.random_crop:\n", - " crop_volume = self.shape[0] * self.shape[1] * self.shape[2]\n", - " volume = self.source.shape[0] * self.source.shape[1] * self.val_split * self.source.shape[2]\n", - " return math.floor(augment_factor * volume / (crop_volume * self.batch_size * self.downscale))\n", - " else:\n", - " return math.floor((self.val_split * self.source.shape[0] / self.batch_size))\n", - " else:\n", - " if self.random_crop:\n", - " crop_volume = self.shape[0] * self.shape[1] * self.shape[2]\n", - " volume = self.source.shape[0] * self.source.shape[1] * (1 - self.val_split) * self.source.shape[2]\n", - " return math.floor(augment_factor * volume / (crop_volume * self.batch_size * self.downscale))\n", - " else:\n", - " return math.floor(augment_factor * (1 - self.val_split) * self.source.shape[0] / self.batch_size)\n", - "\n", - " def __getitem__(self, idx):\n", - " source_batch = np.empty((self.batch_size,\n", - " self.shape[0],\n", - " self.shape[1],\n", - " self.shape[2],\n", - " self.shape[3]))\n", - " target_batch = np.empty((self.batch_size,\n", - " self.shape[0],\n", - " self.shape[1],\n", - " self.shape[2],\n", - " self.shape[3]))\n", - "\n", - " for batch in range(self.batch_size):\n", - " # Modulo operator ensures IndexError is avoided\n", - " stack_start = self.batch_list[(idx+batch*self.shape[2])%len(self.batch_list)]\n", - "\n", - " if self.dir_flag:\n", - " self.source = tifffile.imread(self.source_dir_list[stack_start[0]])\n", - " if self.binary_target:\n", - " self.target = tifffile.imread(self.target_dir_list[stack_start[0]]).astype(np.bool)\n", - " else:\n", - " self.target = tifffile.imread(self.target_dir_list[stack_start[0]])\n", - "\n", - " src_list = []\n", - " tgt_list = []\n", - " for i in range(stack_start[1], stack_start[1]+self.shape[2]):\n", - " src = self.source[i]\n", - " src = transform.downscale_local_mean(src, (self.downscale, self.downscale))\n", - " if not self.random_crop:\n", - " src = transform.resize(src, (self.shape[0], self.shape[1]), mode='constant', preserve_range=True)\n", - " src = self._min_max_scaling(src)\n", - " src_list.append(src)\n", - "\n", - " tgt = self.target[i]\n", - " tgt = transform.downscale_local_mean(tgt, (self.downscale, self.downscale))\n", - " if not self.random_crop:\n", - " tgt = transform.resize(tgt, (self.shape[0], self.shape[1]), mode='constant', preserve_range=True)\n", - " if not self.binary_target:\n", - " tgt = self._min_max_scaling(tgt)\n", - " tgt_list.append(tgt)\n", - "\n", - " if self.random_crop:\n", - " if src.shape[0] == self.shape[0]:\n", - " x_rand = 0\n", - " if src.shape[1] == self.shape[1]:\n", - " y_rand = 0\n", - " if src.shape[0] > self.shape[0]:\n", - " x_rand = np.random.randint(src.shape[0] - self.shape[0])\n", - " if src.shape[1] > self.shape[1]:\n", - " y_rand = np.random.randint(src.shape[1] - self.shape[1])\n", - " if src.shape[0] < self.shape[0] or src.shape[1] < self.shape[1]:\n", - " raise ValueError('Patch shape larger than (downscaled) source shape')\n", - " \n", - " for i in range(self.shape[2]):\n", - " if self.random_crop:\n", - " src = src_list[i]\n", - " tgt = tgt_list[i]\n", - " src_crop = src[x_rand:self.shape[0]+x_rand, y_rand:self.shape[1]+y_rand]\n", - " tgt_crop = tgt[x_rand:self.shape[0]+x_rand, y_rand:self.shape[1]+y_rand]\n", - " else:\n", - " src_crop = src_list[i]\n", - " tgt_crop = tgt_list[i]\n", - "\n", - " source_batch[batch,:,:,i,0] = src_crop\n", - " target_batch[batch,:,:,i,0] = tgt_crop\n", - "\n", - " if self.augment:\n", - " # On-the-fly data augmentation\n", - " source_batch, target_batch = self.augment_volume(source_batch, target_batch)\n", - "\n", - " # Data augmentation by reversing stack\n", - " if np.random.random() > 0.5:\n", - " source_batch, target_batch = source_batch[::-1], target_batch[::-1]\n", - " \n", - " # Data augmentation by elastic deformation\n", - " if np.random.random() > 0.5 and self.deform_augment:\n", - " source_batch, target_batch = self.deform_volume(source_batch, target_batch)\n", - " \n", - " if not self.binary_target:\n", - " target_batch = self._min_max_scaling(target_batch)\n", - " \n", - " return self._min_max_scaling(source_batch), target_batch\n", - " \n", - " else:\n", - " return source_batch, target_batch\n", - "\n", - " def on_epoch_end(self):\n", - " # Validation split performed here\n", - " self.batch_list = []\n", - " # Create batch_list of all combinations of tifffile and stack position\n", - " if self.dir_flag:\n", - " for i in range(len(self.source_dir_list)):\n", - " num_of_pages = tifffile.imread(self.source_dir_list[i]).shape[0]\n", - " if self.is_val:\n", - " start_page = num_of_pages-math.floor(self.val_split*num_of_pages)\n", - " for j in range(start_page, num_of_pages-self.shape[2]):\n", - " self.batch_list.append([i, j])\n", - " else:\n", - " last_page = math.floor((1-self.val_split)*num_of_pages)\n", - " for j in range(last_page-self.shape[2]):\n", - " self.batch_list.append([i, j])\n", - " else:\n", - " num_of_pages = self.source.shape[0]\n", - " if self.is_val:\n", - " start_page = num_of_pages-math.floor(self.val_split*num_of_pages)\n", - " for j in range(start_page, num_of_pages-self.shape[2]):\n", - " self.batch_list.append([0, j])\n", - "\n", - " else:\n", - " last_page = math.floor((1-self.val_split)*num_of_pages)\n", - " for j in range(last_page-self.shape[2]):\n", - " self.batch_list.append([0, j])\n", - " \n", - " if self.is_val and (len(self.batch_list) <= 0):\n", - " raise ValueError('validation_split too small! Increase val_split or decrease z-depth')\n", - " random.shuffle(self.batch_list)\n", - " \n", - " def _min_max_scaling(self, data):\n", - " n = data - np.min(data)\n", - " d = np.max(data) - np.min(data) \n", - " \n", - " return n/d\n", - " \n", - " def class_weights(self):\n", - " ones = 0\n", - " pixels = 0\n", - "\n", - " if self.dir_flag:\n", - " for i in range(len(self.target_dir_list)):\n", - " tgt = tifffile.imread(self.target_dir_list[i]).astype(np.bool)\n", - " ones += np.sum(tgt)\n", - " pixels += tgt.shape[0]*tgt.shape[1]*tgt.shape[2]\n", - " else:\n", - " ones = np.sum(self.target)\n", - " pixels = self.target.shape[0]*self.target.shape[1]*self.target.shape[2]\n", - " p_ones = ones/pixels\n", - " p_zeros = 1-p_ones\n", - "\n", - " # Return swapped probability to increase weight of unlikely class\n", - " return p_ones, p_zeros\n", - "\n", - " def deform_volume(self, src_vol, tgt_vol):\n", - " [src_dfrm, tgt_dfrm] = elasticdeform.deform_random_grid([src_vol, tgt_vol],\n", - " axis=(1, 2, 3),\n", - " sigma=self.deform_sigma,\n", - " points=self.deform_points,\n", - " order=self.deform_order)\n", - " if self.binary_target:\n", - " tgt_dfrm = tgt_dfrm > 0.1\n", - " \n", - " return self._min_max_scaling(src_dfrm), tgt_dfrm \n", - "\n", - " def augment_volume(self, src_vol, tgt_vol):\n", - " src_vol_aug = np.empty(src_vol.shape)\n", - " tgt_vol_aug = np.empty(tgt_vol.shape)\n", - "\n", - " for i in range(src_vol.shape[3]):\n", - " src_vol_aug[:,:,:,i,0], tgt_vol_aug[:,:,:,i,0] = self.seq(images=src_vol[:,:,:,i,0].astype('float16'), \n", - " segmentation_maps=tgt_vol[:,:,:,i,0].astype(bool))\n", - " return self._min_max_scaling(src_vol_aug), tgt_vol_aug\n", - "\n", - " def sample_augmentation(self, idx):\n", - " src, tgt = self.__getitem__(idx)\n", - "\n", - " src_aug, tgt_aug = self.augment_volume(src, tgt)\n", - " \n", - " if self.deform_augment:\n", - " src_aug, tgt_aug = self.deform_volume(src_aug, tgt_aug)\n", - "\n", - " return src_aug, tgt_aug \n", - "\n", - "# Define custom loss and dice coefficient\n", - "def dice_coefficient(y_true, y_pred):\n", - " eps = 1e-6\n", - " y_true_f = K.flatten(y_true)\n", - " y_pred_f = K.flatten(y_pred)\n", - " intersection = K.sum(y_true_f*y_pred_f)\n", - "\n", - " return (2.*intersection)/(K.sum(y_true_f*y_true_f)+K.sum(y_pred_f*y_pred_f)+eps)\n", - "\n", - "def weighted_binary_crossentropy(zero_weight, one_weight):\n", - " def _weighted_binary_crossentropy(y_true, y_pred):\n", - " binary_crossentropy = K.binary_crossentropy(y_true, y_pred)\n", - "\n", - " weight_vector = y_true*one_weight+(1.-y_true)*zero_weight\n", - " weighted_binary_crossentropy = weight_vector*binary_crossentropy\n", - "\n", - " return K.mean(weighted_binary_crossentropy)\n", - "\n", - " return _weighted_binary_crossentropy\n", - "\n", - "# Custom callback showing sample prediction\n", - "class SampleImageCallback(Callback):\n", - "\n", - " def __init__(self, model, sample_data, model_path, save=False):\n", - " self.model = model\n", - " self.sample_data = sample_data\n", - " self.model_path = model_path\n", - " self.save = save\n", - "\n", - " def on_epoch_end(self, epoch, logs={}):\n", - " sample_predict = self.model.predict_on_batch(self.sample_data)\n", - "\n", - " f=plt.figure(figsize=(16,8))\n", - " plt.subplot(1,2,1)\n", - " plt.imshow(self.sample_data[0,:,:,0,0], interpolation='nearest', cmap='gray')\n", - " plt.title('Sample source')\n", - " plt.axis('off');\n", - "\n", - " plt.subplot(1,2,2)\n", - " plt.imshow(sample_predict[0,:,:,0,0], interpolation='nearest', cmap='magma')\n", - " plt.title('Predicted target')\n", - " plt.axis('off');\n", - "\n", - " plt.show()\n", - "\n", - " if self.save:\n", - " plt.savefig(self.model_path + '/epoch_' + str(epoch+1) + '.png')\n", - "\n", - "\n", - "# Define Unet3D class\n", - "class Unet3D:\n", - "\n", - " def __init__(self,\n", - " shape=(256,256,16,1)):\n", - " if isinstance(shape, str):\n", - " shape = eval(shape)\n", - "\n", - " self.shape = shape\n", - " \n", - " input_tensor = Input(self.shape, name='input')\n", - "\n", - " self.model = self.unet_3D(input_tensor)\n", - "\n", - " def down_block_3D(self, input_tensor, filters):\n", - " x = Conv3D(filters=filters, kernel_size=(3,3,3), padding='same')(input_tensor)\n", - " x = BatchNormalization()(x)\n", - " x = ReLU()(x)\n", - "\n", - " x = Conv3D(filters=filters*2, kernel_size=(3,3,3), padding='same')(x)\n", - " x = BatchNormalization()(x)\n", - " x = ReLU()(x)\n", - "\n", - " return x\n", - "\n", - " def up_block_3D(self, input_tensor, concat_layer, filters):\n", - " x = Conv3DTranspose(filters, kernel_size=(2,2,2), strides=(2,2,2))(input_tensor)\n", - "\n", - " x = Concatenate()([x, concat_layer])\n", - "\n", - " x = Conv3D(filters=filters, kernel_size=(3,3,3), padding='same')(x)\n", - " x = BatchNormalization()(x)\n", - " x = ReLU()(x)\n", - "\n", - " x = Conv3D(filters=filters*2, kernel_size=(3,3,3), padding='same')(x)\n", - " x = BatchNormalization()(x)\n", - " x = ReLU()(x)\n", - "\n", - " return x\n", - "\n", - " def unet_3D(self, input_tensor, filters=32):\n", - " d1 = self.down_block_3D(input_tensor, filters=filters)\n", - " p1 = MaxPooling3D(pool_size=(2,2,2), strides=(2,2,2), data_format='channels_last')(d1)\n", - " d2 = self.down_block_3D(p1, filters=filters*2)\n", - " p2 = MaxPooling3D(pool_size=(2,2,2), strides=(2,2,2), data_format='channels_last')(d2)\n", - " d3 = self.down_block_3D(p2, filters=filters*4)\n", - " p3 = MaxPooling3D(pool_size=(2,2,2), strides=(2,2,2), data_format='channels_last')(d3)\n", - "\n", - " d4 = self.down_block_3D(p3, filters=filters*8)\n", - "\n", - " u1 = self.up_block_3D(d4, d3, filters=filters*4)\n", - " u2 = self.up_block_3D(u1, d2, filters=filters*2)\n", - " u3 = self.up_block_3D(u2, d1, filters=filters)\n", - "\n", - " output_tensor = Conv3D(filters=1, kernel_size=(1,1,1), activation='sigmoid')(u3)\n", - "\n", - " return Model(inputs=[input_tensor], outputs=[output_tensor])\n", - "\n", - " def summary(self):\n", - " return self.model.summary()\n", - "\n", - " # Pass generators instead\n", - " def train(self, \n", - " epochs, \n", - " batch_size, \n", - " train_generator,\n", - " val_generator, \n", - " model_path, \n", - " model_name,\n", - " optimizer='adam',\n", - " learning_rate=0.001,\n", - " loss='weighted_binary_crossentropy',\n", - " metrics='dice',\n", - " ckpt_period=1, \n", - " save_best_ckpt_only=False, \n", - " ckpt_path=None):\n", - "\n", - " class_weight_zero, class_weight_one = train_generator.class_weights()\n", - " \n", - " if loss == 'weighted_binary_crossentropy':\n", - " loss = weighted_binary_crossentropy(class_weight_zero, class_weight_one)\n", - " \n", - " if metrics == 'dice':\n", - " metrics = dice_coefficient\n", - " \n", - " if optimizer == 'adam':\n", - " optimizer = Adam(lr=learning_rate)\n", - " elif optimizer == 'sgd':\n", - " optimizer = SGD(lr=learning_rate)\n", - " elif optimizer == 'rmsprop':\n", - " optimizer = RMSprop(lr=learning_rate)\n", - "\n", - " self.model.compile(optimizer=optimizer,\n", - " loss=loss,\n", - " metrics=[metrics])\n", - "\n", - " if ckpt_path is not None:\n", - " self.model.load_weights(ckpt_path)\n", - "\n", - " full_model_path = os.path.join(model_path, model_name)\n", - "\n", - " if not os.path.exists(full_model_path):\n", - " os.makedirs(full_model_path)\n", - " \n", - " log_dir = full_model_path + '/Quality Control'\n", - "\n", - " if not os.path.exists(log_dir):\n", - " os.makedirs(log_dir)\n", - " \n", - " ckpt_dir = full_model_path + '/ckpt'\n", - "\n", - " if not os.path.exists(ckpt_dir):\n", - " os.makedirs(ckpt_dir)\n", - "\n", - " csv_out_name = log_dir + '/training_evaluation.csv'\n", - " if ckpt_path is None:\n", - " csv_logger = CSVLogger(csv_out_name)\n", - " else:\n", - " csv_logger = CSVLogger(csv_out_name, append=True)\n", - "\n", - " if save_best_ckpt_only:\n", - " ckpt_name = ckpt_dir + '/' + model_name + '.hdf5'\n", - " else:\n", - " ckpt_name = ckpt_dir + '/' + model_name + '_epoch_{epoch:02d}_val_loss_{val_loss:.4f}.hdf5'\n", - " \n", - " model_ckpt = ModelCheckpoint(ckpt_name,\n", - " verbose=1,\n", - " period=ckpt_period,\n", - " save_best_only=save_best_ckpt_only,\n", - " save_weights_only=True)\n", - "\n", - " sample_batch, __ = val_generator.__getitem__(random.randint(0, len(val_generator)))\n", - " sample_img = SampleImageCallback(self.model, \n", - " sample_batch, \n", - " model_path)\n", - "\n", - " self.model.fit_generator(generator=train_generator,\n", - " validation_data=val_generator,\n", - " validation_steps=math.floor(len(val_generator)/batch_size),\n", - " epochs=epochs,\n", - " callbacks=[csv_logger,\n", - " model_ckpt,\n", - " sample_img])\n", - "\n", - " last_ckpt_name = ckpt_dir + '/' + model_name + '_last.hdf5'\n", - " self.model.save_weights(last_ckpt_name)\n", - "\n", - " def _min_max_scaling(self, data):\n", - " n = data - np.min(data)\n", - " d = np.max(data) - np.min(data) \n", - " \n", - " return n/d\n", - "\n", - " def predict(self, \n", - " input, \n", - " ckpt_path, \n", - " z_range=None, \n", - " downscaling=None, \n", - " true_patch_size=None):\n", - "\n", - " self.model.load_weights(ckpt_path)\n", - "\n", - " if isinstance(downscaling, str):\n", - " downscaling = eval(downscaling)\n", - "\n", - " if math.isnan(downscaling):\n", - " downscaling = None\n", - "\n", - " if isinstance(true_patch_size, str):\n", - " true_patch_size = eval(true_patch_size)\n", - " \n", - " if not isinstance(true_patch_size, tuple): \n", - " if math.isnan(true_patch_size):\n", - " true_patch_size = None\n", - "\n", - " if isinstance(input, str):\n", - " src_volume = tifffile.imread(input)\n", - " elif isinstance(input, np.ndarray):\n", - " src_volume = input\n", - " else:\n", - " raise TypeError('Input is not path or numpy array!')\n", - " \n", - " in_size = src_volume.shape\n", - "\n", - " if downscaling or true_patch_size is not None:\n", - " x_scaling = 0\n", - " y_scaling = 0\n", - "\n", - " if true_patch_size is not None:\n", - " x_scaling += true_patch_size[0]/self.shape[0]\n", - " y_scaling += true_patch_size[1]/self.shape[1]\n", - " if downscaling is not None:\n", - " x_scaling += downscaling\n", - " y_scaling += downscaling\n", - "\n", - " src_list = []\n", - " for i in range(src_volume.shape[0]):\n", - " src_list.append(transform.downscale_local_mean(src_volume[i], (int(x_scaling), int(y_scaling))))\n", - " src_volume = np.array(src_list) \n", - "\n", - " if z_range is not None:\n", - " src_volume = src_volume[z_range[0]:z_range[1]]\n", - "\n", - " src_volume = self._min_max_scaling(src_volume) \n", - "\n", - " src_array = np.zeros((1,\n", - " math.ceil(src_volume.shape[1]/self.shape[0])*self.shape[0], \n", - " math.ceil(src_volume.shape[2]/self.shape[1])*self.shape[1],\n", - " math.ceil(src_volume.shape[0]/self.shape[2])*self.shape[2], \n", - " self.shape[3]))\n", - "\n", - " for i in range(src_volume.shape[0]):\n", - " src_array[0,:src_volume.shape[1],:src_volume.shape[2],i,0] = src_volume[i]\n", - "\n", - " pred_array = np.empty(src_array.shape)\n", - "\n", - " for i in range(math.ceil(src_volume.shape[1]/self.shape[0])):\n", - " for j in range(math.ceil(src_volume.shape[2]/self.shape[1])):\n", - " for k in range(math.ceil(src_volume.shape[0]/self.shape[2])):\n", - " pred_temp = self.model.predict(src_array[:,\n", - " i*self.shape[0]:i*self.shape[0]+self.shape[0],\n", - " j*self.shape[1]:j*self.shape[1]+self.shape[1],\n", - " k*self.shape[2]:k*self.shape[2]+self.shape[2]])\n", - " pred_array[:,\n", - " i*self.shape[0]:i*self.shape[0]+self.shape[0],\n", - " j*self.shape[1]:j*self.shape[1]+self.shape[1],\n", - " k*self.shape[2]:k*self.shape[2]+self.shape[2]] = pred_temp\n", - " \n", - " pred_volume = np.rollaxis(np.squeeze(pred_array), -1)[:src_volume.shape[0],:src_volume.shape[1],:src_volume.shape[2]] \n", - "\n", - " if downscaling is not None:\n", - " pred_list = []\n", - " for i in range(pred_volume.shape[0]):\n", - " pred_list.append(transform.resize(pred_volume[i], (in_size[1], in_size[2]), preserve_range=True))\n", - " pred_volume = np.array(pred_list)\n", - "\n", - " return pred_volume\n", - "\n", - "def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n", - " class MyFPDF(FPDF, HTMLMixin):\n", - " pass\n", - "\n", - " pdf = MyFPDF()\n", - " pdf.add_page()\n", - " pdf.set_right_margin(-1)\n", - " pdf.set_font(\"Arial\", size = 11, style='B') \n", - "\n", - " Network = 'U-Net 3D'\n", - " day = datetime.now()\n", - " datetime_str = str(day)[0:10]\n", - "\n", - " Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n", - " pdf.multi_cell(180, 5, txt = Header, align = 'L') \n", - " \n", - " # add another cell \n", - " if trained:\n", - " training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n", - " pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n", - " pdf.ln(1)\n", - "\n", - " Header_2 = 'Information for your materials and methods:'\n", - " pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n", - "\n", - " all_packages = ''\n", - " for requirement in freeze(local_only=True):\n", - " all_packages = all_packages+requirement+', '\n", - " #print(all_packages)\n", - "\n", - " #Main Packages\n", - " main_packages = ''\n", - " version_numbers = []\n", - " for name in ['tensorflow','numpy','Keras']:\n", - " find_name=all_packages.find(name)\n", - " main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n", - " #Version numbers only here:\n", - " version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n", - "\n", - " cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n", - " cuda_version = cuda_version.stdout.decode('utf-8')\n", - " cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n", - " gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n", - " gpu_name = gpu_name.stdout.decode('utf-8')\n", - " gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n", - " #print(cuda_version[cuda_version.find(', V')+3:-1])\n", - " #print(gpu_name)\n", - "\n", - " if os.path.isdir(training_source):\n", - " shape = io.imread(training_source+'/'+os.listdir(training_source)[0]).shape\n", - " elif os.path.isfile(training_source):\n", - " shape = io.imread(training_source).shape\n", - " else:\n", - " print('Cannot read training data.')\n", - "\n", - " dataset_size = len(train_generator)\n", - "\n", - " text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+') with a batch size of '+str(batch_size)+' and a '+loss_function+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n", - "\n", - " if pretrained_model:\n", - " text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch_size: '+str(patch_size)+') with a batch size of '+str(batch_size)+' and a '+loss_function+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was retrained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n", - "\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - " pdf.multi_cell(190, 5, txt = text, align='L')\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.ln(1)\n", - " pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n", - " pdf.set_font('')\n", - " if augmentation:\n", - " aug_text = 'The dataset was augmented by'\n", - " if add_gaussian_blur == True:\n", - " aug_text = aug_text+'\\n- gaussian blur'\n", - " if add_linear_contrast == True:\n", - " aug_text = aug_text+'\\n- linear contrast'\n", - " if add_additive_gaussian_noise == True:\n", - " aug_text = aug_text+'\\n- additive gaussian noise'\n", - " if augmenters != '':\n", - " aug_text = aug_text+'\\n- imgaug augmentations: '+augmenters\n", - " if add_elastic_deform == True:\n", - " aug_text = aug_text+'\\n- elastic deformation'\n", - " else:\n", - " aug_text = 'No augmentation was used for training.'\n", - " pdf.multi_cell(190, 5, txt=aug_text, align='L')\n", - " pdf.set_font('Arial', size = 11, style = 'B')\n", - " pdf.ln(1)\n", - " pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - " if use_default_advanced_parameters:\n", - " pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n", - " pdf.cell(200, 5, txt='The following parameters were used for training:')\n", - " pdf.ln(1)\n", - " html = \"\"\" \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
ParameterValue
number_of_epochs{0}
batch_size{1}
patch_size{2}
image_pre_processing{3}
validation_split_in_percent{4}
downscaling_in_xy{5}
binary_target{6}
loss_function{7}
metrics{8}
optimizer{9}
learning_rate{10}
checkpointing_period{11}
save_best_only{12}
resume_training{13}
\n", - " \"\"\".format(number_of_epochs,batch_size,str(patch_size[0])+'x'+str(patch_size[1])+'x'+str(patch_size[2]),image_pre_processing, validation_split_in_percent, downscaling_in_xy, str(binary_target), loss_function, metrics, optimizer, learning_rate, checkpointing_period, str(save_best_only), str(resume_training))\n", - " pdf.write_html(html)\n", - "\n", - " #pdf.multi_cell(190, 5, txt = text_2, align='L')\n", - " pdf.set_font(\"Arial\", size = 11, style='B')\n", - " pdf.ln(1)\n", - " pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.cell(30, 5, txt= 'Training_source:', align = 'L', ln=0)\n", - " pdf.set_font('')\n", - " pdf.multi_cell(170, 5, txt = training_source, align = 'L')\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.cell(28, 5, txt= 'Training_target:', align = 'L', ln=0)\n", - " pdf.set_font('')\n", - " pdf.multi_cell(170, 5, txt = training_target, align = 'L')\n", - " #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n", - " pdf.ln(1)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.cell(21, 5, txt= 'Model Path:', align = 'L', ln=0)\n", - " pdf.set_font('')\n", - " pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n", - " pdf.ln(1)\n", - " pdf.cell(60, 5, txt = 'Example Training pair (single slice)', ln=1)\n", - " pdf.ln(1)\n", - " exp_size = io.imread('/content/TrainingDataExample_Unet3D.png').shape\n", - " pdf.image('/content/TrainingDataExample_Unet3D.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n", - " pdf.ln(1)\n", - " ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" bioRxiv (2020).'\n", - " pdf.multi_cell(190, 5, txt = ref_1, align='L')\n", - " ref_2 = '- Unet 3D: Çiçek, Özgün, et al. \"3D U-Net: learning dense volumetric segmentation from sparse annotation.\" International conference on medical image computing and computer-assisted intervention. Springer, Cham, 2016.'\n", - " pdf.multi_cell(190, 5, txt = ref_2, align='L')\n", - " # if Use_Data_augmentation:\n", - " # ref_4 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\n", - " # pdf.multi_cell(190, 5, txt = ref_4, align='L')\n", - " pdf.ln(3)\n", - " reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n", - " pdf.set_font('Arial', size = 11, style='B')\n", - " pdf.multi_cell(190, 5, txt=reminder, align='C')\n", - "\n", - " pdf.output(model_path+'/'+model_name+'/'+model_name+'_training_report.pdf')\n", - "\n", - " print('------------------------------')\n", - " print('PDF report exported in '+model_path+'/'+model_name+'/')\n", - "\n", - "\n", - "def qc_pdf_export():\n", - " class MyFPDF(FPDF, HTMLMixin):\n", - " pass\n", - "\n", - " pdf = MyFPDF()\n", - " pdf.add_page()\n", - " pdf.set_right_margin(-1)\n", - " pdf.set_font(\"Arial\", size = 11, style='B') \n", - "\n", - " Network = 'U-Net 3D'\n", - "\n", - " day = datetime.now()\n", - " datetime_str = str(day)[0:10]\n", - "\n", - " Header = 'Quality Control report for '+Network+' model ('+qc_model_name+')\\nDate: '+datetime_str\n", - " pdf.multi_cell(180, 5, txt = Header, align = 'L') \n", - "\n", - " all_packages = ''\n", - " for requirement in freeze(local_only=True):\n", - " all_packages = all_packages+requirement+', '\n", - "\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 11, style = 'B')\n", - " pdf.ln(2)\n", - " pdf.cell(190, 5, txt = 'Loss curves', ln=1, align='L')\n", - " pdf.ln(1)\n", - " if os.path.exists(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/lossCurvePlots.png'):\n", - " exp_size = io.imread(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/lossCurvePlots.png').shape\n", - " pdf.image(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n", - " else:\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size=10)\n", - " pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.')\n", - " pdf.ln(2)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.ln(3)\n", - " pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n", - " pdf.ln(1)\n", - " exp_size = io.imread(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/QC_example_data.png').shape\n", - " pdf.image(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n", - " pdf.ln(1)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 11, style = 'B')\n", - " pdf.ln(1)\n", - " pdf.cell(180, 5, txt = 'IoU threshold optimisation', align='L', ln=1)\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - " pdf.ln(1)\n", - " pdf.cell(120, 5, txt='Highest IoU is {:.4f} with a threshold of {}'.format(best_iou, best_thresh), align='L', ln=1)\n", - " pdf.ln(2)\n", - " exp_size = io.imread(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/QC_IoU_analysis.png').shape\n", - " pdf.image(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/QC_IoU_analysis.png', x=16, y=None, w = round(exp_size[1]/6), h = round(exp_size[0]/6))\n", - " pdf.ln(1)\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - " ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" bioRxiv (2020).'\n", - " pdf.multi_cell(190, 5, txt = ref_1, align='L')\n", - " ref_2 = '- Unet 3D: Çiçek, Özgün, et al. \"3D U-Net: learning dense volumetric segmentation from sparse annotation.\" International conference on medical image computing and computer-assisted intervention. Springer, Cham, 2016.'\n", - " pdf.multi_cell(190, 5, txt = ref_2, align='L')\n", - "\n", - " pdf.ln(3)\n", - " reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n", - "\n", - " pdf.set_font('Arial', size = 11, style='B')\n", - " pdf.multi_cell(190, 5, txt=reminder, align='C')\n", - "\n", - " pdf.output(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/'+qc_model_name+'_QC_report.pdf')\n", - "\n", - " print('------------------------------')\n", - " print('QC PDF report exported in '+os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/')\n", - "\n", - "\n", - "# -------------- Other definitions -----------\n", - "W = '\\033[0m' # white (normal)\n", - "R = '\\033[31m' # red\n", - "prediction_prefix = 'Predicted_'\n", - "\n", - "\n", - "print('-------------------')\n", - "print('U-Net 3D and dependencies installed.')\n", - "\n", - "# Colors for the warning messages\n", - "class bcolors:\n", - " WARNING = '\\033[31m'\n", - " NORMAL = '\\033[0m' # white (normal)\n", - " \n", - "\n", - "# Check if this is the latest version of the notebook\n", - "Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n", - "\n", - "if Notebook_version == list(Latest_notebook_version.columns):\n", - " print(\"This notebook is up-to-date.\")\n", - "\n", - "if not Notebook_version == list(Latest_notebook_version.columns):\n", - " print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n", - "\n", - "# Build requirements file for local run\n", - "after = [str(m) for m in sys.modules]\n", - "build_requirements_file(before, after)\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "HLYcZR9gMv42" - }, - "source": [ - "# **3. Select your parameters and paths**\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "AuESFimvMv43" - }, - "source": [ - "## **3.1. Choosing parameters**\n", - "\n", - "---\n", - "\n", - "### **Paths to training data and model**\n", - "\n", - "* **`training_source`** and **`training_target`** specify the paths to the training data. They can either be a single multipage TIFF file each or directories containing various multipage TIFF files in which case target and source files must be named identically within the respective directories. See Section 0 for a detailed description of the necessary directory structure.\n", - "\n", - "* **`model_name`** will be used when naming checkpoints. Adhere to a `lower_case_with_underscores` naming convention and beware of using the name of an existing model within the same folder, as it will be overwritten.\n", - "\n", - "* **`model_path`** specifies the directory where the model checkpoints and quality control logs will be saved.\n", - "\n", - "\n", - "**Note:** You can copy paths from the 'Files' tab by right-clicking any folder or file and selecting 'Copy path'. \n", - "\n", - "### **Training parameters**\n", - "\n", - "* **`number_of_epochs`** is the number of times the entire training data will be seen by the model. *Default: >100*\n", - "\n", - "* **`batch_size`** is the number of training patches of size `patch_size` that will be bundled together at each training step. *Default: 1*\n", - "\n", - "* **`patch_size`** specifies the size of the three-dimensional training patches in (x, y, z) that will be fed to the model. In order to avoid errors, preferably use a square aspect ratio or stick to the advanced parameters. *Default: <(512, 512, 16)*\n", - "\n", - "* **`validation_split_in_percent`** is the relative amount of training data that will be set aside for validation. *Default: 20* \n", - "\n", - "* **`downscaling_in_xy`** downscales the training images by the specified amount in x and y. This is useful to enforce isotropic pixel-size if the z resolution is lower than the xy resolution in the training volume or to capture a larger field-of-view while decreasing the memory requirements. *Default: 1*\n", - "\n", - "* **`image_pre_processing`** selects whether the training images are randomly cropped during training or resized to `patch_size`. Choose `randomly crop to patch_size` to shrink the field-of-view of the training images to the `patch_size`. *Default: resize to patch_size* \n", - "\n", - "* **`binary_target`** forces the target image to be binary. Choose this if your model is trained to perform binary segmentation tasks *Default: True* \n", - "\n", - "* **`loss_function`** defines the loss. Read more [here](https://keras.io/api/losses/). *Default: weighted_binary_crossentropy* \n", - "\n", - "* **`metrics`** defines the metric. Read more [here](https://keras.io/api/metrics/). *Default: dice* \n", - "\n", - "* **`optimizer`** defines the optimizer. Read more [here](https://keras.io/api/optimizers/). *Default: adam* \n", - "\n", - "* **`learning_rate`** defines the learning rate. Read more [here](https://keras.io/api/optimizers/). *Default: 0.001* \n", - "\n", - "**Note:** If a *ResourceExhaustedError* is raised in Section 4.1. during training, decrease `batch_size` and `patch_size`. Decrease `batch_size` first and if the error persists at `batch_size = 1`, reduce the `patch_size`. \n", - "\n", - "**Note:** The number of steps per epoch are calculated as `floor(augment_factor * (1 - validation_split) * num_of_slices / batch_size)` if `image_pre_processing` is `resize to patch_size` where `augment_factor` is three if `apply_data_augmentation` is `True` and one otherwise. The `num_of_slices` is the overall number of slices (z-depth) in the training set across all provided image volumes. If `image_pre_processing` is `randomly crop to patch_size`, the number of steps per epoch are calculated as `floor(augment_factor * volume / (crop_volume * batch_size))` where `volume` is the overall volume of the training data in pixels accounting for the validation split and `crop_volume` is defined as the volume in pixels based on the specified `patch_size`." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "ewpNJ_I0Mv47", - "cellView": "form" - }, - "source": [ - "#@markdown ###Path to training data:\n", - "training_source = \"\" #@param {type:\"string\"}\n", - "training_target = \"\" #@param {type:\"string\"}\n", - "\n", - "#@markdown ---\n", - "\n", - "#@markdown ###Model name and path to model folder:\n", - "model_name = \"\" #@param {type:\"string\"}\n", - "model_path = \"\" #@param {type:\"string\"}\n", - "\n", - "full_model_path = os.path.join(model_path, model_name)\n", - "\n", - "#@markdown ---\n", - "\n", - "#@markdown ###Training parameters\n", - "number_of_epochs = 200#@param {type:\"number\"}\n", - "\n", - "#@markdown ###Default advanced parameters\n", - "use_default_advanced_parameters = True #@param {type:\"boolean\"}\n", - "\n", - "#@markdown If not, please change:\n", - "\n", - "batch_size = 1#@param {type:\"number\"}\n", - "patch_size = (256,256,4) #@param {type:\"number\"} # in pixels\n", - "training_shape = patch_size + (1,)\n", - "image_pre_processing = 'randomly crop to patch_size' #@param [\"randomly crop to patch_size\", \"resize to patch_size\"]\n", - "\n", - "validation_split_in_percent = 20 #@param{type:\"number\"}\n", - "downscaling_in_xy = 1#@param {type:\"number\"} # in pixels\n", - "\n", - "binary_target = True #@param {type:\"boolean\"}\n", - "\n", - "loss_function = 'weighted_binary_crossentropy' #@param [\"weighted_binary_crossentropy\", \"binary_crossentropy\", \"categorical_crossentropy\", \"sparse_categorical_crossentropy\", \"mean_squared_error\", \"mean_absolute_error\"]\n", - "\n", - "metrics = 'dice' #@param [\"dice\", \"accuracy\"]\n", - "\n", - "optimizer = 'adam' #@param [\"adam\", \"sgd\", \"rmsprop\"]\n", - "\n", - "learning_rate = 0.0001 #@param{type:\"number\"}\n", - "\n", - "if image_pre_processing == \"randomly crop to patch_size\":\n", - " random_crop = True\n", - "else:\n", - " random_crop = False\n", - "\n", - "if use_default_advanced_parameters: \n", - " print(\"Default advanced parameters enabled\")\n", - " batch_size = 1\n", - " training_shape = (256,256,8,1)\n", - " validation_split_in_percent = 20\n", - " downscaling_in_xy = 1\n", - " random_crop = True\n", - " binary_target = True\n", - " loss_function = 'weighted_binary_crossentropy'\n", - " metrics = 'dice'\n", - " optimizer = 'adam'\n", - " learning_rate = 0.001 \n", - " \n", - "#@markdown ###Checkpointing parameters\n", - "checkpointing_period = 1 #@param {type:\"number\"}\n", - "\n", - "#@markdown If chosen only the best checkpoint is saved, otherwise a checkpoint is saved every checkpoint_period epochs:\n", - "save_best_only = True #@param {type:\"boolean\"}\n", - "\n", - "#@markdown ###Resume training\n", - "#@markdown Choose if training was interrupted:\n", - "resume_training = False #@param {type:\"boolean\"}\n", - "\n", - "#@markdown ###Transfer learning\n", - "#@markdown For transfer learning, do not select resume_training and specify a checkpoint_path below:\n", - "checkpoint_path = \"\" #@param {type:\"string\"}\n", - "\n", - "if resume_training and checkpoint_path != \"\":\n", - " print('If resume_training is True while checkpoint_path is specified, resume_training will be set to False!')\n", - " resume_training = False\n", - " \n", - "\n", - "# Retrieve last checkpoint\n", - "if resume_training:\n", - " try:\n", - " ckpt_dir_list = glob(full_model_path + '/ckpt/*')\n", - " ckpt_dir_list.sort()\n", - " last_ckpt_path = ckpt_dir_list[-1]\n", - " print('Training will resume from checkpoint:', os.path.basename(last_ckpt_path))\n", - " except IndexError:\n", - " last_ckpt_path=None\n", - " print('CheckpointError: No previous checkpoints were found, training from scratch.')\n", - "elif not resume_training and checkpoint_path != \"\":\n", - " last_ckpt_path = checkpoint_path\n", - " assert os.path.isfile(last_ckpt_path), 'checkpoint_path does not exist!'\n", - "else:\n", - " last_ckpt_path=None\n", - "\n", - "# Instantiate Unet3D \n", - "model = Unet3D(shape=training_shape)\n", - "\n", - "#here we check that no model with the same name already exist\n", - "if not resume_training and os.path.exists(full_model_path): \n", - " print(bcolors.WARNING+'The model folder already exists and will be overwritten.'+bcolors.NORMAL)\n", - " # print('!! WARNING: Folder already exists and will be overwritten !!') \n", - " # shutil.rmtree(full_model_path)\n", - "\n", - "# if not os.path.exists(full_model_path):\n", - "# os.makedirs(full_model_path)\n", - "\n", - "# Show sample image\n", - "if os.path.isdir(training_source):\n", - " training_source_sample = sorted(glob(os.path.join(training_source, '*')))[0]\n", - " training_target_sample = sorted(glob(os.path.join(training_target, '*')))[0]\n", - "else:\n", - " training_source_sample = training_source\n", - " training_target_sample = training_target\n", - "\n", - "src_sample = tifffile.imread(training_source_sample)\n", - "src_sample = model._min_max_scaling(src_sample)\n", - "if binary_target:\n", - " tgt_sample = tifffile.imread(training_target_sample).astype(np.bool)\n", - "else:\n", - " tgt_sample = tifffile.imread(training_target_sample)\n", - "\n", - "src_down = transform.downscale_local_mean(src_sample[0], (downscaling_in_xy, downscaling_in_xy))\n", - "tgt_down = transform.downscale_local_mean(tgt_sample[0], (downscaling_in_xy, downscaling_in_xy)) \n", - "\n", - "if random_crop:\n", - " true_patch_size = None\n", - "\n", - " if src_down.shape[0] == training_shape[0]:\n", - " x_rand = 0\n", - " if src_down.shape[1] == training_shape[1]:\n", - " y_rand = 0\n", - " if src_down.shape[0] > training_shape[0]:\n", - " x_rand = np.random.randint(src_down.shape[0] - training_shape[0])\n", - " if src_down.shape[1] > training_shape[1]:\n", - " y_rand = np.random.randint(src_down.shape[1] - training_shape[1])\n", - " if src_down.shape[0] < training_shape[0] or src_down.shape[1] < training_shape[1]:\n", - " raise ValueError('Patch shape larger than (downscaled) source shape')\n", - "else:\n", - " true_patch_size = src_down.shape\n", - "\n", - "def scroll_in_z(z):\n", - " src_down = transform.downscale_local_mean(src_sample[z-1], (downscaling_in_xy,downscaling_in_xy))\n", - " tgt_down = transform.downscale_local_mean(tgt_sample[z-1], (downscaling_in_xy,downscaling_in_xy)) \n", - " if random_crop:\n", - " src_slice = src_down[x_rand:training_shape[0]+x_rand, y_rand:training_shape[1]+y_rand]\n", - " tgt_slice = tgt_down[x_rand:training_shape[0]+x_rand, y_rand:training_shape[1]+y_rand]\n", - " else:\n", - " \n", - " src_slice = transform.resize(src_down, (training_shape[0], training_shape[1]), mode='constant', preserve_range=True)\n", - " tgt_slice = transform.resize(tgt_down, (training_shape[0], training_shape[1]), mode='constant', preserve_range=True)\n", - "\n", - " f=plt.figure(figsize=(16,8))\n", - " plt.subplot(1,2,1)\n", - " plt.imshow(src_slice, cmap='gray')\n", - " plt.title('Training source (z = ' + str(z) + ')', fontsize=15)\n", - " plt.axis('off')\n", - "\n", - " plt.subplot(1,2,2)\n", - " plt.imshow(tgt_slice, cmap='magma')\n", - " plt.title('Training target (z = ' + str(z) + ')', fontsize=15)\n", - " plt.axis('off')\n", - " plt.savefig('/content/TrainingDataExample_Unet3D.png',bbox_inches='tight',pad_inches=0)\n", - " #plt.close()\n", - "\n", - "print('This is what the training images will look like with the chosen settings')\n", - "interact(scroll_in_z, z=widgets.IntSlider(min=1, max=src_sample.shape[0], step=1, value=0));\n", - "plt.show()\n", - "#Create a copy of an example slice and close the display.\n", - "scroll_in_z(z=int(src_sample.shape[0]/2))\n", - "# If you close the display, then the users can't interactively inspect the data\n", - "# plt.close()\n", - "\n", - "# Save model parameters\n", - "params = {'training_source': training_source,\n", - " 'training_target': training_target,\n", - " 'model_name': model_name,\n", - " 'model_path': model_path,\n", - " 'number_of_epochs': number_of_epochs,\n", - " 'batch_size': batch_size,\n", - " 'training_shape': training_shape,\n", - " 'downscaling': downscaling_in_xy,\n", - " 'true_patch_size': true_patch_size,\n", - " 'val_split': validation_split_in_percent/100,\n", - " 'random_crop': random_crop}\n", - "\n", - "params_df = pd.DataFrame.from_dict(params, orient='index')\n", - "\n", - "# apply_data_augmentation = False\n", - "# pdf_export(augmentation = apply_data_augmentation, pretrained_model = resume_training)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "w_jCy7xOx2g3" - }, - "source": [ - "## **3.2. Data augmentation**\n", - " \n", - "---\n", - " Augmenting the training data increases robustness of the model by simulating possible variations within the training data which avoids it from overfitting on small datasets. We therefore strongly recommended augmenting the data and making sure that the applied augmentations are reasonable.\n", - "\n", - "* **Gaussian blur** blurs images using Gaussian kernels with a sigma of `gaussian_sigma`. This augmentation step is applied with a probability of `gaussian_frequency`. Read more [here](https://imgaug.readthedocs.io/en/latest/source/overview/blur.html#gaussianblur).\n", - "\n", - "* **Linear contrast** modifies the contrast of images according to `127 + alpha *(pixel_value-127)`, where `pixel_value` and `alpha` are sampled uniformly from the interval `[contrast_min, contrast_max]`. This augmentation step is applied with a probability of `contrast_frequency`. Read more [here](https://imgaug.readthedocs.io/en/latest/source/overview/contrast.html#linearcontrast).\n", - "\n", - "* **Additive Gaussian noise** adds Gaussian noise sampled once per pixel from a normal distribution `N(0, s)`, where `s` is sampled from `[scale_min, scale_max]`. This augmentation step is applied with a probability of `noise_frequency`. Read more [here](https://imgaug.readthedocs.io/en/latest/source/overview/arithmetic.html#additivegaussiannoise).\n", - "\n", - "* **Add custom augmenters** allows you to create a custom augmentation pipeline using the [augmenters available in the imagug library](https://imgaug.readthedocs.io/en/latest/source/overview_of_augmenters.html).\n", - "In the example above, the augmentation pipeline is equivalent to: \n", - "```\n", - "seq = iaa.Sequential([\n", - " iaa.Sometimes(0.3, iaa.GammaContrast((0.5, 2.0)), \n", - " iaa.Sometimes(0.4, iaa.AverageBlur((0.5, 2.0)), \n", - " iaa.Sometimes(0.5, iaa.LinearContrast((0.4, 1.6)), \n", - "], random_order=True)\n", - "```\n", - " Note that there is no limit on the number of augmenters that can be chained together and that individual augmenter and parameter entries must be separated by `;`. Custom augmenters do not overwrite the preset augmentation steps (*Gaussian blur*, *Linear contrast* or *Additive Gaussian noise*). Also, the augmenters, augmenter parameters and augmenter frequencies must be entered such that each position within the string corresponds to the same augmentation step.\n", - "\n", - "* **`apply_data_augmentation`** ensures that data augmentation is randomly applied to the training data at each training step. This includes inverting the order of the slices within a training patch, as well as applying any augmenters that are added. *Default: True*\n", - "\n", - "* **`add_elastic_deform`** ensures that elastic grid-based deformations are applied as described in the original 3D U-Net paper. *Default: True*" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "DMqWq5-AxnFU", - "cellView": "form" - }, - "source": [ - "#@markdown ##**Augmentation options**\n", - "\n", - "#@markdown ###Data augmentation\n", - "\n", - "apply_data_augmentation = True #@param {type:\"boolean\"}\n", - "\n", - "# List of augmentations\n", - "augmentations = []\n", - "\n", - "#@markdown ###Gaussian blur\n", - "add_gaussian_blur = True #@param {type:\"boolean\"}\n", - "gaussian_sigma = 0.7#@param {type:\"number\"}\n", - "gaussian_frequency = 0.5 #@param {type:\"number\"}\n", - "\n", - "if add_gaussian_blur:\n", - " augmentations.append(iaa.Sometimes(gaussian_frequency, iaa.GaussianBlur(sigma=(0, gaussian_sigma))))\n", - "\n", - "#@markdown ###Linear contrast\n", - "add_linear_contrast = True #@param {type:\"boolean\"}\n", - "contrast_min = 0.4 #@param {type:\"number\"}\n", - "contrast_max = 1.6#@param {type:\"number\"}\n", - "contrast_frequency = 0.5 #@param {type:\"number\"}\n", - "\n", - "if add_linear_contrast:\n", - " augmentations.append(iaa.Sometimes(contrast_frequency, iaa.LinearContrast((contrast_min, contrast_max))))\n", - "\n", - "#@markdown ###Additive Gaussian noise\n", - "add_additive_gaussian_noise = False #@param {type:\"boolean\"}\n", - "scale_min = 0 #@param {type:\"number\"}\n", - "scale_max = 0.05 #@param {type:\"number\"}\n", - "noise_frequency = 0.5 #@param {type:\"number\"}\n", - "\n", - "if add_additive_gaussian_noise:\n", - " augmentations.append(iaa.Sometimes(noise_frequency, iaa.AdditiveGaussianNoise(scale=(scale_min, scale_max))))\n", - "\n", - "#@markdown ###Add custom augmenters\n", - "add_custom_augmenters = False #@param {type:\"boolean\"}\n", - "augmenters = \"\"\n", - "if add_custom_augmenters:\n", - " augmenters = \"\" #@param {type:\"string\"}\n", - "\n", - " augmenter_params = \"\" #@param {type:\"string\"}\n", - "\n", - " augmenter_frequency = \"\" #@param {type:\"string\"}\n", - "\n", - " aug_lst = augmenters.split(';')\n", - " aug_params_lst = augmenter_params.split(';')\n", - " aug_freq_lst = augmenter_frequency.split(';')\n", - "\n", - " assert len(aug_lst) == len(aug_params_lst) and len(aug_lst) == len(aug_freq_lst), 'The number of arguments in augmenters, augmenter_params and augmenter_frequency are not the same!'\n", - "\n", - " for __, (aug, param, freq) in enumerate(zip(aug_lst, aug_params_lst, aug_freq_lst)):\n", - " aug, param, freq = aug.strip(), param.strip(), freq.strip() \n", - " aug_func = iaa.Sometimes(eval(freq), getattr(iaa, aug)(eval(param)))\n", - " augmentations.append(aug_func)\n", - "\n", - "#@markdown ###Elastic deformations\n", - "add_elastic_deform = True #@param {type:\"boolean\"}\n", - "sigma = 2#@param {type:\"number\"}\n", - "points = 2#@param {type:\"number\"}\n", - "order = 2#@param {type:\"number\"}\n", - "\n", - "if add_elastic_deform:\n", - " deform_params = (sigma, points, order)\n", - "else:\n", - " deform_params = None\n", - "\n", - "train_generator = MultiPageTiffGenerator(training_source,\n", - " training_target,\n", - " batch_size=batch_size,\n", - " shape=training_shape,\n", - " augment=apply_data_augmentation,\n", - " augmentations=augmentations,\n", - " deform_augment=add_elastic_deform,\n", - " deform_augmentation_params=deform_params,\n", - " val_split=validation_split_in_percent/100,\n", - " random_crop=random_crop,\n", - " downscale=downscaling_in_xy,\n", - " binary_target=binary_target)\n", - "\n", - "val_generator = MultiPageTiffGenerator(training_source,\n", - " training_target,\n", - " batch_size=batch_size,\n", - " shape=training_shape,\n", - " val_split=validation_split_in_percent/100,\n", - " is_val=True,\n", - " random_crop=random_crop,\n", - " downscale=downscaling_in_xy,\n", - " binary_target=binary_target)\n", - "\n", - "\n", - "if apply_data_augmentation:\n", - " print('Data augmentation enabled.')\n", - " sample_src_aug, sample_tgt_aug = train_generator.sample_augmentation(random.randint(0, len(train_generator)))\n", - "\n", - " def scroll_in_z(z):\n", - " f=plt.figure(figsize=(16,8))\n", - " plt.subplot(1,2,1)\n", - " plt.imshow(sample_src_aug[0,:,:,z-1,0], cmap='gray')\n", - " plt.title('Sample augmented source (z = ' + str(z) + ')', fontsize=15)\n", - " plt.axis('off')\n", - "\n", - " plt.subplot(1,2,2)\n", - " plt.imshow(sample_tgt_aug[0,:,:,z-1,0], cmap='magma')\n", - " plt.title('Sample training target (z = ' + str(z) + ')', fontsize=15)\n", - " plt.axis('off')\n", - "\n", - " print('This is what the augmented training images will look like with the chosen settings')\n", - " interact(scroll_in_z, z=widgets.IntSlider(min=1, max=sample_src_aug.shape[3], step=1, value=0));\n", - "\n", - "else:\n", - " print('Data augmentation disabled.')\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "MCGklf1vZf2M" - }, - "source": [ - "# **4. Train the network**\n", - "---\n", - "\n", - "**CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training times must be less than 12 hours! If training takes longer than 12 hours, please decrease `number_of_epochs`." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1KYOuygETJkT" - }, - "source": [ - "## **4.1. Show model and start training**\n", - "---\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "lIUAOJ_LMv5E", - "cellView": "form" - }, - "source": [ - "#@markdown ## Show model summary\n", - "model.summary()" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "CyQI4ssarUp4", - "cellView": "form" - }, - "source": [ - "#@markdown ##Start training\n", - "\n", - "#here we check that no model with the same name already exist, if so delete\n", - "if not resume_training and os.path.exists(full_model_path): \n", - " shutil.rmtree(full_model_path)\n", - " print(bcolors.WARNING+'!! WARNING: Folder already exists and has been overwritten !!'+bcolors.NORMAL) \n", - "\n", - "if not os.path.exists(full_model_path):\n", - " os.makedirs(full_model_path)\n", - "\n", - "pdf_export(augmentation = apply_data_augmentation, pretrained_model = resume_training)\n", - "\n", - "# Save file\n", - "params_df.to_csv(os.path.join(full_model_path, 'params.csv'))\n", - "\n", - "start = time.time()\n", - "# Start Training\n", - "model.train(epochs=number_of_epochs,\n", - " batch_size=batch_size,\n", - " train_generator=train_generator,\n", - " val_generator=val_generator,\n", - " model_path=model_path,\n", - " model_name=model_name,\n", - " loss=loss_function,\n", - " metrics=metrics,\n", - " optimizer=optimizer,\n", - " learning_rate=learning_rate,\n", - " ckpt_period=checkpointing_period,\n", - " save_best_ckpt_only=save_best_only,\n", - " ckpt_path=last_ckpt_path)\n", - "\n", - "print('Training successfully completed!')\n", - "dt = time.time() - start\n", - "mins, sec = divmod(dt, 60) \n", - "hour, mins = divmod(mins, 60) \n", - "print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n", - "\n", - "#Create a pdf document with training summary\n", - "\n", - "pdf_export(trained = True, augmentation = apply_data_augmentation, pretrained_model = resume_training)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "0Dfn8ZsEMv5d" - }, - "source": [ - "##**4.2. Download your model from Google Drive**\n", - "\n", - "---\n", - "Once training is complete, the trained model is automatically saved to your Google Drive, in the **`model_path`** folder that was specified in Section 3. Download the folder to avoid any unwanted surprises, since the data can be erased if you train another model using the same `model_path`." - ] - }, - { - "cell_type": "code", - "metadata": { - "scrolled": true, - "id": "iwNmp1PUzRDQ", - "cellView": "form" - }, - "source": [ - "#@markdown ##Download model directory\n", - "#@markdown 1. Specify the model_path in `model_path_download` otherwise the model sepcified in Section 3.1 will be downloaded\n", - "#@markdown 2. Run this cell to zip the model directory\n", - "#@markdown 3. Download the zipped file from the *Files* tab on the left\n", - "\n", - "from google.colab import files\n", - "\n", - "model_path_download = \"\" #@param {type:\"string\"}\n", - "\n", - "if len(model_path_download) == 0:\n", - " model_path_download = full_model_path\n", - "\n", - "model_name_download = os.path.basename(model_path_download)\n", - "\n", - "print('Zipping', model_name_download)\n", - "\n", - "zip_model_path = model_name_download + '.zip'\n", - "\n", - "!zip -r \"$zip_model_path\" \"$model_path_download\"\n", - "\n", - "print('Successfully saved zipped model directory as', zip_model_path)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "_0Hynw3-xHp1" - }, - "source": [ - "# **5. Evaluate your model**\n", - "---\n", - "\n", - "In this section the newly trained model can be assessed for performance. This involves inspecting the loss function in Section 5.1. and employing more advanced metrics in Section 5.2.\n", - "\n", - "**We highly recommend performing quality control on all newly trained models.**\n", - "\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "eAJzMwPA6tlH", - "cellView": "form" - }, - "source": [ - "#@markdown ###Model to be evaluated:\n", - "#@markdown If left blank, the latest model defined in Section 3 will be evaluated:\n", - "\n", - "qc_model_name = \"\" #@param {type:\"string\"}\n", - "qc_model_path = \"\" #@param {type:\"string\"}\n", - "\n", - "if len(qc_model_path) == 0 and len(qc_model_name) == 0:\n", - " qc_model_name = model_name\n", - " qc_model_path = model_path\n", - "\n", - "full_qc_model_path = os.path.join(qc_model_path, qc_model_name)\n", - "\n", - "if os.path.exists(full_qc_model_path):\n", - " print(qc_model_name + ' will be evaluated')\n", - "else:\n", - " W = '\\033[0m' # white (normal)\n", - " R = '\\033[31m' # red\n", - " print(R+'!! WARNING: The chosen model does not exist !!'+W)\n", - " print('Please make sure you provide a valid model path and model name before proceeding further.')\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "dhJROwlAMv5o" - }, - "source": [ - "## **5.1. Inspecting loss function**\n", - "---\n", - "\n", - "**The training loss** is the error between prediction and target after each epoch calculated across the training data while the **validation loss** calculates the error on the (unseen) validation data. During training these values should decrease until converging at which point the model has been sufficiently trained. If the validation loss starts increasing while the training loss has plateaued, the model has overfit on the training data which reduces its ability to generalise. Aim to halt training before this point.\n", - "\n", - "**Note:** For a more in-depth explanation please refer to [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols et al.\n", - "\n", - "\n", - "The accuracy is another performance metric that is calculated after each epoch. We use the [Sørensen–Dice coefficient](https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient) to score the prediction accuracy. \n", - "\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "vMzSP50kMv5p", - "cellView": "form" - }, - "source": [ - "#@markdown ##Visualise loss and accuracy\n", - "lossDataFromCSV = []\n", - "vallossDataFromCSV = []\n", - "accuracyDataFromCSV = []\n", - "valaccuracyDataFromCSV = []\n", - "\n", - "with open(full_qc_model_path + '/Quality Control/training_evaluation.csv', 'r') as csvfile:\n", - " csvRead = csv.reader(csvfile, delimiter=',')\n", - " next(csvRead)\n", - " for row in csvRead:\n", - " lossDataFromCSV.append(float(row[2]))\n", - " vallossDataFromCSV.append(float(row[4]))\n", - " accuracyDataFromCSV.append(float(row[1]))\n", - " valaccuracyDataFromCSV.append(float(row[3]))\n", - "\n", - "epochNumber = range(len(lossDataFromCSV))\n", - "plt.figure(figsize=(15,10))\n", - "\n", - "plt.subplot(2,1,1)\n", - "plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n", - "plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n", - "plt.title('Training and validation loss', fontsize=14)\n", - "plt.ylabel('Loss', fontsize=12)\n", - "plt.xlabel('Epochs', fontsize=12)\n", - "plt.legend()\n", - "\n", - "plt.subplot(2,1,2)\n", - "plt.plot(epochNumber,accuracyDataFromCSV, label='Training accuracy')\n", - "plt.plot(epochNumber,valaccuracyDataFromCSV, label='Validation accuracy')\n", - "plt.title('Training and validation accuracy', fontsize=14)\n", - "plt.ylabel('Dice', fontsize=12)\n", - "plt.xlabel('Epochs', fontsize=12)\n", - "plt.legend()\n", - "plt.savefig(full_qc_model_path + '/Quality Control/lossCurvePlots.png', bbox_inches='tight', pad_inches=0)\n", - "plt.show()\n", - "\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "X5_92nL2xdP6" - }, - "source": [ - "## **5.2. Error mapping and quality metrics estimation**\n", - "---\n", - "This section will provide both a visual indication of the model performance by comparing the overlay of the predicted and source volume." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "w90MdriMxhjD", - "cellView": "form" - }, - "source": [ - "#@markdown ##Compare prediction and ground-truth on testing data\n", - "\n", - "#@markdown Provide an unseen annotated dataset to determine the performance of the model:\n", - "\n", - "testing_source = \"\" #@param{type:\"string\"}\n", - "testing_target = \"\" #@param{type:\"string\"}\n", - "\n", - "qc_dir = full_qc_model_path + '/Quality Control'\n", - "predict_dir = qc_dir + '/Prediction'\n", - "if os.path.exists(predict_dir):\n", - " shutil.rmtree(predict_dir)\n", - "\n", - "os.makedirs(predict_dir)\n", - "\n", - "# predict_dir + '/' + \n", - "predict_path = os.path.splitext(os.path.basename(testing_source))[0] + '_prediction.tif'\n", - "\n", - "def last_chars(x):\n", - " return(x[-11:])\n", - "\n", - "try:\n", - " ckpt_dir_list = glob(full_qc_model_path + '/ckpt/*')\n", - " ckpt_dir_list.sort(key=last_chars)\n", - " last_ckpt_path = ckpt_dir_list[0]\n", - " print('Predicting from checkpoint:', os.path.basename(last_ckpt_path))\n", - "except IndexError:\n", - " raise CheckpointError('No previous checkpoints were found, please retrain model.')\n", - "\n", - "# Load parameters\n", - "params = pd.read_csv(os.path.join(full_qc_model_path, 'params.csv'), names=['val'], header=0, index_col=0) \n", - "\n", - "model = Unet3D(shape=params.loc['training_shape', 'val'])\n", - "\n", - "prediction = model.predict(testing_source, last_ckpt_path, downscaling=params.loc['downscaling', 'val'], true_patch_size=params.loc['true_patch_size', 'val'])\n", - "\n", - "tifffile.imwrite(predict_path, prediction.astype('float32'), imagej=True)\n", - "\n", - "print('Predicted images!')\n", - "\n", - "qc_metrics_path = full_qc_model_path + '/Quality Control/QC_metrics_' + qc_model_name + '.csv'\n", - "\n", - "test_target = tifffile.imread(testing_target)\n", - "test_source = tifffile.imread(testing_source)\n", - "test_prediction = tifffile.imread(predict_path)\n", - "\n", - "def scroll_in_z(z):\n", - "\n", - " plt.figure(figsize=(25,5))\n", - " # Source\n", - " plt.subplot(1,4,1)\n", - " plt.axis('off')\n", - " plt.imshow(test_source[z-1], cmap='gray')\n", - " plt.title('Source (z = ' + str(z) + ')', fontsize=15)\n", - "\n", - " # Target (Ground-truth)\n", - " plt.subplot(1,4,2)\n", - " plt.axis('off')\n", - " plt.imshow(test_target[z-1], cmap='magma')\n", - " plt.title('Target (z = ' + str(z) + ')', fontsize=15)\n", - "\n", - " # Prediction\n", - " plt.subplot(1,4,3)\n", - " plt.axis('off')\n", - " plt.imshow(test_prediction[z-1], cmap='magma')\n", - " plt.title('Prediction (z = ' + str(z) + ')', fontsize=15)\n", - " \n", - " # Overlay\n", - " plt.subplot(1,4,4)\n", - " plt.axis('off')\n", - " plt.imshow(test_target[z-1], cmap='Greens')\n", - " plt.imshow(test_prediction[z-1], alpha=0.5, cmap='Purples')\n", - " plt.title('Overlay (z = ' + str(z) + ')', fontsize=15)\n", - " plt.savefig(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/QC_example_data.png', bbox_inches='tight', pad_inches=0)\n", - "interact(scroll_in_z, z=widgets.IntSlider(min=1, max=test_source.shape[0], step=1, value=0));" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "aIvRxpZlsFeZ" - }, - "source": [ - "## **5.3. Determine best Intersection over Union and threshold**\n", - "---\n", - "\n", - "**Note:** This section is only relevant if the target image is a binary mask and `binary_target` is selected in Section 3! \n", - "\n", - "This section will provide both a visual and a quantitative indication of the model performance by comparing the overlay of the predicted and source volume, as well as computing the highest [**Intersection over Union**](https://en.wikipedia.org/wiki/Jaccard_index) (IoU) score. The IoU is also known as the Jaccard Index. \n", - "\n", - "The best threshold is calculated using the IoU. Each threshold value from 0 to 255 is tested and the threshold with the highest score is deemed the best. The IoU is calculated for the entire volume in 3D." - ] - }, - { - "cell_type": "code", - "metadata": { - "cellView": "form", - "id": "XhkeZTFusHA8" - }, - "source": [ - "\n", - "#@markdown ##Calculate Intersection over Union and best threshold \n", - "prediction = tifffile.imread(predict_path)\n", - "prediction = np.interp(prediction, (prediction.min(), prediction.max()), (0, 255))\n", - "\n", - "target = tifffile.imread(testing_target).astype(np.bool)\n", - "\n", - "def iou_vs_threshold(prediction, target):\n", - " threshold_list = []\n", - " IoU_scores_list = []\n", - "\n", - " for threshold in range(0,256): \n", - " mask = prediction > threshold\n", - "\n", - " intersection = np.logical_and(target, mask)\n", - " union = np.logical_or(target, mask)\n", - " iou_score = np.sum(intersection) / np.sum(union)\n", - "\n", - " threshold_list.append(threshold)\n", - " IoU_scores_list.append(iou_score)\n", - "\n", - " return threshold_list, IoU_scores_list\n", - "\n", - "threshold_list, IoU_scores_list = iou_vs_threshold(prediction, target)\n", - "thresh_arr = np.array(list(zip(threshold_list, IoU_scores_list)))\n", - "best_thresh = int(np.where(thresh_arr == np.max(thresh_arr[:,1]))[0])\n", - "best_iou = IoU_scores_list[best_thresh]\n", - "\n", - "print('Highest IoU is {:.4f} with a threshold of {}'.format(best_iou, best_thresh))\n", - "\n", - "def adjust_threshold(threshold, z):\n", - "\n", - " f=plt.figure(figsize=(25,5))\n", - " plt.subplot(1,4,1)\n", - " plt.imshow((prediction[z-1] > threshold).astype('uint8'), cmap='magma')\n", - " plt.title('Prediction (Threshold = ' + str(threshold) + ')', fontsize=15)\n", - " plt.axis('off')\n", - "\n", - " plt.subplot(1,4,2)\n", - " plt.imshow(target[z-1], cmap='magma')\n", - " plt.title('Target (z = ' + str(z) + ')', fontsize=15)\n", - " plt.axis('off')\n", - "\n", - " plt.subplot(1,4,3)\n", - " plt.axis('off')\n", - " plt.imshow(test_source[z-1], cmap='gray')\n", - " plt.imshow((prediction[z-1] > threshold).astype('uint8'), alpha=0.4, cmap='Reds')\n", - " plt.title('Overlay (z = ' + str(z) + ')', fontsize=15)\n", - "\n", - " plt.subplot(1,4,4)\n", - " plt.title('Threshold vs. IoU', fontsize=15)\n", - " plt.plot(threshold_list, IoU_scores_list)\n", - " plt.plot(threshold, IoU_scores_list[threshold], 'ro') \n", - " plt.ylabel('IoU score')\n", - " plt.xlabel('Threshold')\n", - " plt.savefig(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/QC_IoU_analysis.png',bbox_inches=matplotlib.transforms.Bbox([[17.5,0],[23,5]]),pad_inches=0)\n", - " plt.show()\n", - "\n", - "interact(adjust_threshold, \n", - " threshold=widgets.IntSlider(min=0, max=255, step=1, value=best_thresh),\n", - " z=widgets.IntSlider(min=1, max=prediction.shape[0], step=1, value=0));\n", - "\n", - "#Make a pdf summary of the QC results\n", - "\n", - "qc_pdf_export()\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "WoBhNSW_dupQ" - }, - "source": [ - "## **5.4. Export your model into the BioImage Model Zoo format**\n", - "---\n", - "This section exports the model into the BioImage Model Zoo format so it can be used directly with DeepImageJ. The new files will be stored in the model folder specified at the beginning of Section 5. \n", - "\n", - "Once the cell is executed, you will find a new zip file with the name specified in `Trained_model_name.bioimage.io.model`.\n", - "\n", - "To use it with deepImageJ, download it and unzip it in the ImageJ/models/ or Fiji/models/ folder of your local machine. \n", - "\n", - "In ImageJ, open the example image given within the downloaded zip file. Go to Plugins > DeepImageJ > DeepImageJ Run. Choose this model from the list and click OK.\n", - "\n", - " More information at https://deepimagej.github.io/deepimagej/" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "n_uVao0edw2q", - "cellView": "form" - }, - "source": [ - "# ####\n", - "from pydeepimagej.yaml import BioImageModelZooConfig\n", - "import urllib\n", - "import warnings\n", - "warnings. filterwarnings(\"ignore\") \n", - "\n", - "# ------------- User input ------------\n", - "# information about the model\n", - "#@markdown ##Introduce the metadata of the model architecture:\n", - "Trained_model_name = \"\" #@param {type:\"string\"}\n", - "Trained_model_authors = \"[Author 1, Author 2, Author 3]\" #@param {type:\"string\"}\n", - "\n", - "Trained_model_description = \"\"#@param {type:\"string\"}\n", - "Trained_model_license = 'MIT'#@param {type:\"string\"}\n", - "Trained_model_references = [\"Çiçek, Özgün, et al. MICCAI 2016\", \"Lucas von Chamier et al. biorXiv 2020\"]\n", - "Trained_model_DOI = [\"https://doi.org/10.1007/978-3-319-46723-8_49\", \"https://doi.org/10.1101/2020.03.20.000133\"]\n", - "\n", - "# Add example image information\n", - "# ---------------------------------------\n", - "#@markdown ##Choose a threshold for DeepImageJ's postprocessing macro:\n", - "Use_The_Best_Average_Threshold = True #@param {type:\"boolean\"}\n", - "#@markdown ###If not, please input:\n", - "threshold = 155 #@param {type:\"number\"}\n", - "if Use_The_Best_Average_Threshold:\n", - " threshold = best_thresh\n", - "\n", - "#@markdown ##Introduce the voxel size (pixel size for each Z-slice and the distance between Z-salices) (in microns) of the image provided as an example of the model processing:\n", - "# information about the example image\n", - "PixelSize = 1 #@param {type:\"number\"}\n", - "Zdistance = 1 #@param {type:\"number\"}\n", - "#@markdown ##Do you want to choose the exampleimage?\n", - "default_example_image = True #@param {type:\"boolean\"}\n", - "#@markdown ###If not, please input:\n", - "fileID = \"\" #@param {type:\"string\"}\n", - "if default_example_image:\n", - " fileID = testing_source\n", - " \n", - "example_image = tifffile.imread(fileID) \n", - "# Z-dim first\n", - "z_size = example_image.shape[0]\n", - "z_size = np.int(z_size/2)\n", - "\n", - "example_image = example_image[z_size-10:z_size+10]\n", - "path_example_im = \"/content/example_image_biomodelzoo.tif\"\n", - "tifffile.imsave(path_example_im, example_image)\n", - "\n", - "\n", - "\n", - "# Load model parameters\n", - "# ---------------------------------------\n", - "def last_chars(x):\n", - " return(x[-11:])\n", - "try:\n", - " ckpt_dir_list = glob(full_qc_model_path + '/ckpt/*')\n", - " ckpt_dir_list.sort(key=last_chars)\n", - " last_ckpt_path = ckpt_dir_list[0]\n", - " print('Predicting from checkpoint:', os.path.basename(last_ckpt_path))\n", - "except IndexError:\n", - " raise CheckpointError('No previous checkpoints were found, please retrain model.')\n", - "params = pd.read_csv(os.path.join(full_qc_model_path, 'params.csv'), names=['val'], header=0, index_col=0) \n", - "\n", - "\n", - "\n", - "# Load the model and process the example image\n", - "# ---------------------------------------\n", - "model = Unet3D(shape=params.loc['training_shape', 'val'])\n", - "# prediction = model.predict(fileID, last_ckpt_path, downscaling=params.loc['downscaling', 'val'], true_patch_size=params.loc['true_patch_size', 'val'])\n", - "prediction = model.predict(path_example_im, last_ckpt_path, downscaling=params.loc['downscaling', 'val'], true_patch_size=params.loc['true_patch_size', 'val'])\n", - "\n", - "prediction = np.interp(prediction, (prediction.min(), prediction.max()), (0, 255))\n", - "mask = prediction > threshold\n", - "# Z-dim first\n", - "# mask = mask[:20]\n", - "\n", - "\n", - "\n", - "\n", - "# ------------- Execute bioimage model zoo configuration ------------\n", - "# Check minimum size: it is [8,8] for the 2D XY plane\n", - "keras_model = model.model\n", - "# pooling_steps = 0\n", - "# for keras_layer in keras_model.layers:\n", - "# if keras_layer.name.startswith('max') or \"pool\" in keras_layer.name:\n", - "# pooling_steps += 1\n", - "# MinimumSize = [2**(pooling_steps), 2**(pooling_steps)]\n", - "MinimumSize = keras_model.input_shape[1:-1]\n", - "dij_config = BioImageModelZooConfig(keras_model, MinimumSize)\n", - "\n", - "# Model developer details\n", - "dij_config.Authors = Trained_model_authors[1:-1].split(',')\n", - "dij_config.Description = Trained_model_description\n", - "dij_config.Name = Trained_model_name\n", - "dij_config.References = Trained_model_references\n", - "dij_config.DOI = Trained_model_DOI\n", - "dij_config.License = Trained_model_license\n", - "\n", - "# Additional information about the model\n", - "dij_config.GitHub = '/~https://github.com/HenriquesLab/ZeroCostDL4Mic'\n", - "dij_config.Date = datetime.now()\n", - "dij_config.Documentation = '/~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki'\n", - "dij_config.Tags = ['ZeroCostDL4Mic', 'deepimagej', 'segmentation', '3DUNet']\n", - "dij_config.Framework = 'tensorflow'\n", - "\n", - "# Add the information about the test image. Note here PixelSize should be given in microns\n", - "dij_config.add_test_info(example_image, mask, [PixelSize, PixelSize, Zdistance])\n", - "dij_config.create_covers([example_image, mask])\n", - "dij_config.Covers = ['./input.png', './output.png']\n", - "\n", - "# Store the model weights\n", - "# ---------------------------------------\n", - "# used_bioimage_model_for_training_URL = \"/Some/URI/\"\n", - "# dij_config.Parent = used_bioimage_model_for_training_URL\n", - "\n", - "# Add weights information\n", - "format_authors = [\"pydeepimagej\"]\n", - "dij_config.add_weights_formats(keras_model, 'TensorFlow', \n", - " parent=\"keras_hdf5\",\n", - " authors=[a for a in format_authors])\n", - "dij_config.add_weights_formats(keras_model, 'KerasHDF5', \n", - " authors=[a for a in format_authors])\n", - "\n", - "## Preprocessing and postprocessing\n", - "# -------------------------------------------------\n", - "## Prepare preprocessing file_ min_max_scaling\n", - "path_preprocessing = \"per_sample_scale_range.ijm\"\n", - "urllib.request.urlretrieve(\"https://raw.githubusercontent.com/deepimagej/imagej-macros/master/bioimage.io/per_sample_scale_range.ijm\", path_preprocessing )\n", - "\n", - "# Modify the threshold in the macro to the chosen threshold\n", - "ijmacro = open(path_preprocessing,\"r\") \n", - "list_of_lines = ijmacro. readlines()\n", - "# Line 21 is the one corresponding to the optimal threshold\n", - "list_of_lines[24] = \"min_percentile = 0;\\n\"\n", - "list_of_lines[25] = \"max_percentile = 100;\\n\"\n", - "ijmacro.close()\n", - "ijmacro = open(path_preprocessing,\"w\") \n", - "ijmacro. writelines(list_of_lines)\n", - "ijmacro. close()\n", - "\n", - "## Prepare postprocessing file\n", - "path_postprocessing = \"binarize.ijm\"\n", - "urllib.request.urlretrieve(\"https://raw.githubusercontent.com/deepimagej/imagej-macros/master/bioimage.io/binarize.ijm\", path_postprocessing )\n", - "\n", - "# Modify the threshold in the macro to the chosen threshold\n", - "ijmacro = open(path_postprocessing,\"r\") \n", - "list_of_lines = ijmacro. readlines()\n", - "# Line 21 is the one corresponding to the optimal threshold\n", - "list_of_lines[11] = \"optimalThreshold = {};\\n\".format(threshold/255) # The output in DeepImageJ will not be converted to the range [0,255], so the threshold is adjusted.\n", - "ijmacro.close()\n", - "ijmacro = open(path_postprocessing,\"w\") \n", - "ijmacro. writelines(list_of_lines)\n", - "ijmacro. close()\n", - "\n", - "\n", - "\n", - "# Include the info about the macros \n", - "dij_config.Preprocessing = [path_preprocessing]\n", - "dij_config.Preprocessing_files = [path_preprocessing]\n", - "dij_config.Postprocessing = [path_postprocessing]\n", - "dij_config.Postprocessing_files = [path_postprocessing]\n", - "dij_config.add_bioimageio_spec('pre-processing', 'percentile', mode='per_sample', axes='xyzc', min_percentile=0, max_percentile=100)\n", - "dij_config.add_bioimageio_spec('post-processing', 'binarize', threshold=threshold)\n", - "\n", - "## EXPORT THE MODEL TO AN EXISTING PATH OR CREATE IT\n", - "deepimagej_model_path = os.path.join(full_qc_model_path,qc_model_name+'.bioimage.io.model')\n", - "# if not os.path.exists(deepimagej_model_path):\n", - "# os.mkdir(deepimagej_model_path)\n", - "dij_config.export_model(deepimagej_model_path)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-tJeeJjLnRkP" - }, - "source": [ - "# **6. Using the trained model**\n", - "\n", - "---\n", - "\n", - "Once sufficient performance of the trained model has been established using Section 5, the network can be used to segment unseen volumetric data." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "d8wuQGjoq6eN" - }, - "source": [ - "## **6.1. Generate predictions from unseen dataset**\n", - "---\n", - "\n", - "The most recently trained model can now be used to predict segmentation masks on unseen images. If you want to use an older model, leave `model_path` blank. Predicted output images are saved in `output_path` as Image-J compatible TIFF files.\n", - "\n", - "## **Prediction parameters**\n", - "\n", - "* **`source_path`** specifies the location of the source \n", - "image volume.\n", - "\n", - "* **`output_directory`** specified the directory where the output predictions are stored.\n", - "\n", - "* **`binary_target`** should be chosen if the network is trained to predict binary segmentation masks.\n", - "\n", - "* **`threshold`** can be calculated in Section 5 and is used to generate binary masks from the predictions.\n", - "\n", - "* **`big_tiff`** should be chosen if the expected prediction exceeds 4GB. The predictions will be saved using the BigTIFF format. Beware that this might substantially reduce the prediction speed. *Default: False* \n", - "\n", - "* **`prediction_depth`** is only relevant if the prediction is saved as a BigTIFF. The prediction will not be performed in one go to not deplete the memory resources. Instead, the prediction is iteratively performed on a subset of the entire volume with shape `(source.shape[0], source.shape[1], prediction_depth)`. *Default: 32*\n", - "\n", - "* **`model_path`** specifies the path to a model other than the most recently trained." - ] - }, - { - "cell_type": "code", - "metadata": { - "cellView": "form", - "id": "DEmhPh5fsWX2" - }, - "source": [ - "#@markdown ## Download example volume\n", - "\n", - "#@markdown This can take up to an hour\n", - "\n", - "import requests \n", - "import os\n", - "from tqdm.notebook import tqdm \n", - "\n", - "\n", - "def download_from_url(url, save_as):\n", - " file_url = url\n", - " r = requests.get(file_url, stream=True) \n", - " \n", - " with open(save_as, 'wb') as file: \n", - " for block in tqdm(r.iter_content(chunk_size = 1024), desc = 'Downloading ' + os.path.basename(save_as), total=3275073, ncols=1000):\n", - " if block:\n", - " file.write(block) \n", - "\n", - "download_from_url('https://documents.epfl.ch/groups/c/cv/cvlab-unit/www/data/%20ElectronMicroscopy_Hippocampus/volumedata.tif', 'example_dataset/volumedata.tif')" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "y2TD5p7MZrEb", - "cellView": "form" - }, - "source": [ - "#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then run the cell to predict outputs from your unseen images.\n", - "\n", - "source_path = \"\" #@param {type:\"string\"}\n", - "output_directory = \"\" #@param {type:\"string\"}\n", - "\n", - "if not os.path.exists(output_directory):\n", - " os.makedirs(output_directory)\n", - "\n", - "output_path = os.path.join(output_directory, os.path.splitext(os.path.basename(source_path))[0] + '_predicted.tif')\n", - "#@markdown ###Prediction parameters:\n", - "\n", - "binary_target = True #@param {type:\"boolean\"}\n", - "\n", - "save_probability_map = False #@param {type:\"boolean\"}\n", - "\n", - "#@markdown Determine best threshold in Section 5.2.\n", - "\n", - "use_calculated_threshold = True #@param {type:\"boolean\"}\n", - "threshold = 200#@param {type:\"number\"}\n", - "\n", - "# Tifffile library issues means that images cannot be appended to \n", - "#@markdown Choose if prediction file exceeds 4GB or if input file is very large (above 2GB). Image volume saved as BigTIFF.\n", - "big_tiff = False #@param {type:\"boolean\"}\n", - "\n", - "#@markdown Reduce `prediction_depth` if runtime runs out of memory during prediction. Only relevant if prediction saved as BigTIFF\n", - "\n", - "prediction_depth = 32#@param {type:\"number\"}\n", - "\n", - "#@markdown ###Model to be evaluated\n", - "#@markdown If left blank, the latest model defined in Section 5 will be evaluated\n", - "\n", - "full_model_path_ = \"\" #@param {type:\"string\"}\n", - "\n", - "if len(full_model_path_) == 0:\n", - " full_model_path_ = os.path.join(qc_model_path, qc_model_name) \n", - "\n", - "\n", - "\n", - "# Load parameters\n", - "params = pd.read_csv(os.path.join(full_model_path_, 'params.csv'), names=['val'], header=0, index_col=0) \n", - "model = Unet3D(shape=params.loc['training_shape', 'val'])\n", - "\n", - "if use_calculated_threshold:\n", - " threshold = best_thresh\n", - "\n", - "def last_chars(x):\n", - " return(x[-11:])\n", - "\n", - "try:\n", - " ckpt_dir_list = glob(full_model_path_ + '/ckpt/*')\n", - " ckpt_dir_list.sort(key=last_chars)\n", - " last_ckpt_path = ckpt_dir_list[0]\n", - " print('Predicting from checkpoint:', os.path.basename(last_ckpt_path))\n", - "except IndexError:\n", - " raise CheckpointError('No previous checkpoints were found, please retrain model.')\n", - "\n", - "src = tifffile.imread(source_path)\n", - "\n", - "if src.nbytes >= 4e9:\n", - " big_tiff = True\n", - " print('The source file exceeds 4GB in memory, prediction will be saved as BigTIFF!')\n", - "\n", - "if binary_target:\n", - " if not big_tiff:\n", - " prediction = model.predict(src, last_ckpt_path, downscaling=params.loc['downscaling', 'val'], true_patch_size=params.loc['true_patch_size', 'val'])\n", - " prediction = np.interp(prediction, (prediction.min(), prediction.max()), (0, 255))\n", - " prediction = (prediction > threshold).astype('float32')\n", - "\n", - " tifffile.imwrite(output_path, prediction, imagej=True)\n", - "\n", - " else:\n", - " with tifffile.TiffWriter(output_path, bigtiff=True) as tif:\n", - " for i in tqdm(range(0, src.shape[0], prediction_depth)):\n", - " prediction = model.predict(src, last_ckpt_path, z_range=(i,i+prediction_depth), downscaling=params.loc['downscaling', 'val'], true_patch_size=params.loc['true_patch_size', 'val'])\n", - " prediction = np.interp(prediction, (prediction.min(), prediction.max()), (0, 255))\n", - " prediction = (prediction > threshold).astype('float32')\n", - " \n", - " for j in range(prediction.shape[0]):\n", - " tif.save(prediction[j])\n", - "\n", - "if not binary_target or save_probability_map:\n", - " if not binary_target:\n", - " prob_map_path = output_path\n", - " else:\n", - " prob_map_path = os.path.splitext(output_path)[0] + '_prob_map.tif'\n", - " \n", - " if not big_tiff:\n", - " prediction = model.predict(src, last_ckpt_path, downscaling=params.loc['downscaling', 'val'], true_patch_size=params.loc['true_patch_size', 'val'])\n", - " prediction = np.interp(prediction, (prediction.min(), prediction.max()), (0, 255))\n", - " tifffile.imwrite(prob_map_path, prediction.astype('float32'), imagej=True)\n", - "\n", - " else:\n", - " with tifffile.TiffWriter(prob_map_path, bigtiff=True) as tif:\n", - " for i in tqdm(range(0, src.shape[0], prediction_depth)):\n", - " prediction = model.predict(src, last_ckpt_path, z_range=(i,i+prediction_depth), downscaling=params.loc['downscaling', 'val'], true_patch_size=params.loc['true_patch_size', 'val'])\n", - " prediction = np.interp(prediction, (prediction.min(), prediction.max()), (0, 255))\n", - " \n", - " for j in range(prediction.shape[0]):\n", - " tif.save(prediction[j])\n", - "\n", - "print('Predictions saved as', output_path)\n", - "\n", - "src_volume = tifffile.imread(source_path)\n", - "pred_volume = tifffile.imread(output_path)\n", - "\n", - "def scroll_in_z(z):\n", - " \n", - " f=plt.figure(figsize=(25,5))\n", - " plt.subplot(1,2,1)\n", - " plt.imshow(src_volume[z-1], cmap='gray')\n", - " plt.title('Source (z = ' + str(z) + ')', fontsize=15)\n", - " plt.axis('off')\n", - "\n", - " plt.subplot(1,2,2)\n", - " plt.imshow(pred_volume[z-1], cmap='magma')\n", - " plt.title('Prediction (z = ' + str(z) + ')', fontsize=15)\n", - " plt.axis('off')\n", - "\n", - "interact(scroll_in_z, z=widgets.IntSlider(min=1, max=src_volume.shape[0], step=1, value=0));\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "hvkd66PldsXB" - }, - "source": [ - "## **6.2. Download your predictions**\n", - "---\n", - "\n", - "**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "UvSlTaH14s3t" - }, - "source": [ - "\n", - "#**Thank you for using 3D U-Net!**" - ] - } - ] -} \ No newline at end of file +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"U-Net_3D_ZeroCostDL4Mic_BioImageModelZoo_export.ipynb","provenance":[{"file_id":"1MmLTCC0nyX3Akb9V4C_OVxM3X_M8u-eX","timestamp":1610543191319},{"file_id":"1paNjUObR5Rcr4BMGADJTz0PQBBLZDPrY","timestamp":1602522500580},{"file_id":"1WZRIoSBNcRUEq4-Rq5M4mDkIaOlEHnxz","timestamp":1588762142860},{"file_id":"10weAY0es-pEfHlACCaBCKK7PmgdoJqdh","timestamp":1587728072051},{"file_id":"10Ze0rFZoooyyTL_OIVWGdFJEhWE6_cSB","timestamp":1586789421439},{"file_id":"1SsGyUbWcMaLGHFepMuKElRNYLdEBUwf6","timestamp":1583244509550}],"collapsed_sections":[],"toc_visible":true,"machine_shape":"hm"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.7"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"IkSguVy8Xv83"},"source":["# **U-Net (3D)**\n"," ---\n","\n"," The 3D U-Net was first introduced by [Çiçek et al](https://arxiv.org/abs/1606.06650) for learning dense volumetric segmentations from sparsely annotated ground-truth data building upon the original U-Net architecture by [Ronneberger et al](https://arxiv.org/abs/1505.04597). \n","\n","**This particular implementation allows supervised learning between any two types of 3D image data. If you are interested in image segmentation of 2D datasets, you should use the 2D U-Net notebook instead.**\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the *Zero-Cost Deep-Learning to Enhance Microscopy* project ([ZeroCostDL4Mic](/~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki)) jointly developed by the [Jacquemet](https://cellmig.org/) and [Henriques](https://henriqueslab.github.io/) laboratories and created by Daniel Krentzel.\n","\n","This notebook is laregly based on the following paper: \n","\n","[**3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation**](https://arxiv.org/pdf/1606.06650.pdf) by Özgün Çiçek *et al.* published on arXiv in 2016\n","\n","The following two Python libraries play an important role in the notebook: \n","\n","1. [**Elasticdeform**](/~https://github.com/gvtulder/elasticdeform)\n"," by Gijs van Tulder was used to augment the 3D training data using elastic grid-based deformations as described in the original 3D U-Net paper. \n","\n","2. [**Tifffile**](/~https://github.com/cgohlke/tifffile) by Christoph Gohlke is a great library for reading and writing TIFF files. \n","\n","3. [**Imgaug**](/~https://github.com/aleju/imgaug) by Alexander Jung *et al.* is an amazing library for image augmentation in machine learning - it is the most complete and extensive image augmentation package I have found to date. \n","\n","The [example dataset](https://www.epfl.ch/labs/cvlab/data/data-em/) represents a 5x5x5µm section taken from the CA1 hippocampus region of the brain with annotated mitochondria and was acquired by Graham Knott and Marco Cantoni at EPFL.\n","\n","\n","**Please also cite the original paper and relevant Python libraries when using or developing this notebook.**"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use ZeroCostDL4Mic notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cells: \n","\n","**Text cells** provide information and can be modified by double-clicking the cell. You are currently reading a text cell. You can create a new one by clicking `+ Text`.\n","\n","**Code cells** contain code which can be modfied by selecting the cell. To execute the cell, move your cursor to the `[]`-symbol on the left side of the cell (a play button should appear). Click it to execute the cell. Once the cell is fully executed, the animation stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","Three tabs are located on the upper left side of the notebook:\n","\n","1. *Table of contents* contains the structure of the notebook. Click the headers to move quickly between sections.\n","\n","2. *Code snippets* provides a wide array of example code specific to Google Colab. You can ignore this when using this notebook.\n","\n","3. *Files* displays the current working directory. We will mount your Google Drive in Section 1.2. so that you can access your files and save them permanently.\n","\n","**Important:** All uploaded files are purged once the runtime ends.\n","\n","**Note:** The directory *sample data* in *Files* contains default files. Do not upload anything there!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive by clicking *File* -> *Save a copy in Drive*.\n","\n","To **edit a cell**, double click on the text. This will either display the source code (in code cells) or the [markdown](https://colab.research.google.com/notebooks/markdown_guide.ipynb#scrollTo=70pYkR9LiOV0) (in text cells).\n","You can use `#` in code cells to comment out parts of the code. This allows you to keep the original piece of code while not executing it."]},{"cell_type":"markdown","metadata":{"id":"gKDLkLWUd-YX"},"source":["#**0. Before getting started**\n","---\n","\n","As the network operates in three dimensions, certain consideration should be given to correctly pre-processing the data. Ensure that the structure of interest does not substantially change between slices - image volumes with isotropic pixelsizes are ideal for this architecture.\n","\n","Each image volume must be provided as an **8-bit** or **binary multipage TIFF file** to maintain the correct ordering of individual image slices. If more than one image volume has been annotated, source and target files must be named identically and placed in separate directories. In case only one image volume has been annotated, source and target file do not have to be placed in separate directories and can be named differently, as long as their paths are explicitly provided in Section 3. \n","\n","**Prepare two datasets** (*training* and *testing*) for quality control puproses. Make sure that the *testing* dataset does not overlap with the *training* dataset and is ideally sourced from a different acquisiton and sample to ensure robustness of the trained model. \n","\n","\n","---\n","\n","\n","### **Directory structure**\n","\n","Make sure to adhere to one of the following directory structures. If only one annotated training volume exists, choose the first structure. In case more than one training volume is available, choose the second structure.\n","\n","**Structure 1:** Only one training volume\n","```\n","path/to/directory/with/one/training/volume\n","│--training_source.tif\n","│--training_target.tif\n","| \n","│--testing_source.tif\n","|--testing_target.tif \n","|\n","|--data_to_predict_on.tif\n","|--prediction_results.tif\n","\n","```\n","**Structure 2:** Various training volumes\n","```\n","path/to/directory/with/various/training/volumes\n","│--testing_source.tif\n","|--testing_target.tif \n","|\n","└───training\n","| └───source\n","| | |--training_volume_one.tif\n","| | |--training_volume_two.tif\n","| | |--...\n","| | |--training_volume_n.tif\n","| |\n","| └───target\n","| |--training_volume_one.tif\n","| |--training_volume_two.tif\n","| |--...\n","| |--training_volume_n.tif\n","|\n","|--data_to_predict_on.tif\n","|--prediction_results.tif\n","```\n","**Note:** Naming directories is completely up to you, as long as the paths are correctly specified throughout the notebook.\n","\n","\n","---\n","\n","\n","### **Important note**\n","\n","* If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do so), you will need to run **Sections 1 - 4**, then use **Section 5** to assess the quality of your model and **Section 6** to run predictions using the model that you trained.\n","\n","* If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **Sections 1 and 2** to set up the notebook, then use **Section 5** to assess the quality of your model.\n","\n","* If you only wish to **Run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **Sections 1 and 2** to set up the notebook, then use **Section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"code","metadata":{"cellView":"form","id":"M-GZMaL7pd8a"},"source":["#@markdown ##**Download example dataset**\n","\n","#@markdown This usually takes a few minutes. The images are saved in *example_dataset*.\n","\n","import requests \n","import os\n","from tqdm.notebook import tqdm \n","\n","def make_directory(dir):\n"," if not os.path.exists(dir):\n"," os.makedirs(dir)\n","\n","def download_from_url(url, save_as):\n"," file_url = url\n"," r = requests.get(file_url, stream=True) \n"," \n"," with open(save_as, 'wb') as file: \n"," for block in tqdm(r.iter_content(chunk_size = 1024), desc = 'Downloading ' + os.path.basename(save_as), total=126875, ncols=1000):\n"," if block:\n"," file.write(block) \n","\n","\n","make_directory('example_dataset')\n","\n","download_from_url('https://documents.epfl.ch/groups/c/cv/cvlab-unit/www/data/%20ElectronMicroscopy_Hippocampus/training.tif', 'example_dataset/training.tif')\n","download_from_url('https://documents.epfl.ch/groups/c/cv/cvlab-unit/www/data/%20ElectronMicroscopy_Hippocampus/training_groundtruth.tif', 'example_dataset/training_groundtruth.tif')\n","download_from_url('https://documents.epfl.ch/groups/c/cv/cvlab-unit/www/data/%20ElectronMicroscopy_Hippocampus/testing.tif', 'example_dataset/testing.tif')\n","download_from_url('https://documents.epfl.ch/groups/c/cv/cvlab-unit/www/data/%20ElectronMicroscopy_Hippocampus/testing_groundtruth.tif', 'example_dataset/testing_groundtruth.tif')\n","\n","print('Example dataset successfully downloaded!')"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"cellView":"form","id":"zxELU7CIp4oF"},"source":["#@markdown ##Unzip pre-trained model directory\n","\n","#@markdown 1. Upload a zipped model directory using the *Files* tab\n","#@markdown 2. Run this cell to unzip your model file\n","#@markdown 3. The model directory will appear in the *Files* tab \n","\n","from google.colab import files\n","\n","zipped_model_file = \"\" #@param {type:\"string\"}\n","\n","!unzip \"$zipped_model_file\""],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"AdN8B91xZO0x"},"source":["# **1. Install 3D U-Net dependencies**\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"PVbFEzo1DgUB"},"source":["## **1.1. Install key dependencies**\n","---\n"," "]},{"cell_type":"code","metadata":{"cellView":"form","id":"q4wM9Sr0Dbbf"},"source":["#@markdown ##Play to install 3D U-Net dependencies\n","\n","!pip install pydeepimagej==2.1.2\n","# !pip uninstall -y keras-nightly\n","!pip install data\n","!pip install fpdf\n","!pip install h5py==2.10\n","\n","\n","#Force session restart\n","exit(0)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"3pcnkKZaDjaF"},"source":["\n","## **1.2. Restart your runtime**\n","---\n","\n","\n","\n","** Ignore the following message error message. Your Runtime has automatically restarted. This is normal.**\n","\n","\"\"
"]},{"cell_type":"markdown","metadata":{"id":"9rNP7LpoDoGk"},"source":["## **1.3. Load key dependencies**\n","---\n"," "]},{"cell_type":"code","metadata":{"id":"fq21zJVFNASx","cellView":"form"},"source":["#@markdown ##Load key 3D U-Net dependencies and instantiate network\n","Notebook_version = '1.13'\n","Network = 'U-Net (3D)'\n","\n","from builtins import any as b_any\n","\n","def get_requirements_path():\n"," # Store requirements file in 'contents' directory \n"," current_dir = os.getcwd()\n"," dir_count = current_dir.count('/') - 1\n"," path = '../' * (dir_count) + 'requirements.txt'\n"," return path\n","\n","def filter_files(file_list, filter_list):\n"," filtered_list = []\n"," for fname in file_list:\n"," if b_any(fname.split('==')[0] in s for s in filter_list):\n"," filtered_list.append(fname)\n"," return filtered_list\n","\n","def build_requirements_file(before, after):\n"," path = get_requirements_path()\n","\n"," # Exporting requirements.txt for local run\n"," !pip freeze > $path\n","\n"," # Get minimum requirements file\n"," df = pd.read_csv(path, delimiter = \"\\n\")\n"," mod_list = [m.split('.')[0] for m in after if not m in before]\n"," req_list_temp = df.values.tolist()\n"," req_list = [x[0] for x in req_list_temp]\n","\n"," # Replace with package name and handle cases where import name is different to module name\n"," mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]\n"," mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list] \n"," filtered_list = filter_files(req_list, mod_replace_list)\n","\n"," file=open(path,'w')\n"," for item in filtered_list:\n"," file.writelines(item + '\\n')\n","\n"," file.close()\n","\n","import sys\n","before = [str(m) for m in sys.modules]\n","\n","#Put the imported code and libraries here\n","# !pip install fpdf\n","from __future__ import absolute_import, division, print_function, unicode_literals\n","\n","try:\n"," import elasticdeform\n","except:\n"," !pip install elasticdeform\n"," import elasticdeform\n","\n","try:\n"," import tifffile\n","except:\n"," !pip install tifffile\n"," import tifffile\n","\n","try:\n"," import imgaug.augmenters as iaa\n","except:\n"," !pip install imgaug\n"," import imgaug.augmenters as iaa\n","\n","import os\n","import csv\n","import random\n","import h5py\n","import imageio\n","import math\n","import shutil\n","\n","import pandas as pd\n","from glob import glob\n","from tqdm import tqdm\n","\n","from skimage import transform\n","from skimage import exposure\n","from skimage import color\n","from skimage import io\n","\n","from scipy.ndimage import zoom\n","\n","import matplotlib.pyplot as plt\n","\n","import numpy as np\n","\n","%tensorflow_version 1.x\n","import tensorflow as tf\n","print(tf.__version__)\n","\n","# from keras import backend as K\n","\n","# from keras.layers import Conv3D\n","# from keras.layers import BatchNormalization\n","# from keras.layers import ReLU\n","# from keras.layers import MaxPooling3D\n","# from keras.layers import Conv3DTranspose\n","# from keras.layers import Input\n","# from keras.layers import Concatenate\n","\n","# from keras.models import Model\n","\n","# from keras.utils import Sequence\n","# from keras.callbacks import ModelCheckpoint\n","# from keras.callbacks import CSVLogger\n","# from keras.callbacks import Callback\n","\n","from tensorflow.keras import backend as K\n","\n","from tensorflow.keras.layers import Conv3D\n","from tensorflow.keras.layers import BatchNormalization\n","from tensorflow.keras.layers import ReLU\n","from tensorflow.keras.layers import MaxPooling3D\n","from tensorflow.keras.layers import Conv3DTranspose\n","from tensorflow.keras.layers import Input\n","from tensorflow.keras.layers import Concatenate\n","\n","from tensorflow.keras.models import Model\n","\n","from tensorflow.keras.utils import Sequence\n","from tensorflow.keras.callbacks import ModelCheckpoint\n","from tensorflow.keras.callbacks import CSVLogger\n","from tensorflow.keras.callbacks import Callback\n","\n","from tensorflow.keras.metrics import RootMeanSquaredError\n","\n","from tensorflow.keras.optimizers import Adam, SGD, RMSprop\n","\n","from ipywidgets import interact\n","from ipywidgets import interactive\n","from ipywidgets import fixed\n","from ipywidgets import interact_manual \n","import ipywidgets as widgets\n","\n","from fpdf import FPDF, HTMLMixin\n","from datetime import datetime\n","import subprocess\n","from pip._internal.operations.freeze import freeze\n","import time\n","\n","from skimage import io\n","import matplotlib\n","\n","print(\"Dependencies installed and imported.\")\n","\n","# Define MultiPageTiffGenerator class\n","class MultiPageTiffGenerator(Sequence):\n","\n"," def __init__(self,\n"," source_path,\n"," target_path,\n"," batch_size=1,\n"," shape=(128,128,32,1),\n"," augment=False,\n"," augmentations=[],\n"," deform_augment=False,\n"," deform_augmentation_params=(5,3,4),\n"," val_split=0.2,\n"," is_val=False,\n"," random_crop=True,\n"," downscale=1,\n"," binary_target=False):\n","\n"," # If directory with various multi-page tiffiles is provided read as list\n"," if os.path.isfile(source_path):\n"," self.dir_flag = False\n"," self.source = tifffile.imread(source_path)\n"," if binary_target:\n"," self.target = tifffile.imread(target_path).astype(np.bool)\n"," else:\n"," self.target = tifffile.imread(target_path)\n","\n"," elif os.path.isdir(source_path):\n"," self.dir_flag = True\n"," self.source_dir_list = glob(os.path.join(source_path, '*'))\n"," self.target_dir_list = glob(os.path.join(target_path, '*'))\n","\n"," self.source_dir_list.sort()\n"," self.target_dir_list.sort()\n","\n"," self.shape = shape\n"," self.batch_size = batch_size\n"," self.augment = augment\n"," self.val_split = val_split\n"," self.is_val = is_val\n"," self.random_crop = random_crop\n"," self.downscale = downscale\n"," self.binary_target = binary_target\n"," self.deform_augment = deform_augment\n"," self.on_epoch_end()\n"," \n"," if self.augment:\n"," # pass list of augmentation functions \n"," self.seq = iaa.Sequential(augmentations, random_order=True) # apply augmenters in random order\n"," if self.deform_augment:\n"," self.deform_sigma, self.deform_points, self.deform_order = deform_augmentation_params\n","\n"," def __len__(self):\n"," # If various multi-page tiff files provided sum all images within each\n"," if self.augment:\n"," augment_factor = 4\n"," else:\n"," augment_factor = 1\n"," \n"," if self.dir_flag:\n"," num_of_imgs = 0\n"," for tiff_path in self.source_dir_list:\n"," num_of_imgs += tifffile.imread(tiff_path).shape[0]\n"," xy_shape = tifffile.imread(self.source_dir_list[0]).shape[1:]\n","\n"," if self.is_val:\n"," if self.random_crop:\n"," crop_volume = self.shape[0] * self.shape[1] * self.shape[2]\n"," volume = xy_shape[0] * xy_shape[1] * self.val_split * num_of_imgs\n"," return math.floor(augment_factor * volume / (crop_volume * self.batch_size * self.downscale))\n"," else:\n"," return math.floor(self.val_split * num_of_imgs / self.batch_size)\n"," else:\n"," if self.random_crop:\n"," crop_volume = self.shape[0] * self.shape[1] * self.shape[2]\n"," volume = xy_shape[0] * xy_shape[1] * (1 - self.val_split) * num_of_imgs\n"," return math.floor(augment_factor * volume / (crop_volume * self.batch_size * self.downscale))\n","\n"," else:\n"," return math.floor(augment_factor*(1 - self.val_split) * num_of_imgs/self.batch_size)\n"," else:\n"," if self.is_val:\n"," if self.random_crop:\n"," crop_volume = self.shape[0] * self.shape[1] * self.shape[2]\n"," volume = self.source.shape[0] * self.source.shape[1] * self.val_split * self.source.shape[2]\n"," return math.floor(augment_factor * volume / (crop_volume * self.batch_size * self.downscale))\n"," else:\n"," return math.floor((self.val_split * self.source.shape[0] / self.batch_size))\n"," else:\n"," if self.random_crop:\n"," crop_volume = self.shape[0] * self.shape[1] * self.shape[2]\n"," volume = self.source.shape[0] * self.source.shape[1] * (1 - self.val_split) * self.source.shape[2]\n"," return math.floor(augment_factor * volume / (crop_volume * self.batch_size * self.downscale))\n"," else:\n"," return math.floor(augment_factor * (1 - self.val_split) * self.source.shape[0] / self.batch_size)\n","\n"," def __getitem__(self, idx):\n"," source_batch = np.empty((self.batch_size,\n"," self.shape[0],\n"," self.shape[1],\n"," self.shape[2],\n"," self.shape[3]))\n"," target_batch = np.empty((self.batch_size,\n"," self.shape[0],\n"," self.shape[1],\n"," self.shape[2],\n"," self.shape[3]))\n","\n"," for batch in range(self.batch_size):\n"," # Modulo operator ensures IndexError is avoided\n"," stack_start = self.batch_list[(idx+batch*self.shape[2])%len(self.batch_list)]\n","\n"," if self.dir_flag:\n"," self.source = tifffile.imread(self.source_dir_list[stack_start[0]])\n"," if self.binary_target:\n"," self.target = tifffile.imread(self.target_dir_list[stack_start[0]]).astype(np.bool)\n"," else:\n"," self.target = tifffile.imread(self.target_dir_list[stack_start[0]])\n","\n"," src_list = []\n"," tgt_list = []\n"," for i in range(stack_start[1], stack_start[1]+self.shape[2]):\n"," src = self.source[i]\n"," src = transform.downscale_local_mean(src, (self.downscale, self.downscale))\n"," if not self.random_crop:\n"," src = transform.resize(src, (self.shape[0], self.shape[1]), mode='constant', preserve_range=True)\n"," src = self._min_max_scaling(src)\n"," src_list.append(src)\n","\n"," tgt = self.target[i]\n"," tgt = transform.downscale_local_mean(tgt, (self.downscale, self.downscale))\n"," if not self.random_crop:\n"," tgt = transform.resize(tgt, (self.shape[0], self.shape[1]), mode='constant', preserve_range=True)\n"," if not self.binary_target:\n"," tgt = self._min_max_scaling(tgt)\n"," tgt_list.append(tgt)\n","\n"," if self.random_crop:\n"," if src.shape[0] == self.shape[0]:\n"," x_rand = 0\n"," if src.shape[1] == self.shape[1]:\n"," y_rand = 0\n"," if src.shape[0] > self.shape[0]:\n"," x_rand = np.random.randint(src.shape[0] - self.shape[0])\n"," if src.shape[1] > self.shape[1]:\n"," y_rand = np.random.randint(src.shape[1] - self.shape[1])\n"," if src.shape[0] < self.shape[0] or src.shape[1] < self.shape[1]:\n"," raise ValueError('Patch shape larger than (downscaled) source shape')\n"," \n"," for i in range(self.shape[2]):\n"," if self.random_crop:\n"," src = src_list[i]\n"," tgt = tgt_list[i]\n"," src_crop = src[x_rand:self.shape[0]+x_rand, y_rand:self.shape[1]+y_rand]\n"," tgt_crop = tgt[x_rand:self.shape[0]+x_rand, y_rand:self.shape[1]+y_rand]\n"," else:\n"," src_crop = src_list[i]\n"," tgt_crop = tgt_list[i]\n","\n"," source_batch[batch,:,:,i,0] = src_crop\n"," target_batch[batch,:,:,i,0] = tgt_crop\n","\n"," if self.augment:\n"," # On-the-fly data augmentation\n"," source_batch, target_batch = self.augment_volume(source_batch, target_batch)\n","\n"," # Data augmentation by reversing stack\n"," if np.random.random() > 0.5:\n"," source_batch, target_batch = source_batch[::-1], target_batch[::-1]\n"," \n"," # Data augmentation by elastic deformation\n"," if np.random.random() > 0.5 and self.deform_augment:\n"," source_batch, target_batch = self.deform_volume(source_batch, target_batch)\n"," \n"," if not self.binary_target:\n"," target_batch = self._min_max_scaling(target_batch)\n"," \n"," return self._min_max_scaling(source_batch), target_batch\n"," \n"," else:\n"," return source_batch, target_batch\n","\n"," def on_epoch_end(self):\n"," # Validation split performed here\n"," self.batch_list = []\n"," # Create batch_list of all combinations of tifffile and stack position\n"," if self.dir_flag:\n"," for i in range(len(self.source_dir_list)):\n"," num_of_pages = tifffile.imread(self.source_dir_list[i]).shape[0]\n"," if self.is_val:\n"," start_page = num_of_pages-math.floor(self.val_split*num_of_pages)\n"," for j in range(start_page, num_of_pages-self.shape[2]):\n"," self.batch_list.append([i, j])\n"," else:\n"," last_page = math.floor((1-self.val_split)*num_of_pages)\n"," for j in range(last_page-self.shape[2]):\n"," self.batch_list.append([i, j])\n"," else:\n"," num_of_pages = self.source.shape[0]\n"," if self.is_val:\n"," start_page = num_of_pages-math.floor(self.val_split*num_of_pages)\n"," for j in range(start_page, num_of_pages-self.shape[2]):\n"," self.batch_list.append([0, j])\n","\n"," else:\n"," last_page = math.floor((1-self.val_split)*num_of_pages)\n"," for j in range(last_page-self.shape[2]):\n"," self.batch_list.append([0, j])\n"," \n"," if self.is_val and (len(self.batch_list) <= 0):\n"," raise ValueError('validation_split too small! Increase val_split or decrease z-depth')\n"," random.shuffle(self.batch_list)\n"," \n"," def _min_max_scaling(self, data):\n"," n = data - np.min(data)\n"," d = np.max(data) - np.min(data) \n"," \n"," return n/d\n"," \n"," def class_weights(self):\n"," ones = 0\n"," pixels = 0\n","\n"," if self.dir_flag:\n"," for i in range(len(self.target_dir_list)):\n"," tgt = tifffile.imread(self.target_dir_list[i]).astype(np.bool)\n"," ones += np.sum(tgt)\n"," pixels += tgt.shape[0]*tgt.shape[1]*tgt.shape[2]\n"," else:\n"," ones = np.sum(self.target)\n"," pixels = self.target.shape[0]*self.target.shape[1]*self.target.shape[2]\n"," p_ones = ones/pixels\n"," p_zeros = 1-p_ones\n","\n"," # Return swapped probability to increase weight of unlikely class\n"," return p_ones, p_zeros\n","\n"," def deform_volume(self, src_vol, tgt_vol):\n"," [src_dfrm, tgt_dfrm] = elasticdeform.deform_random_grid([src_vol, tgt_vol],\n"," axis=(1, 2, 3),\n"," sigma=self.deform_sigma,\n"," points=self.deform_points,\n"," order=self.deform_order)\n"," if self.binary_target:\n"," tgt_dfrm = tgt_dfrm > 0.1\n"," \n"," return self._min_max_scaling(src_dfrm), tgt_dfrm \n","\n"," def augment_volume(self, src_vol, tgt_vol):\n"," src_vol_aug = np.empty(src_vol.shape)\n"," tgt_vol_aug = np.empty(tgt_vol.shape)\n","\n"," for i in range(src_vol.shape[3]):\n"," src_vol_aug[:,:,:,i,0], tgt_vol_aug[:,:,:,i,0] = self.seq(images=src_vol[:,:,:,i,0].astype('float16'), \n"," segmentation_maps=tgt_vol[:,:,:,i,0].astype(bool))\n"," return self._min_max_scaling(src_vol_aug), tgt_vol_aug\n","\n"," def sample_augmentation(self, idx):\n"," src, tgt = self.__getitem__(idx)\n","\n"," src_aug, tgt_aug = self.augment_volume(src, tgt)\n"," \n"," if self.deform_augment:\n"," src_aug, tgt_aug = self.deform_volume(src_aug, tgt_aug)\n","\n"," return src_aug, tgt_aug \n","\n","# Define custom loss and dice coefficient\n","def dice_coefficient(y_true, y_pred):\n"," eps = 1e-6\n"," y_true_f = K.flatten(y_true)\n"," y_pred_f = K.flatten(y_pred)\n"," intersection = K.sum(y_true_f*y_pred_f)\n","\n"," return (2.*intersection)/(K.sum(y_true_f*y_true_f)+K.sum(y_pred_f*y_pred_f)+eps)\n","\n","def weighted_binary_crossentropy(zero_weight, one_weight):\n"," def _weighted_binary_crossentropy(y_true, y_pred):\n"," binary_crossentropy = K.binary_crossentropy(y_true, y_pred)\n","\n"," weight_vector = y_true*one_weight+(1.-y_true)*zero_weight\n"," weighted_binary_crossentropy = weight_vector*binary_crossentropy\n","\n"," return K.mean(weighted_binary_crossentropy)\n","\n"," return _weighted_binary_crossentropy\n","\n","# Custom callback showing sample prediction\n","class SampleImageCallback(Callback):\n","\n"," def __init__(self, model, sample_data, model_path, save=False):\n"," self.model = model\n"," self.sample_data = sample_data\n"," self.model_path = model_path\n"," self.save = save\n","\n"," def on_epoch_end(self, epoch, logs={}):\n"," sample_predict = self.model.predict_on_batch(self.sample_data)\n","\n"," f=plt.figure(figsize=(16,8))\n"," plt.subplot(1,2,1)\n"," plt.imshow(self.sample_data[0,:,:,0,0], interpolation='nearest', cmap='gray')\n"," plt.title('Sample source')\n"," plt.axis('off');\n","\n"," plt.subplot(1,2,2)\n"," plt.imshow(sample_predict[0,:,:,0,0], interpolation='nearest', cmap='magma')\n"," plt.title('Predicted target')\n"," plt.axis('off');\n","\n"," plt.show()\n","\n"," if self.save:\n"," plt.savefig(self.model_path + '/epoch_' + str(epoch+1) + '.png')\n","\n","\n","# Define Unet3D class\n","class Unet3D:\n","\n"," def __init__(self,\n"," shape=(256,256,16,1)):\n"," if isinstance(shape, str):\n"," shape = eval(shape)\n","\n"," self.shape = shape\n"," \n"," input_tensor = Input(self.shape, name='input')\n","\n"," self.model = self.unet_3D(input_tensor)\n","\n"," def down_block_3D(self, input_tensor, filters):\n"," x = Conv3D(filters=filters, kernel_size=(3,3,3), padding='same')(input_tensor)\n"," x = BatchNormalization()(x)\n"," x = ReLU()(x)\n","\n"," x = Conv3D(filters=filters*2, kernel_size=(3,3,3), padding='same')(x)\n"," x = BatchNormalization()(x)\n"," x = ReLU()(x)\n","\n"," return x\n","\n"," def up_block_3D(self, input_tensor, concat_layer, filters):\n"," x = Conv3DTranspose(filters, kernel_size=(2,2,2), strides=(2,2,2))(input_tensor)\n","\n"," x = Concatenate()([x, concat_layer])\n","\n"," x = Conv3D(filters=filters, kernel_size=(3,3,3), padding='same')(x)\n"," x = BatchNormalization()(x)\n"," x = ReLU()(x)\n","\n"," x = Conv3D(filters=filters*2, kernel_size=(3,3,3), padding='same')(x)\n"," x = BatchNormalization()(x)\n"," x = ReLU()(x)\n","\n"," return x\n","\n"," def unet_3D(self, input_tensor, filters=32):\n"," d1 = self.down_block_3D(input_tensor, filters=filters)\n"," p1 = MaxPooling3D(pool_size=(2,2,2), strides=(2,2,2), data_format='channels_last')(d1)\n"," d2 = self.down_block_3D(p1, filters=filters*2)\n"," p2 = MaxPooling3D(pool_size=(2,2,2), strides=(2,2,2), data_format='channels_last')(d2)\n"," d3 = self.down_block_3D(p2, filters=filters*4)\n"," p3 = MaxPooling3D(pool_size=(2,2,2), strides=(2,2,2), data_format='channels_last')(d3)\n","\n"," d4 = self.down_block_3D(p3, filters=filters*8)\n","\n"," u1 = self.up_block_3D(d4, d3, filters=filters*4)\n"," u2 = self.up_block_3D(u1, d2, filters=filters*2)\n"," u3 = self.up_block_3D(u2, d1, filters=filters)\n","\n"," output_tensor = Conv3D(filters=1, kernel_size=(1,1,1), activation='sigmoid')(u3)\n","\n"," return Model(inputs=[input_tensor], outputs=[output_tensor])\n","\n"," def summary(self):\n"," return self.model.summary()\n","\n"," # Pass generators instead\n"," def train(self, \n"," epochs, \n"," batch_size, \n"," train_generator,\n"," val_generator, \n"," model_path, \n"," model_name,\n"," optimizer='adam',\n"," learning_rate=0.001,\n"," loss='weighted_binary_crossentropy',\n"," metrics='dice',\n"," ckpt_period=1, \n"," save_best_ckpt_only=False, \n"," ckpt_path=None):\n","\n"," class_weight_zero, class_weight_one = train_generator.class_weights()\n"," \n"," if loss == 'weighted_binary_crossentropy':\n"," loss = weighted_binary_crossentropy(class_weight_zero, class_weight_one)\n"," \n"," if metrics == 'dice':\n"," metrics = dice_coefficient\n","\n"," if optimizer == 'adam':\n"," optimizer = Adam(learning_rate=learning_rate)\n"," elif optimizer == 'sgd':\n"," optimizer = SGD(learning_rate=learning_rate)\n"," elif optimizer == 'rmsprop':\n"," optimizer = RMSprop(learning_rate=learning_rate)\n","\n"," self.model.compile(optimizer=optimizer,\n"," loss=loss,\n"," metrics=[metrics])\n","\n"," if ckpt_path is not None:\n"," self.model.load_weights(ckpt_path)\n","\n"," full_model_path = os.path.join(model_path, model_name)\n","\n"," if not os.path.exists(full_model_path):\n"," os.makedirs(full_model_path)\n"," \n"," log_dir = full_model_path + '/Quality Control'\n","\n"," if not os.path.exists(log_dir):\n"," os.makedirs(log_dir)\n"," \n"," ckpt_dir = full_model_path + '/ckpt'\n","\n"," if not os.path.exists(ckpt_dir):\n"," os.makedirs(ckpt_dir)\n","\n"," csv_out_name = log_dir + '/training_evaluation.csv'\n"," if ckpt_path is None:\n"," csv_logger = CSVLogger(csv_out_name)\n"," else:\n"," csv_logger = CSVLogger(csv_out_name, append=True)\n","\n"," if save_best_ckpt_only:\n"," ckpt_name = ckpt_dir + '/' + model_name + '.hdf5'\n"," else:\n"," ckpt_name = ckpt_dir + '/' + model_name + '_epoch_{epoch:02d}_val_loss_{val_loss:.4f}.hdf5'\n"," \n"," model_ckpt = ModelCheckpoint(ckpt_name,\n"," verbose=1,\n"," period=ckpt_period,\n"," save_best_only=save_best_ckpt_only,\n"," save_weights_only=True)\n","\n"," sample_batch, __ = val_generator.__getitem__(random.randint(0, len(val_generator)))\n"," sample_img = SampleImageCallback(self.model, \n"," sample_batch, \n"," model_path)\n","\n"," self.model.fit_generator(generator=train_generator,\n"," validation_data=val_generator,\n"," validation_steps=math.floor(len(val_generator)/batch_size),\n"," epochs=epochs,\n"," callbacks=[csv_logger,\n"," model_ckpt,\n"," sample_img])\n","\n"," last_ckpt_name = ckpt_dir + '/' + model_name + '_last.hdf5'\n"," self.model.save_weights(last_ckpt_name)\n","\n"," def _min_max_scaling(self, data):\n"," n = data - np.min(data)\n"," d = np.max(data) - np.min(data) \n"," \n"," return n/d\n","\n"," def predict(self, \n"," input, \n"," ckpt_path, \n"," z_range=None, \n"," downscaling=None, \n"," true_patch_size=None):\n","\n"," self.model.load_weights(ckpt_path)\n","\n"," if isinstance(downscaling, str):\n"," downscaling = eval(downscaling)\n","\n"," if math.isnan(downscaling):\n"," downscaling = None\n","\n"," if isinstance(true_patch_size, str):\n"," true_patch_size = eval(true_patch_size)\n"," \n"," if not isinstance(true_patch_size, tuple): \n"," if math.isnan(true_patch_size):\n"," true_patch_size = None\n","\n"," if isinstance(input, str):\n"," src_volume = tifffile.imread(input)\n"," elif isinstance(input, np.ndarray):\n"," src_volume = input\n"," else:\n"," raise TypeError('Input is not path or numpy array!')\n"," \n"," in_size = src_volume.shape\n","\n"," if downscaling or true_patch_size is not None:\n"," x_scaling = 0\n"," y_scaling = 0\n","\n"," if true_patch_size is not None:\n"," x_scaling += true_patch_size[0]/self.shape[0]\n"," y_scaling += true_patch_size[1]/self.shape[1]\n"," if downscaling is not None:\n"," x_scaling += downscaling\n"," y_scaling += downscaling\n","\n"," src_list = []\n"," for i in range(src_volume.shape[0]):\n"," src_list.append(transform.downscale_local_mean(src_volume[i], (int(x_scaling), int(y_scaling))))\n"," src_volume = np.array(src_list) \n","\n"," if z_range is not None:\n"," src_volume = src_volume[z_range[0]:z_range[1]]\n","\n"," src_volume = self._min_max_scaling(src_volume) \n","\n"," src_array = np.zeros((1,\n"," math.ceil(src_volume.shape[1]/self.shape[0])*self.shape[0], \n"," math.ceil(src_volume.shape[2]/self.shape[1])*self.shape[1],\n"," math.ceil(src_volume.shape[0]/self.shape[2])*self.shape[2], \n"," self.shape[3]))\n","\n"," for i in range(src_volume.shape[0]):\n"," src_array[0,:src_volume.shape[1],:src_volume.shape[2],i,0] = src_volume[i]\n","\n"," pred_array = np.empty(src_array.shape)\n","\n"," for i in range(math.ceil(src_volume.shape[1]/self.shape[0])):\n"," for j in range(math.ceil(src_volume.shape[2]/self.shape[1])):\n"," for k in range(math.ceil(src_volume.shape[0]/self.shape[2])):\n"," pred_temp = self.model.predict(src_array[:,\n"," i*self.shape[0]:i*self.shape[0]+self.shape[0],\n"," j*self.shape[1]:j*self.shape[1]+self.shape[1],\n"," k*self.shape[2]:k*self.shape[2]+self.shape[2]])\n"," pred_array[:,\n"," i*self.shape[0]:i*self.shape[0]+self.shape[0],\n"," j*self.shape[1]:j*self.shape[1]+self.shape[1],\n"," k*self.shape[2]:k*self.shape[2]+self.shape[2]] = pred_temp\n"," \n"," pred_volume = np.rollaxis(np.squeeze(pred_array), -1)[:src_volume.shape[0],:src_volume.shape[1],:src_volume.shape[2]] \n","\n"," if downscaling is not None:\n"," pred_list = []\n"," for i in range(pred_volume.shape[0]):\n"," pred_list.append(transform.resize(pred_volume[i], (in_size[1], in_size[2]), preserve_range=True))\n"," pred_volume = np.array(pred_list)\n","\n"," return pred_volume\n","\n","def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," \n"," # add another cell \n"," if trained:\n"," training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n"," pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n"," pdf.ln(1)\n","\n"," Header_2 = 'Information for your materials and methods:'\n"," pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n"," #print(all_packages)\n","\n"," #Main Packages\n"," main_packages = ''\n"," version_numbers = []\n"," for name in ['tensorflow','numpy','Keras']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n"," cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n"," cuda_version = cuda_version.stdout.decode('utf-8')\n"," cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n"," gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n"," gpu_name = gpu_name.stdout.decode('utf-8')\n"," gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n"," #print(cuda_version[cuda_version.find(', V')+3:-1])\n"," #print(gpu_name)\n","\n"," if os.path.isdir(training_source):\n"," shape = io.imread(training_source+'/'+os.listdir(training_source)[0]).shape\n"," elif os.path.isfile(training_source):\n"," shape = io.imread(training_source).shape\n"," else:\n"," print('Cannot read training data.')\n","\n"," dataset_size = len(train_generator)\n","\n"," text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+') with a batch size of '+str(batch_size)+' and a '+loss_function+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," if pretrained_model:\n"," text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch_size: '+str(patch_size)+') with a batch size of '+str(batch_size)+' and a '+loss_function+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was retrained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.multi_cell(190, 5, txt = text, align='L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n"," pdf.set_font('')\n"," if augmentation:\n"," aug_text = 'The dataset was augmented by'\n"," if add_gaussian_blur == True:\n"," aug_text = aug_text+'\\n- gaussian blur'\n"," if add_linear_contrast == True:\n"," aug_text = aug_text+'\\n- linear contrast'\n"," if add_additive_gaussian_noise == True:\n"," aug_text = aug_text+'\\n- additive gaussian noise'\n"," if augmenters != '':\n"," aug_text = aug_text+'\\n- imgaug augmentations: '+augmenters\n"," if add_elastic_deform == True:\n"," aug_text = aug_text+'\\n- elastic deformation'\n"," else:\n"," aug_text = 'No augmentation was used for training.'\n"," pdf.multi_cell(190, 5, txt=aug_text, align='L')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," if use_default_advanced_parameters:\n"," pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n"," pdf.cell(200, 5, txt='The following parameters were used for training:')\n"," pdf.ln(1)\n"," html = \"\"\" \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
number_of_epochs{0}
batch_size{1}
patch_size{2}
image_pre_processing{3}
validation_split_in_percent{4}
downscaling_in_xy{5}
binary_target{6}
loss_function{7}
metrics{8}
optimizer{9}
checkpointing_period{10}
save_best_only{11}
resume_training{12}
\n"," \"\"\".format(number_of_epochs,batch_size,str(patch_size[0])+'x'+str(patch_size[1])+'x'+str(patch_size[2]),image_pre_processing, validation_split_in_percent, downscaling_in_xy, str(binary_target), loss_function, metrics, optimizer, checkpointing_period, str(save_best_only), str(resume_training))\n"," pdf.write_html(html)\n","\n"," #pdf.multi_cell(190, 5, txt = text_2, align='L')\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," pdf.ln(1)\n"," pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(30, 5, txt= 'Training_source:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = training_source, align = 'L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(28, 5, txt= 'Training_target:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = training_target, align = 'L')\n"," #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(21, 5, txt= 'Model Path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n"," pdf.ln(1)\n"," pdf.cell(60, 5, txt = 'Example Training pair (single slice)', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread('/content/TrainingDataExample_Unet3D.png').shape\n"," pdf.image('/content/TrainingDataExample_Unet3D.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- Unet 3D: Çiçek, Özgün, et al. \"3D U-Net: learning dense volumetric segmentation from sparse annotation.\" International conference on medical image computing and computer-assisted intervention. Springer, Cham, 2016.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," # if Use_Data_augmentation:\n"," # ref_4 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\n"," # pdf.multi_cell(190, 5, txt = ref_4, align='L')\n"," pdf.ln(3)\n"," reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(model_path+'/'+model_name+'/'+model_name+'_training_report.pdf')\n","\n"," print('------------------------------')\n"," print('PDF report exported in '+model_path+'/'+model_name+'/')\n","\n","\n","def qc_pdf_export():\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'U-Net 3D'\n","\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Quality Control report for '+Network+' model ('+qc_model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(2)\n"," pdf.cell(190, 5, txt = 'Loss curves', ln=1, align='L')\n"," pdf.ln(1)\n"," if os.path.exists(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/lossCurvePlots.png'):\n"," exp_size = io.imread(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/lossCurvePlots.png').shape\n"," pdf.image(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.')\n"," pdf.ln(2)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(3)\n"," pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/QC_example_data.png').shape\n"," pdf.image(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'IoU threshold optimisation', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.ln(1)\n"," pdf.cell(120, 5, txt='Highest IoU is {:.4f} with a threshold of {}'.format(best_iou, best_thresh), align='L', ln=1)\n"," pdf.ln(2)\n"," exp_size = io.imread(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/QC_IoU_analysis.png').shape\n"," pdf.image(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/QC_IoU_analysis.png', x=16, y=None, w = round(exp_size[1]/6), h = round(exp_size[0]/6))\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- Unet 3D: Çiçek, Özgün, et al. \"3D U-Net: learning dense volumetric segmentation from sparse annotation.\" International conference on medical image computing and computer-assisted intervention. Springer, Cham, 2016.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n"," pdf.ln(3)\n"," reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/'+qc_model_name+'_QC_report.pdf')\n","\n"," print('------------------------------')\n"," print('QC PDF report exported in '+os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/')\n","\n","\n","# -------------- Other definitions -----------\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","prediction_prefix = 'Predicted_'\n","\n","\n","print('-------------------')\n","print('U-Net 3D and dependencies installed.')\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n"," NORMAL = '\\033[0m' # white (normal)\n"," \n","\n","# Check if this is the latest version of the notebook\n","# Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n","\n","# if Notebook_version == list(Latest_notebook_version.columns):\n","# print(\"This notebook is up-to-date.\")\n","\n","# if not Notebook_version == list(Latest_notebook_version.columns):\n","# print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","All_notebook_versions = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_Notebook_versions.csv\", dtype=str)\n","print('Notebook version: '+Notebook_version)\n","Latest_Notebook_version = All_notebook_versions[All_notebook_versions[\"Notebook\"] == Network]['Version'].iloc[0]\n","print('Latest notebook version: '+Latest_Notebook_version)\n","if Notebook_version == Latest_Notebook_version:\n"," print(\"This notebook is up-to-date.\")\n","else:\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","\n","\n","# Build requirements file for local run\n","after = [str(m) for m in sys.modules]\n","build_requirements_file(before, after)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **2. Complete the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["\n","## **2.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelerator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"zCvebubeSaGY","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"sNIVx8_CLolt"},"source":["## **2.2. Mount Google Drive**\n","---\n"," To use this notebook with your **own data**, place it in a folder on **Google Drive** following one of the directory structures outlined in **Section 0**.\n","\n","1. **Run** the **cell** below to mount your Google Drive and follow the link. \n","\n","2. **Sign in** to your Google account and press 'Allow'. \n","\n","3. Next, copy the **authorization code**, paste it into the cell and press enter. This will allow Colab to read and write data from and to your Google Drive. \n","\n","4. Once this is done, your data can be viewed in the **Files tab** on the top left of the notebook after hitting 'Refresh'."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","cellView":"form"},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"HLYcZR9gMv42"},"source":["# **3. Select your parameters and paths**\n","---"]},{"cell_type":"markdown","metadata":{"id":"AuESFimvMv43"},"source":["## **3.1. Choosing parameters**\n","\n","---\n","\n","### **Paths to training data and model**\n","\n","* **`training_source`** and **`training_target`** specify the paths to the training data. They can either be a single multipage TIFF file each or directories containing various multipage TIFF files in which case target and source files must be named identically within the respective directories. See Section 0 for a detailed description of the necessary directory structure.\n","\n","* **`model_name`** will be used when naming checkpoints. Adhere to a `lower_case_with_underscores` naming convention and beware of using the name of an existing model within the same folder, as it will be overwritten.\n","\n","* **`model_path`** specifies the directory where the model checkpoints and quality control logs will be saved.\n","\n","\n","**Note:** You can copy paths from the 'Files' tab by right-clicking any folder or file and selecting 'Copy path'. \n","\n","### **Training parameters**\n","\n","* **`number_of_epochs`** is the number of times the entire training data will be seen by the model. *Default: >100*\n","\n","* **`batch_size`** is the number of training patches of size `patch_size` that will be bundled together at each training step. *Default: 1*\n","\n","* **`patch_size`** specifies the size of the three-dimensional training patches in (x, y, z) that will be fed to the model. In order to avoid errors, preferably use a square aspect ratio or stick to the advanced parameters. *Default: <(512, 512, 16)*\n","\n","* **`validation_split_in_percent`** is the relative amount of training data that will be set aside for validation. *Default: 20* \n","\n","* **`downscaling_in_xy`** downscales the training images by the specified amount in x and y. This is useful to enforce isotropic pixel-size if the z resolution is lower than the xy resolution in the training volume or to capture a larger field-of-view while decreasing the memory requirements. *Default: 1*\n","\n","* **`image_pre_processing`** selects whether the training images are randomly cropped during training or resized to `patch_size`. Choose `randomly crop to patch_size` to shrink the field-of-view of the training images to the `patch_size`. *Default: resize to patch_size* \n","\n","* **`binary_target`** forces the target image to be binary. Choose this if your model is trained to perform binary segmentation tasks *Default: True* \n","\n","* **`loss_function`** defines the loss. Read more [here](https://keras.io/api/losses/). *Default: weighted_binary_crossentropy* \n","\n","* **`metrics`** defines the metric. Read more [here](https://keras.io/api/metrics/). *Default: dice* \n","\n","* **`optimizer`** defines the optimizer. Read more [here](https://keras.io/api/optimizers/). *Default: adam* \n","\n","**Note:** If a *ResourceExhaustedError* is raised in Section 4.1. during training, decrease `batch_size` and `patch_size`. Decrease `batch_size` first and if the error persists at `batch_size = 1`, reduce the `patch_size`. \n","\n","**Note:** The number of steps per epoch are calculated as `floor(augment_factor * (1 - validation_split) * num_of_slices / batch_size)` if `image_pre_processing` is `resize to patch_size` where `augment_factor` is three if `apply_data_augmentation` is `True` and one otherwise. The `num_of_slices` is the overall number of slices (z-depth) in the training set across all provided image volumes. If `image_pre_processing` is `randomly crop to patch_size`, the number of steps per epoch are calculated as `floor(augment_factor * volume / (crop_volume * batch_size))` where `volume` is the overall volume of the training data in pixels accounting for the validation split and `crop_volume` is defined as the volume in pixels based on the specified `patch_size`."]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","cellView":"form"},"source":["#@markdown ###Path to training data:\n","training_source = \"\" #@param {type:\"string\"}\n","training_target = \"\" #@param {type:\"string\"}\n","\n","#@markdown ---\n","\n","#@markdown ###Model name and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","full_model_path = os.path.join(model_path, model_name)\n","\n","#@markdown ---\n","\n","#@markdown ###Training parameters\n","number_of_epochs = 100#@param {type:\"number\"}\n","\n","#@markdown ###Default advanced parameters\n","use_default_advanced_parameters = True #@param {type:\"boolean\"}\n","\n","#@markdown If not, please change:\n","\n","batch_size = 1#@param {type:\"number\"}\n","patch_size = (256,256,4) #@param {type:\"number\"} # in pixels\n","training_shape = patch_size + (1,)\n","image_pre_processing = 'resize to patch_size' #@param [\"randomly crop to patch_size\", \"resize to patch_size\"]\n","\n","validation_split_in_percent = 20 #@param{type:\"number\"}\n","downscaling_in_xy = 2#@param {type:\"number\"} # in pixels\n","\n","binary_target = True #@param {type:\"boolean\"}\n","\n","loss_function = 'weighted_binary_crossentropy' #@param [\"weighted_binary_crossentropy\", \"binary_crossentropy\", \"categorical_crossentropy\", \"sparse_categorical_crossentropy\", \"mean_squared_error\", \"mean_absolute_error\"]\n","\n","metrics = 'dice' #@param [\"dice\", \"accuracy\"]\n","\n","optimizer = 'adam' #@param [\"adam\", \"sgd\", \"rmsprop\"]\n","\n","learning_rate = 0.0001 #@param{type:\"number\"}\n","\n","if image_pre_processing == \"randomly crop to patch_size\":\n"," random_crop = True\n","else:\n"," random_crop = False\n","\n","if use_default_advanced_parameters: \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 1\n"," training_shape = (256,256,8,1)\n"," validation_split_in_percent = 20\n"," downscaling_in_xy = 1\n"," random_crop = True\n"," binary_target = True\n"," loss_function = 'weighted_binary_crossentropy'\n"," metrics = 'dice'\n"," optimizer = 'adam'\n"," learning_rate = 0.001 \n","#@markdown ###Checkpointing parameters\n","checkpointing_period = 1 #@param {type:\"number\"}\n","\n","#@markdown If chosen only the best checkpoint is saved, otherwise a checkpoint is saved every checkpoint_period epochs:\n","save_best_only = True #@param {type:\"boolean\"}\n","\n","#@markdown ###Resume training\n","#@markdown Choose if training was interrupted:\n","resume_training = False #@param {type:\"boolean\"}\n","\n","#@markdown ###Transfer learning\n","#@markdown For transfer learning, do not select resume_training and specify a checkpoint_path below:\n","checkpoint_path = \"\" #@param {type:\"string\"}\n","\n","if resume_training and checkpoint_path != \"\":\n"," print('If resume_training is True while checkpoint_path is specified, resume_training will be set to False!')\n"," resume_training = False\n"," \n","\n","# Retrieve last checkpoint\n","if resume_training:\n"," try:\n"," ckpt_dir_list = glob(full_model_path + '/ckpt/*')\n"," ckpt_dir_list.sort()\n"," last_ckpt_path = ckpt_dir_list[-1]\n"," print('Training will resume from checkpoint:', os.path.basename(last_ckpt_path))\n"," except IndexError:\n"," last_ckpt_path=None\n"," print('CheckpointError: No previous checkpoints were found, training from scratch.')\n","elif not resume_training and checkpoint_path != \"\":\n"," last_ckpt_path = checkpoint_path\n"," assert os.path.isfile(last_ckpt_path), 'checkpoint_path does not exist!'\n","else:\n"," last_ckpt_path=None\n","\n","# Instantiate Unet3D \n","model = Unet3D(shape=training_shape)\n","\n","#here we check that no model with the same name already exist\n","if not resume_training and os.path.exists(full_model_path): \n"," print(bcolors.WARNING+'The model folder already exists and will be overwritten.'+bcolors.NORMAL)\n"," # print('!! WARNING: Folder already exists and will be overwritten !!') \n"," # shutil.rmtree(full_model_path)\n","\n","# if not os.path.exists(full_model_path):\n","# os.makedirs(full_model_path)\n","\n","# Show sample image\n","if os.path.isdir(training_source):\n"," training_source_sample = sorted(glob(os.path.join(training_source, '*')))[0]\n"," training_target_sample = sorted(glob(os.path.join(training_target, '*')))[0]\n","else:\n"," training_source_sample = training_source\n"," training_target_sample = training_target\n","\n","src_sample = tifffile.imread(training_source_sample)\n","src_sample = model._min_max_scaling(src_sample)\n","if binary_target:\n"," tgt_sample = tifffile.imread(training_target_sample).astype(np.bool)\n","else:\n"," tgt_sample = tifffile.imread(training_target_sample)\n","\n","src_down = transform.downscale_local_mean(src_sample[0], (downscaling_in_xy, downscaling_in_xy))\n","tgt_down = transform.downscale_local_mean(tgt_sample[0], (downscaling_in_xy, downscaling_in_xy)) \n","\n","if random_crop:\n"," true_patch_size = None\n","\n"," if src_down.shape[0] == training_shape[0]:\n"," x_rand = 0\n"," if src_down.shape[1] == training_shape[1]:\n"," y_rand = 0\n"," if src_down.shape[0] > training_shape[0]:\n"," x_rand = np.random.randint(src_down.shape[0] - training_shape[0])\n"," if src_down.shape[1] > training_shape[1]:\n"," y_rand = np.random.randint(src_down.shape[1] - training_shape[1])\n"," if src_down.shape[0] < training_shape[0] or src_down.shape[1] < training_shape[1]:\n"," raise ValueError('Patch shape larger than (downscaled) source shape')\n","else:\n"," true_patch_size = src_down.shape\n","\n","def scroll_in_z(z):\n"," src_down = transform.downscale_local_mean(src_sample[z-1], (downscaling_in_xy,downscaling_in_xy))\n"," tgt_down = transform.downscale_local_mean(tgt_sample[z-1], (downscaling_in_xy,downscaling_in_xy)) \n"," if random_crop:\n"," src_slice = src_down[x_rand:training_shape[0]+x_rand, y_rand:training_shape[1]+y_rand]\n"," tgt_slice = tgt_down[x_rand:training_shape[0]+x_rand, y_rand:training_shape[1]+y_rand]\n"," else:\n"," \n"," src_slice = transform.resize(src_down, (training_shape[0], training_shape[1]), mode='constant', preserve_range=True)\n"," tgt_slice = transform.resize(tgt_down, (training_shape[0], training_shape[1]), mode='constant', preserve_range=True)\n","\n"," f=plt.figure(figsize=(16,8))\n"," plt.subplot(1,2,1)\n"," plt.imshow(src_slice, cmap='gray')\n"," plt.title('Training source (z = ' + str(z) + ')', fontsize=15)\n"," plt.axis('off')\n","\n"," plt.subplot(1,2,2)\n"," plt.imshow(tgt_slice, cmap='magma')\n"," plt.title('Training target (z = ' + str(z) + ')', fontsize=15)\n"," plt.axis('off')\n"," plt.savefig('/content/TrainingDataExample_Unet3D.png',bbox_inches='tight',pad_inches=0)\n"," #plt.close()\n","\n","print('This is what the training images will look like with the chosen settings')\n","interact(scroll_in_z, z=widgets.IntSlider(min=1, max=src_sample.shape[0], step=1, value=0));\n","plt.show()\n","#Create a copy of an example slice and close the display.\n","scroll_in_z(z=int(src_sample.shape[0]/2))\n","# If you close the display, then the users can't interactively inspect the data\n","# plt.close()\n","\n","# Save model parameters\n","params = {'training_source': training_source,\n"," 'training_target': training_target,\n"," 'model_name': model_name,\n"," 'model_path': model_path,\n"," 'number_of_epochs': number_of_epochs,\n"," 'batch_size': batch_size,\n"," 'training_shape': training_shape,\n"," 'downscaling': downscaling_in_xy,\n"," 'true_patch_size': true_patch_size,\n"," 'val_split': validation_split_in_percent/100,\n"," 'random_crop': random_crop}\n","\n","params_df = pd.DataFrame.from_dict(params, orient='index')\n","\n","# apply_data_augmentation = False\n","# pdf_export(augmentation = apply_data_augmentation, pretrained_model = resume_training)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"w_jCy7xOx2g3"},"source":["## **3.2. Data augmentation**\n"," \n","---\n"," Augmenting the training data increases robustness of the model by simulating possible variations within the training data which avoids it from overfitting on small datasets. We therefore strongly recommended augmenting the data and making sure that the applied augmentations are reasonable.\n","\n","* **Gaussian blur** blurs images using Gaussian kernels with a sigma of `gaussian_sigma`. This augmentation step is applied with a probability of `gaussian_frequency`. Read more [here](https://imgaug.readthedocs.io/en/latest/source/overview/blur.html#gaussianblur).\n","\n","* **Linear contrast** modifies the contrast of images according to `127 + alpha *(pixel_value-127)`, where `pixel_value` and `alpha` are sampled uniformly from the interval `[contrast_min, contrast_max]`. This augmentation step is applied with a probability of `contrast_frequency`. Read more [here](https://imgaug.readthedocs.io/en/latest/source/overview/contrast.html#linearcontrast).\n","\n","* **Additive Gaussian noise** adds Gaussian noise sampled once per pixel from a normal distribution `N(0, s)`, where `s` is sampled from `[scale_min, scale_max]`. This augmentation step is applied with a probability of `noise_frequency`. Read more [here](https://imgaug.readthedocs.io/en/latest/source/overview/arithmetic.html#additivegaussiannoise).\n","\n","* **Add custom augmenters** allows you to create a custom augmentation pipeline using the [augmenters available in the imagug library](https://imgaug.readthedocs.io/en/latest/source/overview_of_augmenters.html).\n","In the example above, the augmentation pipeline is equivalent to: \n","```\n","seq = iaa.Sequential([\n"," iaa.Sometimes(0.3, iaa.GammaContrast((0.5, 2.0)), \n"," iaa.Sometimes(0.4, iaa.AverageBlur((0.5, 2.0)), \n"," iaa.Sometimes(0.5, iaa.LinearContrast((0.4, 1.6)), \n","], random_order=True)\n","```\n"," Note that there is no limit on the number of augmenters that can be chained together and that individual augmenter and parameter entries must be separated by `;`. Custom augmenters do not overwrite the preset augmentation steps (*Gaussian blur*, *Linear contrast* or *Additive Gaussian noise*). Also, the augmenters, augmenter parameters and augmenter frequencies must be entered such that each position within the string corresponds to the same augmentation step.\n","\n","* **`apply_data_augmentation`** ensures that data augmentation is randomly applied to the training data at each training step. This includes inverting the order of the slices within a training patch, as well as applying any augmenters that are added. *Default: True*\n","\n","* **`add_elastic_deform`** ensures that elastic grid-based deformations are applied as described in the original 3D U-Net paper. *Default: True*"]},{"cell_type":"code","metadata":{"id":"DMqWq5-AxnFU","cellView":"form"},"source":["#@markdown ##**Augmentation options**\n","\n","#@markdown ###Data augmentation\n","\n","apply_data_augmentation = False #@param {type:\"boolean\"}\n","\n","# List of augmentations\n","augmentations = []\n","\n","#@markdown ###Gaussian blur\n","add_gaussian_blur = True #@param {type:\"boolean\"}\n","gaussian_sigma = 0.7#@param {type:\"number\"}\n","gaussian_frequency = 0.5 #@param {type:\"number\"}\n","\n","if add_gaussian_blur:\n"," augmentations.append(iaa.Sometimes(gaussian_frequency, iaa.GaussianBlur(sigma=(0, gaussian_sigma))))\n","\n","#@markdown ###Linear contrast\n","add_linear_contrast = True #@param {type:\"boolean\"}\n","contrast_min = 0.4 #@param {type:\"number\"}\n","contrast_max = 1.6#@param {type:\"number\"}\n","contrast_frequency = 0.5 #@param {type:\"number\"}\n","\n","if add_linear_contrast:\n"," augmentations.append(iaa.Sometimes(contrast_frequency, iaa.LinearContrast((contrast_min, contrast_max))))\n","\n","#@markdown ###Additive Gaussian noise\n","add_additive_gaussian_noise = False #@param {type:\"boolean\"}\n","scale_min = 0 #@param {type:\"number\"}\n","scale_max = 0.05 #@param {type:\"number\"}\n","noise_frequency = 0.5 #@param {type:\"number\"}\n","\n","if add_additive_gaussian_noise:\n"," augmentations.append(iaa.Sometimes(noise_frequency, iaa.AdditiveGaussianNoise(scale=(scale_min, scale_max))))\n","\n","#@markdown ###Add custom augmenters\n","add_custom_augmenters = False #@param {type:\"boolean\"} \n","augmenters = \"\" #@param {type:\"string\"}\n","\n","if add_custom_augmenters:\n","\n"," augmenter_params = \"\" #@param {type:\"string\"}\n","\n"," augmenter_frequency = \"\" #@param {type:\"string\"}\n","\n"," aug_lst = augmenters.split(';')\n"," aug_params_lst = augmenter_params.split(';')\n"," aug_freq_lst = augmenter_frequency.split(';')\n","\n"," assert len(aug_lst) == len(aug_params_lst) and len(aug_lst) == len(aug_freq_lst), 'The number of arguments in augmenters, augmenter_params and augmenter_frequency are not the same!'\n","\n"," for __, (aug, param, freq) in enumerate(zip(aug_lst, aug_params_lst, aug_freq_lst)):\n"," aug, param, freq = aug.strip(), param.strip(), freq.strip() \n"," aug_func = iaa.Sometimes(eval(freq), getattr(iaa, aug)(eval(param)))\n"," augmentations.append(aug_func)\n","\n","#@markdown ###Elastic deformations\n","add_elastic_deform = True #@param {type:\"boolean\"}\n","sigma = 2#@param {type:\"number\"}\n","points = 2#@param {type:\"number\"}\n","order = 2#@param {type:\"number\"}\n","\n","if add_elastic_deform:\n"," deform_params = (sigma, points, order)\n","else:\n"," deform_params = None\n","\n","train_generator = MultiPageTiffGenerator(training_source,\n"," training_target,\n"," batch_size=batch_size,\n"," shape=training_shape,\n"," augment=apply_data_augmentation,\n"," augmentations=augmentations,\n"," deform_augment=add_elastic_deform,\n"," deform_augmentation_params=deform_params,\n"," val_split=validation_split_in_percent/100,\n"," random_crop=random_crop,\n"," downscale=downscaling_in_xy,\n"," binary_target=binary_target)\n","\n","val_generator = MultiPageTiffGenerator(training_source,\n"," training_target,\n"," batch_size=batch_size,\n"," shape=training_shape,\n"," val_split=validation_split_in_percent/100,\n"," is_val=True,\n"," random_crop=random_crop,\n"," downscale=downscaling_in_xy,\n"," binary_target=binary_target)\n","\n","\n","if apply_data_augmentation:\n"," print('Data augmentation enabled.')\n"," sample_src_aug, sample_tgt_aug = train_generator.sample_augmentation(random.randint(0, len(train_generator)))\n","\n"," def scroll_in_z(z):\n"," f=plt.figure(figsize=(16,8))\n"," plt.subplot(1,2,1)\n"," plt.imshow(sample_src_aug[0,:,:,z-1,0], cmap='gray')\n"," plt.title('Sample augmented source (z = ' + str(z) + ')', fontsize=15)\n"," plt.axis('off')\n","\n"," plt.subplot(1,2,2)\n"," plt.imshow(sample_tgt_aug[0,:,:,z-1,0], cmap='magma')\n"," plt.title('Sample training target (z = ' + str(z) + ')', fontsize=15)\n"," plt.axis('off')\n","\n"," print('This is what the augmented training images will look like with the chosen settings')\n"," interact(scroll_in_z, z=widgets.IntSlider(min=1, max=sample_src_aug.shape[3], step=1, value=0));\n","\n","else:\n"," print('Data augmentation disabled.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"MCGklf1vZf2M"},"source":["# **4. Train the network**\n","---\n","\n","**CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training times must be less than 12 hours! If training takes longer than 12 hours, please decrease `number_of_epochs`."]},{"cell_type":"markdown","metadata":{"id":"1KYOuygETJkT"},"source":["## **4.1. Show model and start training**\n","---\n"]},{"cell_type":"code","metadata":{"id":"lIUAOJ_LMv5E","cellView":"form"},"source":["#@markdown ## Show model summary\n","model.summary()"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"CyQI4ssarUp4","cellView":"form"},"source":["#@markdown ##Start training\n","\n","#here we check that no model with the same name already exist, if so delete\n","if not resume_training and os.path.exists(full_model_path): \n"," shutil.rmtree(full_model_path)\n"," print(bcolors.WARNING+'!! WARNING: Folder already exists and has been overwritten !!'+bcolors.NORMAL) \n","\n","if not os.path.exists(full_model_path):\n"," os.makedirs(full_model_path)\n","\n","pdf_export(augmentation = apply_data_augmentation, pretrained_model = resume_training)\n","\n","# Save file\n","params_df.to_csv(os.path.join(full_model_path, 'params.csv'))\n","\n","start = time.time()\n","# Start Training\n","model.train(epochs=number_of_epochs,\n"," batch_size=batch_size,\n"," train_generator=train_generator,\n"," val_generator=val_generator,\n"," model_path=model_path,\n"," model_name=model_name,\n"," loss=loss_function,\n"," metrics=metrics,\n"," optimizer=optimizer,\n"," learning_rate=learning_rate,\n"," ckpt_period=checkpointing_period,\n"," save_best_ckpt_only=save_best_only,\n"," ckpt_path=last_ckpt_path)\n","\n","print('Training successfully completed!')\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","#Create a pdf document with training summary\n","\n","pdf_export(trained = True, augmentation = apply_data_augmentation, pretrained_model = resume_training)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"0Dfn8ZsEMv5d"},"source":["##**4.2. Download your model from Google Drive**\n","\n","---\n","Once training is complete, the trained model is automatically saved to your Google Drive, in the **`model_path`** folder that was specified in Section 3. Download the folder to avoid any unwanted surprises, since the data can be erased if you train another model using the same `model_path`."]},{"cell_type":"code","metadata":{"scrolled":true,"id":"iwNmp1PUzRDQ","cellView":"form"},"source":["#@markdown ##Download model directory\n","#@markdown 1. Specify the model_path in `model_path_download` otherwise the model sepcified in Section 3.1 will be downloaded\n","#@markdown 2. Run this cell to zip the model directory\n","#@markdown 3. Download the zipped file from the *Files* tab on the left\n","\n","from google.colab import files\n","\n","model_path_download = \"\" #@param {type:\"string\"}\n","\n","if len(model_path_download) == 0:\n"," model_path_download = full_model_path\n","\n","model_name_download = os.path.basename(model_path_download)\n","\n","print('Zipping', model_name_download)\n","\n","zip_model_path = model_name_download + '.zip'\n","\n","!zip -r \"$zip_model_path\" \"$model_path_download\"\n","\n","print('Successfully saved zipped model directory as', zip_model_path)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"_0Hynw3-xHp1"},"source":["# **5. Evaluate your model**\n","---\n","\n","In this section the newly trained model can be assessed for performance. This involves inspecting the loss function in Section 5.1. and employing more advanced metrics in Section 5.2.\n","\n","**We highly recommend performing quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"id":"eAJzMwPA6tlH","cellView":"form"},"source":["#@markdown ###Model to be evaluated:\n","#@markdown If left blank, the latest model defined in Section 3 will be evaluated:\n","\n","qc_model_name = \"\" #@param {type:\"string\"}\n","qc_model_path = \"\" #@param {type:\"string\"}\n","\n","if len(qc_model_path) == 0 and len(qc_model_name) == 0:\n"," qc_model_name = model_name\n"," qc_model_path = model_path\n","\n","full_qc_model_path = os.path.join(qc_model_path, qc_model_name)\n","\n","if os.path.exists(full_qc_model_path):\n"," print(qc_model_name + ' will be evaluated')\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"dhJROwlAMv5o"},"source":["## **5.1. Inspecting loss function**\n","---\n","\n","**The training loss** is the error between prediction and target after each epoch calculated across the training data while the **validation loss** calculates the error on the (unseen) validation data. During training these values should decrease until converging at which point the model has been sufficiently trained. If the validation loss starts increasing while the training loss has plateaued, the model has overfit on the training data which reduces its ability to generalise. Aim to halt training before this point.\n","\n","**Note:** For a more in-depth explanation please refer to [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols et al.\n","\n","\n","The accuracy is another performance metric that is calculated after each epoch. We use the [Sørensen–Dice coefficient](https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient) to score the prediction accuracy. \n","\n"]},{"cell_type":"code","metadata":{"id":"vMzSP50kMv5p","cellView":"form"},"source":["#@markdown ##Visualise loss and accuracy\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","accuracyDataFromCSV = []\n","valaccuracyDataFromCSV = []\n","\n","with open(full_qc_model_path + '/Quality Control/training_evaluation.csv', 'r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[2]))\n"," vallossDataFromCSV.append(float(row[4]))\n"," accuracyDataFromCSV.append(float(row[1]))\n"," valaccuracyDataFromCSV.append(float(row[3]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training and validation loss', fontsize=14)\n","plt.ylabel('Loss', fontsize=12)\n","plt.xlabel('Epochs', fontsize=12)\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.plot(epochNumber,accuracyDataFromCSV, label='Training accuracy')\n","plt.plot(epochNumber,valaccuracyDataFromCSV, label='Validation accuracy')\n","plt.title('Training and validation accuracy', fontsize=14)\n","plt.ylabel('Dice', fontsize=12)\n","plt.xlabel('Epochs', fontsize=12)\n","plt.legend()\n","plt.savefig(full_qc_model_path + '/Quality Control/lossCurvePlots.png', bbox_inches='tight', pad_inches=0)\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"X5_92nL2xdP6"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","This section will provide both a visual indication of the model performance by comparing the overlay of the predicted and source volume."]},{"cell_type":"code","metadata":{"id":"w90MdriMxhjD","cellView":"form"},"source":["#@markdown ##Compare prediction and ground-truth on testing data\n","\n","#@markdown Provide an unseen annotated dataset to determine the performance of the model:\n","\n","testing_source = \"\" #@param{type:\"string\"}\n","testing_target = \"\" #@param{type:\"string\"}\n","\n","qc_dir = full_qc_model_path + '/Quality Control'\n","predict_dir = qc_dir + '/Prediction'\n","if os.path.exists(predict_dir):\n"," shutil.rmtree(predict_dir)\n","\n","os.makedirs(predict_dir)\n","\n","# predict_dir + '/' + \n","predict_path = os.path.splitext(os.path.basename(testing_source))[0] + '_prediction.tif'\n","\n","def last_chars(x):\n"," return(x[-11:])\n","\n","try:\n"," ckpt_dir_list = glob(full_qc_model_path + '/ckpt/*')\n"," ckpt_dir_list.sort(key=last_chars)\n"," last_ckpt_path = ckpt_dir_list[0]\n"," print('Predicting from checkpoint:', os.path.basename(last_ckpt_path))\n","except IndexError:\n"," raise CheckpointError('No previous checkpoints were found, please retrain model.')\n","\n","# Load parameters\n","params = pd.read_csv(os.path.join(full_qc_model_path, 'params.csv'), names=['val'], header=0, index_col=0) \n","\n","model = Unet3D(shape=params.loc['training_shape', 'val'])\n","\n","prediction = model.predict(testing_source, last_ckpt_path, downscaling=params.loc['downscaling', 'val'], true_patch_size=params.loc['true_patch_size', 'val'])\n","\n","tifffile.imwrite(predict_path, prediction.astype('float32'), imagej=True)\n","\n","print('Predicted images!')\n","\n","qc_metrics_path = full_qc_model_path + '/Quality Control/QC_metrics_' + qc_model_name + '.csv'\n","\n","test_target = tifffile.imread(testing_target)\n","test_source = tifffile.imread(testing_source)\n","test_prediction = tifffile.imread(predict_path)\n","\n","def scroll_in_z(z):\n","\n"," plt.figure(figsize=(25,5))\n"," # Source\n"," plt.subplot(1,4,1)\n"," plt.axis('off')\n"," plt.imshow(test_source[z-1], cmap='gray')\n"," plt.title('Source (z = ' + str(z) + ')', fontsize=15)\n","\n"," # Target (Ground-truth)\n"," plt.subplot(1,4,2)\n"," plt.axis('off')\n"," plt.imshow(test_target[z-1], cmap='magma')\n"," plt.title('Target (z = ' + str(z) + ')', fontsize=15)\n","\n"," # Prediction\n"," plt.subplot(1,4,3)\n"," plt.axis('off')\n"," plt.imshow(test_prediction[z-1], cmap='magma')\n"," plt.title('Prediction (z = ' + str(z) + ')', fontsize=15)\n"," \n"," # Overlay\n"," plt.subplot(1,4,4)\n"," plt.axis('off')\n"," plt.imshow(test_target[z-1], cmap='Greens')\n"," plt.imshow(test_prediction[z-1], alpha=0.5, cmap='Purples')\n"," plt.title('Overlay (z = ' + str(z) + ')', fontsize=15)\n"," plt.savefig(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/QC_example_data.png', bbox_inches='tight', pad_inches=0)\n","interact(scroll_in_z, z=widgets.IntSlider(min=1, max=test_source.shape[0], step=1, value=0));"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"aIvRxpZlsFeZ"},"source":["## **5.3. Determine best Intersection over Union and threshold**\n","---\n","\n","**Note:** This section is only relevant if the target image is a binary mask and `binary_target` is selected in Section 3! \n","\n","This section will provide both a visual and a quantitative indication of the model performance by comparing the overlay of the predicted and source volume, as well as computing the highest [**Intersection over Union**](https://en.wikipedia.org/wiki/Jaccard_index) (IoU) score. The IoU is also known as the Jaccard Index. \n","\n","The best threshold is calculated using the IoU. Each threshold value from 0 to 255 is tested and the threshold with the highest score is deemed the best. The IoU is calculated for the entire volume in 3D."]},{"cell_type":"code","metadata":{"id":"XhkeZTFusHA8","cellView":"form"},"source":["\n","#@markdown ##Calculate Intersection over Union and best threshold \n","prediction = tifffile.imread(predict_path)\n","prediction = np.interp(prediction, (prediction.min(), prediction.max()), (0, 255))\n","\n","target = tifffile.imread(testing_target).astype(np.bool)\n","\n","def iou_vs_threshold(prediction, target):\n"," threshold_list = []\n"," IoU_scores_list = []\n","\n"," for threshold in range(0,256): \n"," mask = prediction > threshold\n","\n"," intersection = np.logical_and(target, mask)\n"," union = np.logical_or(target, mask)\n"," iou_score = np.sum(intersection) / np.sum(union)\n","\n"," threshold_list.append(threshold)\n"," IoU_scores_list.append(iou_score)\n","\n"," return threshold_list, IoU_scores_list\n","\n","threshold_list, IoU_scores_list = iou_vs_threshold(prediction, target)\n","thresh_arr = np.array(list(zip(threshold_list, IoU_scores_list)))\n","best_thresh = int(np.where(thresh_arr == np.max(thresh_arr[:,1]))[0])\n","best_iou = IoU_scores_list[best_thresh]\n","\n","print('Highest IoU is {:.4f} with a threshold of {}'.format(best_iou, best_thresh))\n","\n","def adjust_threshold(threshold, z):\n","\n"," f=plt.figure(figsize=(25,5))\n"," plt.subplot(1,4,1)\n"," plt.imshow((prediction[z-1] > threshold).astype('uint8'), cmap='magma')\n"," plt.title('Prediction (Threshold = ' + str(threshold) + ')', fontsize=15)\n"," plt.axis('off')\n","\n"," plt.subplot(1,4,2)\n"," plt.imshow(target[z-1], cmap='magma')\n"," plt.title('Target (z = ' + str(z) + ')', fontsize=15)\n"," plt.axis('off')\n","\n"," plt.subplot(1,4,3)\n"," plt.axis('off')\n"," plt.imshow(test_source[z-1], cmap='gray')\n"," plt.imshow((prediction[z-1] > threshold).astype('uint8'), alpha=0.4, cmap='Reds')\n"," plt.title('Overlay (z = ' + str(z) + ')', fontsize=15)\n","\n"," plt.subplot(1,4,4)\n"," plt.title('Threshold vs. IoU', fontsize=15)\n"," plt.plot(threshold_list, IoU_scores_list)\n"," plt.plot(threshold, IoU_scores_list[threshold], 'ro') \n"," plt.ylabel('IoU score')\n"," plt.xlabel('Threshold')\n"," plt.savefig(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/QC_IoU_analysis.png',bbox_inches=matplotlib.transforms.Bbox([[17.5,0],[23,5]]),pad_inches=0)\n"," plt.show()\n","\n","interact(adjust_threshold, \n"," threshold=widgets.IntSlider(min=0, max=255, step=1, value=best_thresh),\n"," z=widgets.IntSlider(min=1, max=prediction.shape[0], step=1, value=0));\n","\n","#Make a pdf summary of the QC results\n","\n","qc_pdf_export()\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"S1Ks6CxfekhY"},"source":["## **5.4. Determine best Intersection over Union and threshold**\n","---\n","\n","**Note:** This section is only relevant if the target image is a binary mask and `binary_target` is selected in Section 3! \n","\n","This section will provide both a visual and a quantitative indication of the model performance by comparing the overlay of the predicted and source volume, as well as computing the highest [**Intersection over Union**](https://en.wikipedia.org/wiki/Jaccard_index) (IoU) score. The IoU is also known as the Jaccard Index. \n","\n","The best threshold is calculated using the IoU. Each threshold value from 0 to 255 is tested and the threshold with the highest score is deemed the best. The IoU is calculated for the entire volume in 3D."]},{"cell_type":"code","metadata":{"cellView":"form","id":"1XCLZp2Xet3x"},"source":["# ####\n","from pydeepimagej.yaml import BioImageModelZooConfig\n","import urllib\n","import warnings\n","warnings. filterwarnings(\"ignore\") \n","\n","# ------------- User input ------------\n","# information about the model\n","#@markdown ##Introduce the metadata of the model architecture:\n","Trained_model_name = \"\" #@param {type:\"string\"}\n","Trained_model_authors = \"[Author 1, Author 2, Author 3]\" #@param {type:\"string\"}\n","\n","Trained_model_description = \"\"#@param {type:\"string\"}\n","Trained_model_license = 'MIT'#@param {type:\"string\"}\n","Trained_model_references = [\"Çiçek, Özgün, et al. MICCAI 2016\", \"Lucas von Chamier et al. biorXiv 2020\"]\n","Trained_model_DOI = [\"https://doi.org/10.1007/978-3-319-46723-8_49\", \"https://doi.org/10.1101/2020.03.20.000133\"]\n","\n","# Add example image information\n","# ---------------------------------------\n","#@markdown ##Choose a threshold for DeepImageJ's postprocessing macro:\n","Use_The_Best_Average_Threshold = True #@param {type:\"boolean\"}\n","#@markdown ###If not, please input:\n","threshold = 155 #@param {type:\"number\"}\n","if Use_The_Best_Average_Threshold:\n"," threshold = best_thresh\n","\n","#@markdown ##Introduce the voxel size (pixel size for each Z-slice and the distance between Z-salices) (in microns) of the image provided as an example of the model processing:\n","# information about the example image\n","PixelSize = 1 #@param {type:\"number\"}\n","Zdistance = 1 #@param {type:\"number\"}\n","#@markdown ##Do you want to choose the exampleimage?\n","default_example_image = True #@param {type:\"boolean\"}\n","#@markdown ###If not, please input:\n","fileID = \"\" #@param {type:\"string\"}\n","if default_example_image:\n"," fileID = testing_source\n"," \n","example_image = tifffile.imread(fileID) \n","# Z-dim first\n","z_size = example_image.shape[0]\n","z_size = np.int(z_size/2)\n","\n","example_image = example_image[z_size-10:z_size+10]\n","path_example_im = \"/content/example_image_biomodelzoo.tif\"\n","tifffile.imsave(path_example_im, example_image)\n","\n","\n","\n","# Load model parameters\n","# ---------------------------------------\n","def last_chars(x):\n"," return(x[-11:])\n","try:\n"," ckpt_dir_list = glob(full_qc_model_path + '/ckpt/*')\n"," ckpt_dir_list.sort(key=last_chars)\n"," last_ckpt_path = ckpt_dir_list[0]\n"," print('Predicting from checkpoint:', os.path.basename(last_ckpt_path))\n","except IndexError:\n"," raise CheckpointError('No previous checkpoints were found, please retrain model.')\n","params = pd.read_csv(os.path.join(full_qc_model_path, 'params.csv'), names=['val'], header=0, index_col=0) \n","\n","\n","\n","# Load the model and process the example image\n","# ---------------------------------------\n","model = Unet3D(shape=params.loc['training_shape', 'val'])\n","# prediction = model.predict(fileID, last_ckpt_path, downscaling=params.loc['downscaling', 'val'], true_patch_size=params.loc['true_patch_size', 'val'])\n","prediction = model.predict(path_example_im, last_ckpt_path, downscaling=params.loc['downscaling', 'val'], true_patch_size=params.loc['true_patch_size', 'val'])\n","\n","prediction = np.interp(prediction, (prediction.min(), prediction.max()), (0, 255))\n","mask = prediction > threshold\n","# Z-dim first\n","# mask = mask[:20]\n","\n","\n","\n","\n","# ------------- Execute bioimage model zoo configuration ------------\n","# Check minimum size: it is [8,8] for the 2D XY plane\n","keras_model = model.model\n","# pooling_steps = 0\n","# for keras_layer in keras_model.layers:\n","# if keras_layer.name.startswith('max') or \"pool\" in keras_layer.name:\n","# pooling_steps += 1\n","# MinimumSize = [2**(pooling_steps), 2**(pooling_steps)]\n","MinimumSize = keras_model.input_shape[1:-1]\n","dij_config = BioImageModelZooConfig(keras_model, MinimumSize)\n","\n","# Model developer details\n","dij_config.Authors = Trained_model_authors[1:-1].split(',')\n","dij_config.Description = Trained_model_description\n","dij_config.Name = Trained_model_name\n","dij_config.References = Trained_model_references\n","dij_config.DOI = Trained_model_DOI\n","dij_config.License = Trained_model_license\n","\n","# Additional information about the model\n","dij_config.GitHub = '/~https://github.com/HenriquesLab/ZeroCostDL4Mic'\n","dij_config.Date = datetime.now()\n","dij_config.Documentation = '/~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki'\n","dij_config.Tags = ['ZeroCostDL4Mic', 'deepimagej', 'segmentation', '3DUNet']\n","dij_config.Framework = 'tensorflow'\n","\n","# Add the information about the test image. Note here PixelSize should be given in microns\n","dij_config.add_test_info(example_image, mask, [PixelSize, PixelSize, Zdistance])\n","dij_config.create_covers([example_image, mask])\n","dij_config.Covers = ['./input.png', './output.png']\n","\n","# Store the model weights\n","# ---------------------------------------\n","# used_bioimage_model_for_training_URL = \"/Some/URI/\"\n","# dij_config.Parent = used_bioimage_model_for_training_URL\n","\n","# Add weights information\n","format_authors = [\"pydeepimagej\"]\n","dij_config.add_weights_formats(keras_model, 'TensorFlow', \n"," parent=\"keras_hdf5\",\n"," authors=[a for a in format_authors])\n","dij_config.add_weights_formats(keras_model, 'KerasHDF5', \n"," authors=[a for a in format_authors])\n","\n","## Preprocessing and postprocessing\n","# -------------------------------------------------\n","## Prepare preprocessing file_ min_max_scaling\n","path_preprocessing = \"per_sample_scale_range.ijm\"\n","urllib.request.urlretrieve(\"https://raw.githubusercontent.com/deepimagej/imagej-macros/master/bioimage.io/per_sample_scale_range.ijm\", path_preprocessing )\n","\n","# Modify the threshold in the macro to the chosen threshold\n","ijmacro = open(path_preprocessing,\"r\") \n","list_of_lines = ijmacro. readlines()\n","# Line 21 is the one corresponding to the optimal threshold\n","list_of_lines[24] = \"min_percentile = 0;\\n\"\n","list_of_lines[25] = \"max_percentile = 100;\\n\"\n","ijmacro.close()\n","ijmacro = open(path_preprocessing,\"w\") \n","ijmacro. writelines(list_of_lines)\n","ijmacro. close()\n","\n","## Prepare postprocessing file\n","path_postprocessing = \"binarize.ijm\"\n","urllib.request.urlretrieve(\"https://raw.githubusercontent.com/deepimagej/imagej-macros/master/bioimage.io/binarize.ijm\", path_postprocessing )\n","\n","# Modify the threshold in the macro to the chosen threshold\n","ijmacro = open(path_postprocessing,\"r\") \n","list_of_lines = ijmacro. readlines()\n","# Line 21 is the one corresponding to the optimal threshold\n","list_of_lines[11] = \"optimalThreshold = {};\\n\".format(threshold/255) # The output in DeepImageJ will not be converted to the range [0,255], so the threshold is adjusted.\n","ijmacro.close()\n","ijmacro = open(path_postprocessing,\"w\") \n","ijmacro. writelines(list_of_lines)\n","ijmacro. close()\n","\n","\n","\n","# Include the info about the macros \n","dij_config.Preprocessing = [path_preprocessing]\n","dij_config.Preprocessing_files = [path_preprocessing]\n","dij_config.Postprocessing = [path_postprocessing]\n","dij_config.Postprocessing_files = [path_postprocessing]\n","dij_config.add_bioimageio_spec('pre-processing', 'percentile', mode='per_sample', axes='xyzc', min_percentile=0, max_percentile=100)\n","dij_config.add_bioimageio_spec('post-processing', 'binarize', threshold=threshold)\n","\n","## EXPORT THE MODEL TO AN EXISTING PATH OR CREATE IT\n","deepimagej_model_path = os.path.join(full_qc_model_path,qc_model_name+'.bioimage.io.model')\n","# if not os.path.exists(deepimagej_model_path):\n","# os.mkdir(deepimagej_model_path)\n","dij_config.export_model(deepimagej_model_path)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-tJeeJjLnRkP"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","Once sufficient performance of the trained model has been established using Section 5, the network can be used to segment unseen volumetric data."]},{"cell_type":"markdown","metadata":{"id":"d8wuQGjoq6eN"},"source":["## **6.1. Generate predictions from unseen dataset**\n","---\n","\n","The most recently trained model can now be used to predict segmentation masks on unseen images. If you want to use an older model, leave `model_path` blank. Predicted output images are saved in `output_path` as Image-J compatible TIFF files.\n","\n","## **Prediction parameters**\n","\n","* **`source_path`** specifies the location of the source \n","image volume.\n","\n","* **`output_directory`** specified the directory where the output predictions are stored.\n","\n","* **`binary_target`** should be chosen if the network is trained to predict binary segmentation masks.\n","\n","* **`threshold`** can be calculated in Section 5 and is used to generate binary masks from the predictions.\n","\n","* **`big_tiff`** should be chosen if the expected prediction exceeds 4GB. The predictions will be saved using the BigTIFF format. Beware that this might substantially reduce the prediction speed. *Default: False* \n","\n","* **`prediction_depth`** is only relevant if the prediction is saved as a BigTIFF. The prediction will not be performed in one go to not deplete the memory resources. Instead, the prediction is iteratively performed on a subset of the entire volume with shape `(source.shape[0], source.shape[1], prediction_depth)`. *Default: 32*\n","\n","* **`model_path`** specifies the path to a model other than the most recently trained."]},{"cell_type":"code","metadata":{"cellView":"form","id":"DEmhPh5fsWX2"},"source":["#@markdown ## Download example volume\n","\n","#@markdown This can take up to an hour\n","\n","import requests \n","import os\n","from tqdm.notebook import tqdm \n","\n","\n","def download_from_url(url, save_as):\n"," file_url = url\n"," r = requests.get(file_url, stream=True) \n"," \n"," with open(save_as, 'wb') as file: \n"," for block in tqdm(r.iter_content(chunk_size = 1024), desc = 'Downloading ' + os.path.basename(save_as), total=3275073, ncols=1000):\n"," if block:\n"," file.write(block) \n","\n","download_from_url('https://documents.epfl.ch/groups/c/cv/cvlab-unit/www/data/%20ElectronMicroscopy_Hippocampus/volumedata.tif', 'example_dataset/volumedata.tif')"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"y2TD5p7MZrEb","cellView":"form"},"source":["#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then run the cell to predict outputs from your unseen images.\n","\n","source_path = \"\" #@param {type:\"string\"}\n","output_directory = \"\" #@param {type:\"string\"}\n","\n","if not os.path.exists(output_directory):\n"," os.makedirs(output_directory)\n","\n","output_path = os.path.join(output_directory, os.path.splitext(os.path.basename(source_path))[0] + '_predicted.tif')\n","#@markdown ###Prediction parameters:\n","\n","binary_target = True #@param {type:\"boolean\"}\n","\n","save_probability_map = False #@param {type:\"boolean\"}\n","\n","#@markdown Determine best threshold in Section 5.2.\n","\n","use_calculated_threshold = True #@param {type:\"boolean\"}\n","threshold = 200#@param {type:\"number\"}\n","\n","# Tifffile library issues means that images cannot be appended to \n","#@markdown Choose if prediction file exceeds 4GB or if input file is very large (above 2GB). Image volume saved as BigTIFF.\n","big_tiff = False #@param {type:\"boolean\"}\n","\n","#@markdown Reduce `prediction_depth` if runtime runs out of memory during prediction. Only relevant if prediction saved as BigTIFF\n","\n","prediction_depth = 32#@param {type:\"number\"}\n","\n","#@markdown ###Model to be evaluated\n","#@markdown If left blank, the latest model defined in Section 5 will be evaluated\n","\n","full_model_path_ = \"\" #@param {type:\"string\"}\n","\n","if len(full_model_path_) == 0:\n"," full_model_path_ = os.path.join(qc_model_path, qc_model_name) \n","\n","\n","\n","# Load parameters\n","params = pd.read_csv(os.path.join(full_model_path_, 'params.csv'), names=['val'], header=0, index_col=0) \n","model = Unet3D(shape=params.loc['training_shape', 'val'])\n","\n","if use_calculated_threshold:\n"," threshold = best_thresh\n","\n","def last_chars(x):\n"," return(x[-11:])\n","\n","try:\n"," ckpt_dir_list = glob(full_model_path_ + '/ckpt/*')\n"," ckpt_dir_list.sort(key=last_chars)\n"," last_ckpt_path = ckpt_dir_list[0]\n"," print('Predicting from checkpoint:', os.path.basename(last_ckpt_path))\n","except IndexError:\n"," raise CheckpointError('No previous checkpoints were found, please retrain model.')\n","\n","src = tifffile.imread(source_path)\n","\n","if src.nbytes >= 4e9:\n"," big_tiff = True\n"," print('The source file exceeds 4GB in memory, prediction will be saved as BigTIFF!')\n","\n","if binary_target:\n"," if not big_tiff:\n"," prediction = model.predict(src, last_ckpt_path, downscaling=params.loc['downscaling', 'val'], true_patch_size=params.loc['true_patch_size', 'val'])\n"," prediction = np.interp(prediction, (prediction.min(), prediction.max()), (0, 255))\n"," prediction = (prediction > threshold).astype('float32')\n","\n"," tifffile.imwrite(output_path, prediction, imagej=True)\n","\n"," else:\n"," with tifffile.TiffWriter(output_path, bigtiff=True) as tif:\n"," for i in tqdm(range(0, src.shape[0], prediction_depth)):\n"," prediction = model.predict(src, last_ckpt_path, z_range=(i,i+prediction_depth), downscaling=params.loc['downscaling', 'val'], true_patch_size=params.loc['true_patch_size', 'val'])\n"," prediction = np.interp(prediction, (prediction.min(), prediction.max()), (0, 255))\n"," prediction = (prediction > threshold).astype('float32')\n"," \n"," for j in range(prediction.shape[0]):\n"," tif.save(prediction[j])\n","\n","if not binary_target or save_probability_map:\n"," if not binary_target:\n"," prob_map_path = output_path\n"," else:\n"," prob_map_path = os.path.splitext(output_path)[0] + '_prob_map.tif'\n"," \n"," if not big_tiff:\n"," prediction = model.predict(src, last_ckpt_path, downscaling=params.loc['downscaling', 'val'], true_patch_size=params.loc['true_patch_size', 'val'])\n"," prediction = np.interp(prediction, (prediction.min(), prediction.max()), (0, 255))\n"," tifffile.imwrite(prob_map_path, prediction.astype('float32'), imagej=True)\n","\n"," else:\n"," with tifffile.TiffWriter(prob_map_path, bigtiff=True) as tif:\n"," for i in tqdm(range(0, src.shape[0], prediction_depth)):\n"," prediction = model.predict(src, last_ckpt_path, z_range=(i,i+prediction_depth), downscaling=params.loc['downscaling', 'val'], true_patch_size=params.loc['true_patch_size', 'val'])\n"," prediction = np.interp(prediction, (prediction.min(), prediction.max()), (0, 255))\n"," \n"," for j in range(prediction.shape[0]):\n"," tif.save(prediction[j])\n","\n","print('Predictions saved as', output_path)\n","\n","src_volume = tifffile.imread(source_path)\n","pred_volume = tifffile.imread(output_path)\n","\n","def scroll_in_z(z):\n"," \n"," f=plt.figure(figsize=(25,5))\n"," plt.subplot(1,2,1)\n"," plt.imshow(src_volume[z-1], cmap='gray')\n"," plt.title('Source (z = ' + str(z) + ')', fontsize=15)\n"," plt.axis('off')\n","\n"," plt.subplot(1,2,2)\n"," plt.imshow(pred_volume[z-1], cmap='magma')\n"," plt.title('Prediction (z = ' + str(z) + ')', fontsize=15)\n"," plt.axis('off')\n","\n","interact(scroll_in_z, z=widgets.IntSlider(min=1, max=src_volume.shape[0], step=1, value=0));\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB"},"source":["## **6.2. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"q3lSeWp3G8eD"},"source":["# **7. Version log**\n","\n","---\n","**v1.13**: \n","* The section 1 and 2 are now swapped for better export of *requirements.txt*. \n","* This version also now includes built-in version check and the version log that you're reading now.\n","* Keras libraries are now imported via TensorFlow.\n","* The learning rate can be changed in section 3.1.\n","\n","---"]},{"cell_type":"markdown","metadata":{"id":"UvSlTaH14s3t"},"source":["\n","#**Thank you for using 3D U-Net!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/Beta notebooks/ZeroCostDL4Mic_Interactive_annotations_Cellpose.ipynb b/Colab_notebooks/Beta notebooks/ZeroCostDL4Mic_Interactive_annotations_Cellpose.ipynb index 93000f9c..85591ac6 100644 --- a/Colab_notebooks/Beta notebooks/ZeroCostDL4Mic_Interactive_annotations_Cellpose.ipynb +++ b/Colab_notebooks/Beta notebooks/ZeroCostDL4Mic_Interactive_annotations_Cellpose.ipynb @@ -1 +1 @@ -{"nbformat":4,"nbformat_minor":0,"metadata":{"accelerator":"GPU","colab":{"name":"ZeroCostDL4Mic_Interactive_annotations_Cellpose.ipynb","provenance":[],"collapsed_sections":[],"toc_visible":true},"kernelspec":{"display_name":"ImJoy Interactive ML","language":"python","name":"imjoy-interactive-ml"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.7.9"}},"cells":[{"cell_type":"markdown","metadata":{"id":"8KzUtudGvysh"},"source":["# **Interactive segmentation**\n","---\n","\n","**Interactive segmentation** is a segmentation tool powered by deep learning and ImJoy that can be used to segment bioimages and was first published by [Ouyang *et al.* in 2021, on F1000R](https://f1000research.com/articles/10-142?s=09#ref-15).\n","\n","**The Original code** is freely available in GitHub:\n","/~https://github.com/imjoy-team/imjoy-interactive-segmentation\n","\n","**Please also cite this original paper when using or developing this notebook.**\n","\n","**!!Currently, this notebook only works with Google Chrome or Firefox!!**\n","\n"]},{"cell_type":"markdown","metadata":{"id":"A2TvmvUFCBJo"},"source":["# **1. Initialise the Colab session**\n","---"]},{"cell_type":"markdown","metadata":{"id":"9ATvHFVpCIZ2"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelerator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"cellView":"form","id":"Bo0sy_5YCQdY"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","\n","%tensorflow_version 1.x\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xh2v4Wuav4fe"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"cellView":"form","id":"zL52lhh5v7tl"},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"yzATrlv6n14H"},"source":["# **2. Install Interactive segmentation**\n","---\n"]},{"cell_type":"code","metadata":{"id":"Xaacxk6suuDP","cellView":"form"},"source":["Notebook_version = ['1.12.7']\n","\n","\n","from builtins import any as b_any\n","\n","def get_requirements_path():\n"," # Store requirements file in 'contents' directory\n"," current_dir = os.getcwd()\n"," dir_count = current_dir.count('/') - 1\n"," path = '../' * (dir_count) + 'requirements.txt'\n"," return path\n","\n","def filter_files(file_list, filter_list):\n"," filtered_list = []\n"," for fname in file_list:\n"," if b_any(fname.split('==')[0] in s for s in filter_list):\n"," filtered_list.append(fname)\n"," return filtered_list\n","\n","def build_requirements_file(before, after):\n"," path = get_requirements_path()\n","\n"," # Exporting requirements.txt for local run\n"," !pip freeze > $path\n","\n"," # Get minimum requirements file\n"," df = pd.read_csv(path, delimiter = \"\\n\")\n"," mod_list = [m.split('.')[0] for m in after if not m in before]\n"," req_list_temp = df.values.tolist()\n"," req_list = [x[0] for x in req_list_temp]\n","\n"," # Replace with package name and handle cases where import name is different to module name\n"," mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]\n"," mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list]\n"," filtered_list = filter_files(req_list, mod_replace_list)\n","\n"," file=open(path,'w')\n"," for item in filtered_list:\n"," file.writelines(item + '\\n')\n","\n"," file.close()\n","\n","import sys\n","before = [str(m) for m in sys.modules]\n","\n","#@markdown ##Install Interactive segmentation\n","import time\n","\n","# Start the clock to measure how long it takes\n","start = time.time()\n","\n","# !pip install -U Werkzeug==1.0.1\n","!pip install git+/~https://github.com/imjoy-team/imjoy-interactive-segmentation@master#egg=imjoy-interactive-trainer\n","!python3 -m ipykernel install --user --name imjoy-interactive-ml --display-name \"ImJoy Interactive ML\"\n","\n","from imjoy_interactive_trainer.imjoy_plugin import start_interactive_segmentation\n","from imjoy_interactive_trainer.interactive_trainer import InteractiveTrainer\n","from imjoy_interactive_trainer.data_utils import download_example_dataset\n","from imjoy_interactive_trainer.imgseg.geojson_utils import geojson_to_masks\n","\n","import os\n","import glob\n","from shutil import copyfile, rmtree\n","from tifffile import imread, imsave\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","from matplotlib import pyplot as plt\n","import cv2\n","from tqdm import tqdm\n","from skimage.util.shape import view_as_windows\n","from skimage import io\n","\n","!pip install cellpose \n","from cellpose import models\n","import numpy as np\n","\n","\n","import random\n","from zipfile import ZIP_DEFLATED\n","import csv\n","import pandas as pd\n","\n","\n","def PrepareDataAsPatches(Training_source, patch_width, patch_height, Data_tag):\n","\n"," # Here we assume that the Train and Test folders are already created\n"," patch_num = 0\n","\n"," for file in tqdm(os.listdir(Training_source)):\n"," \n"," if os.path.isfile(os.path.join(Training_source, file)):\n"," img = io.imread(os.path.join(Training_source, file))\n"," _,this_ext = os.path.splitext(file)\n","\n"," if len(img.shape) == 2:\n"," # Using view_as_windows with step size equal to the patch size to ensure there is no overlap\n"," patches_img = view_as_windows(img, (patch_width, patch_height), (patch_width, patch_height))\n"," patches_img = patches_img.reshape(patches_img.shape[0]*patches_img.shape[1], patch_width, patch_height)\n","\n"," elif len(img.shape) == 3:\n"," # Using view_as_windows with step size equal to the patch size to ensure there is no overlap\n"," patches_img = view_as_windows(img, (patch_width, patch_height, img.shape[2]), (patch_width, patch_height, img.shape[2]))\n"," patches_img = patches_img.reshape(patches_img.shape[0]*patches_img.shape[1], patch_width, patch_height, img.shape[2])\n","\n"," else:\n"," patches_img = []\n"," print('Data format currently unsupported.')\n","\n"," for i in range(patches_img.shape[0]):\n"," save_path = os.path.join(os.path.splitext(Training_source)[0], 'test','Patch_'+str(patch_num)+' - ' + os.path.splitext(os.path.basename(file))[0])\n"," os.mkdir(save_path)\n"," img_save_path = os.path.join(save_path, Data_tag)\n","\n"," if (len(patches_img[i].shape) == 2):\n"," this_image = np.repeat(patches_img[i][:,:,np.newaxis], repeats=3, axis=2) \n"," else:\n"," this_image = patches_img[i] \n","\n"," # Convert to 8-bit to save as png\n"," this_image =(this_image/this_image.max()*255).astype('uint8')\n"," io.imsave(img_save_path, this_image, check_contrast = False)\n","\n"," # Save raw images patches, preserving format and bit depth\n"," img_save_path_raw = os.path.join(save_path, 'Raw_data'+this_ext)\n"," io.imsave(img_save_path_raw, patches_img[i], check_contrast = False)\n","\n"," patch_num += 1\n","\n","\n","def get_image_list(Folder_path, extension_list = ['*.jpg', '*.tif', '*.png']):\n"," image_list = []\n"," for ext in extension_list:\n"," image_list = image_list + glob.glob(Folder_path+\"/\"+ext)\n","\n"," n_files = len(image_list)\n"," print('Number of files: '+str(n_files))\n","\n"," filenames_list = []\n"," for img_name in image_list:\n"," filenames_list.append(os.path.basename(img_name))\n","\n"," return image_list, filenames_list\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n"," NORMAL = '\\033[0m' # white (normal)\n","\n","# Check if this is the latest version of the notebook\n","Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n","\n","print('Notebook version: '+Notebook_version[0])\n","\n","strlist = Notebook_version[0].split('.')\n","Notebook_version_main = strlist[0]+'.'+strlist[1]\n","\n","if Notebook_version_main == Latest_notebook_version.columns:\n"," print(\"This notebook is up-to-date.\")\n","else:\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\"+ bcolors.NORMAL)\n","\n","\n","\n","# Build requirements file for local run\n","after = [str(m) for m in sys.modules]\n","build_requirements_file(before, after)\n","\n","# ---------------------------- Display ----------------------------\n","# Displaying the time elapsed for installation\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds,1),\"sec(s)\")\n","\n","\n","print(\"-----------\")\n","print(\"Interactive segmentation installed.\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"7dCOBJqkCuSf"},"source":["# **3. Select your parameters and paths**\n","---"]},{"cell_type":"markdown","metadata":{"id":"PE6tkvQQC06f"},"source":["## **3.1. Set and prepare dataset**\n","---\n","\n","**WARNING: Currently this notebook only builds 'Grayscale' Cellpose models. So please provide only grayscale equivalent dataset. WARNING.**\n","\n","**`Data_folder:`:** This is the path to the data to use for interactive annotation. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below. \n","\n","**`Use_patch:`:** This option splits the data available into patches of **`patch_size`** x **`patch_size`**. This allows to make all data consistent and formatted. We recommend to always use this option for stability. \n","\n","**`Reset_data:`:** Resetting the data will empty the training data folder and remove all the annotations available from previous uses.\n","\n","**`Use_example_data:`:** This will download and use the example data provided by [Ouyang *et al.* in 2021, on F1000R](https://f1000research.com/articles/10-142?s=09#ref-15).\n"]},{"cell_type":"code","metadata":{"id":"UgRFCIVFVa86","cellView":"form"},"source":["#@markdown ###**Prepare data**\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","#@markdown ###Split the data in small non-verlapping patches?\n","Use_patches = True #@param {type:\"boolean\"}\n","\n","patch_size = 256#@param {type:\"integer\"}\n","\n","#@markdown ###Reset the data? (**!annotations will be lost!**)\n","Reset_data = False #@param {type:\"boolean\"}\n","#@markdown ###Otherwise, use example data\n","Use_example_dataset = False #@param {type:\"boolean\"}\n","\n","\n","if Use_example_dataset:\n"," Data_folder = \"/content/data/hpa_dataset_v2\"\n"," download_example_dataset()\n"," Data_split = [\"microtubules.png\", \"er.png\", \"nuclei.png\"]\n"," channels=[2, 3]\n","\n","\n","else:\n"," Data_tag = \"data.png\" #Kaibu works best with PNGs!\n"," Data_split = [Data_tag]\n"," channels=[0, 0] # grayscale images without nuclei channel\n","\n"," if (Reset_data) and (os.path.exists(os.path.join(Data_folder, \"train\"))):\n"," rmtree(os.path.join(Data_folder, \"train\"))\n","\n"," if (Reset_data) and (os.path.exists(os.path.join(Data_folder, \"test\"))):\n"," rmtree(os.path.join(Data_folder, \"test\"))\n","\n"," if (os.path.exists(os.path.join(Data_folder, \"train\"))) and (os.path.exists(os.path.join(Data_folder, \"test\"))):\n"," print(\"Kaibu data already exist. Starting from these annotations!\")\n"," else:\n"," print(\"Creating new folders!\")\n","\n"," os.mkdir(os.path.join(Data_folder, \"train\"))\n"," os.mkdir(os.path.join(Data_folder, \"test\"))\n","\n"," if Use_patches:\n"," PrepareDataAsPatches(Data_folder, patch_size,patch_size, Data_tag)\n"," else:\n"," image_list, _ = get_image_list(Data_folder)\n"," # jpeg_image_list = glob.glob(Data_folder+\"/*.jpg\")\n"," n_files = len(image_list)\n"," print(\"Total number of files: \"+str(n_files))\n","\n"," for image in image_list:\n"," save_path = os.path.join(Data_folder, \"test\",os.path.splitext(os.path.basename(image))[0])\n"," os.mkdir(save_path)\n"," copyfile(image, save_path+\"/\"+Data_tag)\n","\n","\n","extension_list = ['*.jpg', '*.tif', '*.png']\n","image_list, filenames_list = get_image_list(Data_folder, extension_list)\n","\n","\n","if len(filenames_list) > 0:\n"," # ------------- For display ------------\n"," print('--------------------------------------------------------------')\n"," @interact\n"," def show_example_data(name = filenames_list):\n","\n"," plt.figure(figsize=(13,10))\n"," img = io.imread(os.path.join(Data_folder, name))\n","\n"," plt.imshow(img, cmap='gray')\n"," plt.title('Source image ('+str(img.shape[0])+'x'+str(img.shape[1])+')')\n"," plt.axis('off')\n","\n","\n","\n","\n","print(\"------\")\n","print(\"Data prepared for interactive segmentation.\")\n","\n","\n","\n","\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"6krdgD5IEeMC"},"source":["## **3.2. Prepare the Cellpose model**\n","---\n","\n","**`Model_folder:`:** Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**`Model_name:`:** Name of the model, the notebook will create a folder with this name within which the model will be saved.\n","\n","**`default_diameter:`:** Indicate the diameter of the objects (cells or Nuclei) you want to segment. If you input \"0\", this parameter will be estimated automatically for each of your images.\n","\n","**`model_type:`:** This is the Cellpose model that will be loaded initially and from which it will train from. This will allow to run reasonnable predictions even with no additional training data.\n","\n"]},{"cell_type":"code","metadata":{"id":"aSlO_wVmhKwo","cellView":"form"},"source":["#@markdown ###**Prepare model**\n","\n","Model_folder = \"\" #@param {type:\"string\"}\n","Model_name = \"\" #@param {type:\"string\"}\n","\n","if (os.path.exists(os.path.join(Model_folder, Model_name))):\n"," print(bcolors.WARNING +\"Model folder already exists and will be deleted at next step.\"+bcolors.NORMAL)\n","\n","\n","default_diameter = 0 #@param {type:\"number\"}\n","model_type = \"default\" #@param [\"cyto\", \"nuclei\", \"default\", \"none\",\"Own model\"]\n","\n","#@markdown ###**If using your own model, please select the path to the model**\n","own_model_path = \"\" #@param {type:\"string\"}\n","\n","\n","\n","resume = True\n","pretrained_model = None\n","\n","\n","if (model_type == \"default\"):\n"," model_type = 'cyto'\n","\n","if (model_type == \"none\"):\n"," model_type = 'cyto'\n"," resume = False\n","\n","if (model_type == \"Own model\"):\n"," model_type = None\n"," pretrained_model = own_model_path\n","\n","\n","\n","if (default_diameter == 0):\n"," default_diameter = None\n","\n","\n","print(\"------\")\n","# print(model_type)\n","# print(pretrained_model)\n","print(\"Model prepared.\")\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"qpv5orLZuuDR"},"source":["#**4. Interactive segmentation**\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"j2dOSJGFGHCP"},"source":["## **4.1. Run interactive segmentation interface**\n","---\n","\n"," This will start the interactive segmentation interface using ImJoy and Kaibu. \n","\n","* Get an image\n","* Predict on that image (using the pretrained model selected above)\n","* Edit the segmentations using the Selection and Draw tools\n","* You can save the annotations at any time\n","* When you're happy with the annotations, you can send the data for training\n","* You can then start the training\n","* Get a new image and annotate as above, send for training and repeat\n","\n"," The training will be running in the background as you annotate and send more training data. This will both generate high quality training data while building an increasingly good model.\n","\n"," The Kaibu interface can be made fullscreen by clicking on the three dots on the top right of the cell, and select the Full screen option. Then, it needs to be minimized and maximised again.\n"]},{"cell_type":"code","metadata":{"id":"02YTQ_ErW6NY","cellView":"form"},"source":["# @markdown #Start interactive segmentation interface\n","\n","# Restart the trainer if necessary\n","instance_exist = True\n","try:\n"," trainer = InteractiveTrainer.get_instance()\n","except:\n"," instance_exist = False\n","\n","if instance_exist:\n"," print(\"Trainer already exists. Restarting trainer!\")\n"," trainer.stop()\n","\n","\n","if (os.path.exists(os.path.join(Model_folder, Model_name))):\n"," print('Deleting pre-existing model folder...')\n"," rmtree(os.path.join(Model_folder, Model_name))\n","\n","os.mkdir(os.path.join(Model_folder, Model_name))\n","\n","model_config = dict(type=\"cellpose\",\n"," model_dir=os.path.join(Model_folder, Model_name),\n"," use_gpu=True,\n"," channels=[0, 0],\n"," style_on=0,\n"," batch_size=1,\n"," default_diameter = default_diameter,\n"," pretrained_model = pretrained_model,\n"," model_type = model_type,\n"," resume = resume)\n","\n","start_interactive_segmentation(model_config,\n"," Data_folder,\n"," Data_split,\n"," object_name=\"cell\",\n"," scale_factor=1.0,\n"," restore_test_annotation=True)\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"9kucqOfCGOpa"},"source":["## **4.2. Create masks from annotations**\n","---\n","\n"," This cell will allow you to create and visualise instance segmentation masks. It will be created from the annotations made from the interface and will be saved into a folder called **`Paired training dataset`** in your data folder (**`Data_folder`**). This data can be used to train another segmentation model if necessary."]},{"cell_type":"code","metadata":{"id":"fMSEf9cQ0Nb9","cellView":"form"},"source":["#@markdown ##Create and check annotated training images\n","if (os.path.exists(os.path.join(Data_folder,'Paired training dataset'))):\n"," rmtree(os.path.join(Data_folder,'Paired training dataset'))\n","\n","os.mkdir(os.path.join(Data_folder,'Paired training dataset'))\n","os.mkdir(os.path.join(Data_folder,'Paired training dataset','Images'))\n","os.mkdir(os.path.join(Data_folder,'Paired training dataset','Masks'))\n","\n","\n","dir_list = os.listdir(os.path.join(Data_folder, \"train\"))\n","# _, ext = os.path.splitext(Data_tag)\n","\n","for dir in dir_list:\n"," annotation_file = os.path.join(Data_folder, \"train\", dir, \"annotation.json\") \n"," mask_dict = geojson_to_masks(annotation_file, mask_types=[\"labels\"]) \n"," labels = mask_dict[\"labels\"]\n"," imsave(os.path.join(Data_folder, \"train\", dir, \"label.tif\"), labels)\n","\n"," imsave(os.path.join(Data_folder,'Paired training dataset','Masks', dir+\".tif\"), labels)\n","\n","\n"," file_list = os.listdir(os.path.join(Data_folder, \"train\", dir))\n"," for file in file_list:\n"," filename, this_ext = os.path.splitext(file)\n"," if filename == 'Raw_data':\n"," copyfile(os.path.join(Data_folder, \"train\", dir, file), os.path.join(Data_folder,'Paired training dataset','Images', dir+this_ext))\n","\n"," # raw_data_tag = glob.glob(dir+\"/Raw_data.*\")\n"," # print(raw_data_tag)\n"," # copyfile(os.path.join(Data_folder, \"train\", dir, Data_tag), os.path.join(Data_folder,'Paired training dataset','Images', dir+ext))\n","\n","\n","\n","\n","# ------------- For display ------------\n","print('--------------------------------------------------------------')\n","@interact\n","def show_labels(dir=dir_list):\n"," plt.figure(figsize=(13,10))\n","\n"," imgSource = cv2.imread(os.path.join(Data_folder, \"train\", dir, Data_tag))\n"," imgLabel = imread(os.path.join(Data_folder, \"train\", dir, \"label.tif\"))\n","\n"," plt.subplot(121)\n"," plt.imshow(imgSource, cmap='gray', interpolation='nearest')\n"," plt.title('Source image')\n"," plt.axis('off')\n"," plt.subplot(122)\n"," plt.imshow(imgLabel, cmap='nipy_spectral', interpolation='nearest')\n"," plt.title('Label')\n"," plt.axis('off');\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"yrfFxMalGyE9"},"source":["## **4.3. Stop training**\n","---\n","\n"," Once training has started, the training will carry on until stopped. Here, the training can be stopped. This will automatically create the final model, which can be used for Quality Control (Section 5 below) and for predictions (Section 6 below).\n"]},{"cell_type":"code","metadata":{"id":"yC-u5eKQHUHC","cellView":"form"},"source":["#@markdown ##Stop training\n","\n","# Stop the trainer if it exists\n","instance_exist = True\n","try:\n"," trainer = InteractiveTrainer.get_instance()\n","except:\n"," instance_exist = False\n","\n","print('-------------')\n","if instance_exist:\n"," print(\"Trainer stopped.\")\n"," trainer.stop()\n","else:\n"," print(\"No trainers currently running.\")\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"krBTcbMlI91b"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"RcdqcaPnc5C9","cellView":"form"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, indicate which model you want to assess:\n","\n","QC_model_path = \"\" #@param {type:\"string\"}\n","\n","\n","if Use_the_current_trained_model :\n","\n"," QC_model_path = Model_folder+\"/\"+Model_name+\"/final\"\n","\n"," #model = models.CellposeModel(gpu=True, pretrained_model=QC_model_path, torch=False, diam_mean=30.0, net_avg=True, device=None, residual_on=True, style_on=True, concatenation=False)\n"," model = models.CellposeModel(gpu=True, pretrained_model=QC_model_path, torch=True, diam_mean=30.0, net_avg=True, device=None, residual_on=True, style_on=True, concatenation=False)\n","\n"," QC_model_folder = os.path.dirname(QC_model_path)\n"," QC_model_name = os.path.basename(QC_model_folder)\n"," Saving_path = QC_model_folder\n","\n"," print(\"The \"+str(QC_model_name)+\" model will be evaluated\")\n","\n","else:\n","\n"," if os.path.exists(QC_model_path):\n"," model = models.CellposeModel(gpu=True, pretrained_model=QC_model_path, torch=True, diam_mean=30.0, net_avg=True, device=None, residual_on=True, style_on=True, concatenation=False)\n"," \n"," QC_model_folder = os.path.dirname(QC_model_path)\n"," Saving_path = QC_model_folder\n"," QC_model_name = os.path.basename(QC_model_folder)\n"," print(\"The \"+str(QC_model_name)+\" model will be evaluated\")\n"," \n"," else: \n"," print(bcolors.WARNING+'!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path before proceeding further.')\n","\n","\n","\n","# Here we load the def that perform the QC, code taken from StarDist /~https://github.com/mpicbg-csbd/stardist/blob/master/stardist/matching.py\n","\n","import numpy as np\n","from numba import jit\n","from tqdm import tqdm\n","from scipy.optimize import linear_sum_assignment\n","from collections import namedtuple\n","\n","\n","matching_criteria = dict()\n","\n","def label_are_sequential(y):\n"," \"\"\" returns true if y has only sequential labels from 1... \"\"\"\n"," labels = np.unique(y)\n"," return (set(labels)-{0}) == set(range(1,1+labels.max()))\n","\n","\n","def is_array_of_integers(y):\n"," return isinstance(y,np.ndarray) and np.issubdtype(y.dtype, np.integer)\n","\n","\n","def _check_label_array(y, name=None, check_sequential=False):\n"," err = ValueError(\"{label} must be an array of {integers}.\".format(\n"," label = 'labels' if name is None else name,\n"," integers = ('sequential ' if check_sequential else '') + 'non-negative integers',\n"," ))\n"," is_array_of_integers(y) or print(\"An error occured\")\n"," if check_sequential:\n"," label_are_sequential(y) or print(\"An error occured\")\n"," else:\n"," y.min() >= 0 or print(\"An error occured\")\n"," return True\n","\n","\n","def label_overlap(x, y, check=True):\n"," if check:\n"," _check_label_array(x,'x',True)\n"," _check_label_array(y,'y',True)\n"," x.shape == y.shape or _raise(ValueError(\"x and y must have the same shape\"))\n"," return _label_overlap(x, y)\n","\n","@jit(nopython=True)\n","def _label_overlap(x, y):\n"," x = x.ravel()\n"," y = y.ravel()\n"," overlap = np.zeros((1+x.max(),1+y.max()), dtype=np.uint)\n"," for i in range(len(x)):\n"," overlap[x[i],y[i]] += 1\n"," return overlap\n","\n","\n","def intersection_over_union(overlap):\n"," _check_label_array(overlap,'overlap')\n"," if np.sum(overlap) == 0:\n"," return overlap\n"," n_pixels_pred = np.sum(overlap, axis=0, keepdims=True)\n"," n_pixels_true = np.sum(overlap, axis=1, keepdims=True)\n"," return overlap / (n_pixels_pred + n_pixels_true - overlap)\n","\n","matching_criteria['iou'] = intersection_over_union\n","\n","\n","def intersection_over_true(overlap):\n"," _check_label_array(overlap,'overlap')\n"," if np.sum(overlap) == 0:\n"," return overlap\n"," n_pixels_true = np.sum(overlap, axis=1, keepdims=True)\n"," return overlap / n_pixels_true\n","\n","matching_criteria['iot'] = intersection_over_true\n","\n","\n","def intersection_over_pred(overlap):\n"," _check_label_array(overlap,'overlap')\n"," if np.sum(overlap) == 0:\n"," return overlap\n"," n_pixels_pred = np.sum(overlap, axis=0, keepdims=True)\n"," return overlap / n_pixels_pred\n","\n","matching_criteria['iop'] = intersection_over_pred\n","\n","\n","def precision(tp,fp,fn):\n"," return tp/(tp+fp) if tp > 0 else 0\n","def recall(tp,fp,fn):\n"," return tp/(tp+fn) if tp > 0 else 0\n","def accuracy(tp,fp,fn):\n"," # also known as \"average precision\" (?)\n"," # -> https://www.kaggle.com/c/data-science-bowl-2018#evaluation\n"," return tp/(tp+fp+fn) if tp > 0 else 0\n","def f1(tp,fp,fn):\n"," # also known as \"dice coefficient\"\n"," return (2*tp)/(2*tp+fp+fn) if tp > 0 else 0\n","\n","\n","def _safe_divide(x,y):\n"," return x/y if y>0 else 0.0\n","\n","def matching(y_true, y_pred, thresh=0.5, criterion='iou', report_matches=False):\n"," \"\"\"Calculate detection/instance segmentation metrics between ground truth and predicted label images.\n"," Currently, the following metrics are implemented:\n"," 'fp', 'tp', 'fn', 'precision', 'recall', 'accuracy', 'f1', 'criterion', 'thresh', 'n_true', 'n_pred', 'mean_true_score', 'mean_matched_score', 'panoptic_quality'\n"," Corresponding objects of y_true and y_pred are counted as true positives (tp), false positives (fp), and false negatives (fn)\n"," whether their intersection over union (IoU) >= thresh (for criterion='iou', which can be changed)\n"," * mean_matched_score is the mean IoUs of matched true positives\n"," * mean_true_score is the mean IoUs of matched true positives but normalized by the total number of GT objects\n"," * panoptic_quality defined as in Eq. 1 of Kirillov et al. \"Panoptic Segmentation\", CVPR 2019\n"," Parameters\n"," ----------\n"," y_true: ndarray\n"," ground truth label image (integer valued)\n"," predicted label image (integer valued)\n"," thresh: float\n"," threshold for matching criterion (default 0.5)\n"," criterion: string\n"," matching criterion (default IoU)\n"," report_matches: bool\n"," if True, additionally calculate matched_pairs and matched_scores (note, that this returns even gt-pred pairs whose scores are below 'thresh')\n"," Returns\n"," -------\n"," Matching object with different metrics as attributes\n"," Examples\n"," --------\n"," >>> y_true = np.zeros((100,100), np.uint16)\n"," >>> y_true[10:20,10:20] = 1\n"," >>> y_pred = np.roll(y_true,5,axis = 0)\n"," >>> stats = matching(y_true, y_pred)\n"," >>> print(stats)\n"," Matching(criterion='iou', thresh=0.5, fp=1, tp=0, fn=1, precision=0, recall=0, accuracy=0, f1=0, n_true=1, n_pred=1, mean_true_score=0.0, mean_matched_score=0.0, panoptic_quality=0.0)\n"," \"\"\"\n"," _check_label_array(y_true,'y_true')\n"," _check_label_array(y_pred,'y_pred')\n"," y_true.shape == y_pred.shape or _raise(ValueError(\"y_true ({y_true.shape}) and y_pred ({y_pred.shape}) have different shapes\".format(y_true=y_true, y_pred=y_pred)))\n"," criterion in matching_criteria or _raise(ValueError(\"Matching criterion '%s' not supported.\" % criterion))\n"," if thresh is None: thresh = 0\n"," thresh = float(thresh) if np.isscalar(thresh) else map(float,thresh)\n","\n"," y_true, _, map_rev_true = relabel_sequential(y_true)\n"," y_pred, _, map_rev_pred = relabel_sequential(y_pred)\n","\n"," overlap = label_overlap(y_true, y_pred, check=False)\n"," scores = matching_criteria[criterion](overlap)\n"," assert 0 <= np.min(scores) <= np.max(scores) <= 1\n","\n"," # ignoring background\n"," scores = scores[1:,1:]\n"," n_true, n_pred = scores.shape\n"," n_matched = min(n_true, n_pred)\n","\n"," def _single(thr):\n"," not_trivial = n_matched > 0 and np.any(scores >= thr)\n"," if not_trivial:\n"," # compute optimal matching with scores as tie-breaker\n"," costs = -(scores >= thr).astype(float) - scores / (2*n_matched)\n"," true_ind, pred_ind = linear_sum_assignment(costs)\n"," assert n_matched == len(true_ind) == len(pred_ind)\n"," match_ok = scores[true_ind,pred_ind] >= thr\n"," tp = np.count_nonzero(match_ok)\n"," else:\n"," tp = 0\n"," fp = n_pred - tp\n"," fn = n_true - tp\n"," # assert tp+fp == n_pred\n"," # assert tp+fn == n_true\n","\n"," # the score sum over all matched objects (tp)\n"," sum_matched_score = np.sum(scores[true_ind,pred_ind][match_ok]) if not_trivial else 0.0\n","\n"," # the score average over all matched objects (tp)\n"," mean_matched_score = _safe_divide(sum_matched_score, tp)\n"," # the score average over all gt/true objects\n"," mean_true_score = _safe_divide(sum_matched_score, n_true)\n"," panoptic_quality = _safe_divide(sum_matched_score, tp+fp/2+fn/2)\n","\n"," stats_dict = dict (\n"," criterion = criterion,\n"," thresh = thr,\n"," fp = fp,\n"," tp = tp,\n"," fn = fn,\n"," precision = precision(tp,fp,fn),\n"," recall = recall(tp,fp,fn),\n"," accuracy = accuracy(tp,fp,fn),\n"," f1 = f1(tp,fp,fn),\n"," n_true = n_true,\n"," n_pred = n_pred,\n"," mean_true_score = mean_true_score,\n"," mean_matched_score = mean_matched_score,\n"," panoptic_quality = panoptic_quality,\n"," )\n"," if bool(report_matches):\n"," if not_trivial:\n"," stats_dict.update (\n"," # int() to be json serializable\n"," matched_pairs = tuple((int(map_rev_true[i]),int(map_rev_pred[j])) for i,j in zip(1+true_ind,1+pred_ind)),\n"," matched_scores = tuple(scores[true_ind,pred_ind]),\n"," matched_tps = tuple(map(int,np.flatnonzero(match_ok))),\n"," )\n"," else:\n"," stats_dict.update (\n"," matched_pairs = (),\n"," matched_scores = (),\n"," matched_tps = (),\n"," )\n"," return namedtuple('Matching',stats_dict.keys())(*stats_dict.values())\n","\n"," return _single(thresh) if np.isscalar(thresh) else tuple(map(_single,thresh))\n","\n","\n","\n","def matching_dataset(y_true, y_pred, thresh=0.5, criterion='iou', by_image=False, show_progress=True, parallel=False):\n"," \"\"\"matching metrics for list of images, see `stardist.matching.matching`\n"," \"\"\"\n"," len(y_true) == len(y_pred) or _raise(ValueError(\"y_true and y_pred must have the same length.\"))\n"," return matching_dataset_lazy (\n"," tuple(zip(y_true,y_pred)), thresh=thresh, criterion=criterion, by_image=by_image, show_progress=show_progress, parallel=parallel,\n"," )\n","\n","\n","\n","def matching_dataset_lazy(y_gen, thresh=0.5, criterion='iou', by_image=False, show_progress=True, parallel=False):\n","\n"," expected_keys = set(('fp', 'tp', 'fn', 'precision', 'recall', 'accuracy', 'f1', 'criterion', 'thresh', 'n_true', 'n_pred', 'mean_true_score', 'mean_matched_score', 'panoptic_quality'))\n","\n"," single_thresh = False\n"," if np.isscalar(thresh):\n"," single_thresh = True\n"," thresh = (thresh,)\n","\n"," tqdm_kwargs = {}\n"," tqdm_kwargs['disable'] = not bool(show_progress)\n"," if int(show_progress) > 1:\n"," tqdm_kwargs['total'] = int(show_progress)\n","\n"," # compute matching stats for every pair of label images\n"," if parallel:\n"," from concurrent.futures import ThreadPoolExecutor\n"," fn = lambda pair: matching(*pair, thresh=thresh, criterion=criterion, report_matches=False)\n"," with ThreadPoolExecutor() as pool:\n"," stats_all = tuple(pool.map(fn, tqdm(y_gen,**tqdm_kwargs)))\n"," else:\n"," stats_all = tuple (\n"," matching(y_t, y_p, thresh=thresh, criterion=criterion, report_matches=False)\n"," for y_t,y_p in tqdm(y_gen,**tqdm_kwargs)\n"," )\n","\n"," # accumulate results over all images for each threshold separately\n"," n_images, n_threshs = len(stats_all), len(thresh)\n"," accumulate = [{} for _ in range(n_threshs)]\n"," for stats in stats_all:\n"," for i,s in enumerate(stats):\n"," acc = accumulate[i]\n"," for k,v in s._asdict().items():\n"," if k == 'mean_true_score' and not bool(by_image):\n"," # convert mean_true_score to \"sum_matched_score\"\n"," acc[k] = acc.setdefault(k,0) + v * s.n_true\n"," else:\n"," try:\n"," acc[k] = acc.setdefault(k,0) + v\n"," except TypeError:\n"," pass\n","\n"," # normalize/compute 'precision', 'recall', 'accuracy', 'f1'\n"," for thr,acc in zip(thresh,accumulate):\n"," set(acc.keys()) == expected_keys or _raise(ValueError(\"unexpected keys\"))\n"," acc['criterion'] = criterion\n"," acc['thresh'] = thr\n"," acc['by_image'] = bool(by_image)\n"," if bool(by_image):\n"," for k in ('precision', 'recall', 'accuracy', 'f1', 'mean_true_score', 'mean_matched_score', 'panoptic_quality'):\n"," acc[k] /= n_images\n"," else:\n"," tp, fp, fn, n_true = acc['tp'], acc['fp'], acc['fn'], acc['n_true']\n"," sum_matched_score = acc['mean_true_score']\n","\n"," mean_matched_score = _safe_divide(sum_matched_score, tp)\n"," mean_true_score = _safe_divide(sum_matched_score, n_true)\n"," panoptic_quality = _safe_divide(sum_matched_score, tp+fp/2+fn/2)\n","\n"," acc.update(\n"," precision = precision(tp,fp,fn),\n"," recall = recall(tp,fp,fn),\n"," accuracy = accuracy(tp,fp,fn),\n"," f1 = f1(tp,fp,fn),\n"," mean_true_score = mean_true_score,\n"," mean_matched_score = mean_matched_score,\n"," panoptic_quality = panoptic_quality,\n"," )\n","\n"," accumulate = tuple(namedtuple('DatasetMatching',acc.keys())(*acc.values()) for acc in accumulate)\n"," return accumulate[0] if single_thresh else accumulate\n","\n","\n","\n","# copied from scikit-image master for now (remove when part of a release)\n","def relabel_sequential(label_field, offset=1):\n"," \"\"\"Relabel arbitrary labels to {`offset`, ... `offset` + number_of_labels}.\n"," This function also returns the forward map (mapping the original labels to\n"," the reduced labels) and the inverse map (mapping the reduced labels back\n"," to the original ones).\n"," Parameters\n"," ----------\n"," label_field : numpy array of int, arbitrary shape\n"," An array of labels, which must be non-negative integers.\n"," offset : int, optional\n"," The return labels will start at `offset`, which should be\n"," strictly positive.\n"," Returns\n"," -------\n"," relabeled : numpy array of int, same shape as `label_field`\n"," The input label field with labels mapped to\n"," {offset, ..., number_of_labels + offset - 1}.\n"," The data type will be the same as `label_field`, except when\n"," offset + number_of_labels causes overflow of the current data type.\n"," forward_map : numpy array of int, shape ``(label_field.max() + 1,)``\n"," The map from the original label space to the returned label\n"," space. Can be used to re-apply the same mapping. See examples\n"," for usage. The data type will be the same as `relabeled`.\n"," inverse_map : 1D numpy array of int, of length offset + number of labels\n"," The map from the new label space to the original space. This\n"," can be used to reconstruct the original label field from the\n"," relabeled one. The data type will be the same as `relabeled`.\n"," Notes\n"," -----\n"," The label 0 is assumed to denote the background and is never remapped.\n"," The forward map can be extremely big for some inputs, since its\n"," length is given by the maximum of the label field. However, in most\n"," situations, ``label_field.max()`` is much smaller than\n"," ``label_field.size``, and in these cases the forward map is\n"," guaranteed to be smaller than either the input or output images.\n"," Examples\n"," --------\n"," >>> from skimage.segmentation import relabel_sequential\n"," >>> label_field = np.array([1, 1, 5, 5, 8, 99, 42])\n"," >>> relab, fw, inv = relabel_sequential(label_field)\n"," >>> relab\n"," array([1, 1, 2, 2, 3, 5, 4])\n"," >>> fw\n"," array([0, 1, 0, 0, 0, 2, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n"," 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0,\n"," 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n"," 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n"," 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5])\n"," >>> inv\n"," array([ 0, 1, 5, 8, 42, 99])\n"," >>> (fw[label_field] == relab).all()\n"," True\n"," >>> (inv[relab] == label_field).all()\n"," True\n"," >>> relab, fw, inv = relabel_sequential(label_field, offset=5)\n"," >>> relab\n"," array([5, 5, 6, 6, 7, 9, 8])\n"," \"\"\"\n"," offset = int(offset)\n"," if offset <= 0:\n"," raise ValueError(\"Offset must be strictly positive.\")\n"," if np.min(label_field) < 0:\n"," raise ValueError(\"Cannot relabel array that contains negative values.\")\n"," max_label = int(label_field.max()) # Ensure max_label is an integer\n"," if not np.issubdtype(label_field.dtype, np.integer):\n"," new_type = np.min_scalar_type(max_label)\n"," label_field = label_field.astype(new_type)\n"," labels = np.unique(label_field)\n"," labels0 = labels[labels != 0]\n"," new_max_label = offset - 1 + len(labels0)\n"," new_labels0 = np.arange(offset, new_max_label + 1)\n"," output_type = label_field.dtype\n"," required_type = np.min_scalar_type(new_max_label)\n"," if np.dtype(required_type).itemsize > np.dtype(label_field.dtype).itemsize:\n"," output_type = required_type\n"," forward_map = np.zeros(max_label + 1, dtype=output_type)\n"," forward_map[labels0] = new_labels0\n"," inverse_map = np.zeros(new_max_label + 1, dtype=output_type)\n"," inverse_map[offset:] = labels0\n"," relabeled = forward_map[label_field]\n"," return relabeled, forward_map, inverse_map\n","\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Drt2bI08JFwc"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","It is good practice to evaluate the training progress by looking at the training loss over training epochs. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","During training values should decrease before reaching a minimal value which does not decrease further even after more training.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required.\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"0W3iPXWBuuDS","scrolled":false,"cellView":"form"},"source":["#@markdown ##Plot loss function\n","\n","# Stop the trainer if it exists\n","instance_exist = True\n","try:\n"," trainer = InteractiveTrainer.get_instance()\n","except:\n"," instance_exist = False\n","\n","if Use_the_current_trained_model and instance_exist:\n","\n"," trainer = InteractiveTrainer.get_instance()\n"," reports = trainer.get_reports()\n"," import matplotlib.pyplot as plt\n"," loss = [report['loss'] for report in reports]\n","\n"," plt.figure(figsize=(15,10))\n","\n"," plt.subplot(2,1,1)\n"," plt.plot(loss, label='Training loss')\n"," plt.title('Training loss vs. epoch number (linear scale)')\n"," plt.ylabel('Loss')\n"," plt.xlabel('Epoch number')\n","\n","\n"," plt.subplot(2,1,2)\n"," plt.semilogy(loss, label='Training loss')\n"," plt.title('Training loss vs. epoch number (log scale)')\n"," plt.ylabel('Loss')\n"," plt.xlabel('Epoch number')\n"," # plt.savefig(os.path.join(QC_model_path,'Quality Control/lossCurvePlots.png'), bbox_inches='tight', pad_inches=0)\n"," plt.show()\n","\n","else:\n"," print(bcolors.WARNING+\"Loss curves can currently only be obtained from a currently trained model.\"+bcolors.NORMAL)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"YZMyEUreRsD_"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","This section will calculate the Intersection over Union score for all the images provided in the `Source_QC_folder` and `Target_QC_folder` ! The result for one of the image will also be displayed.\n","\n","The **Intersection over Union** (IoU) metric is a method that can be used to quantify the percent overlap between the target mask and your prediction output. **Therefore, the closer to 1, the better the performance.** This metric can be used to assess the quality of your model to accurately predict nuclei. \n","\n","Here, the IoU is both calculated over the whole image and on a per-object basis. The value displayed below is the IoU value calculated over the entire image. The IoU value calculated on a per-object basis is used to calculate the other metrics displayed.\n","\n","“n_true” refers to the number of objects present in the ground truth image. “n_pred” refers to the number of objects present in the predicted image. \n","\n","When a segmented object has an IoU value above 0.5 (compared to the corresponding ground truth), it is then considered a true positive. The number of “**true positives**” is available in the table below. The number of “false positive” is then defined as “**false positive**” = “n_pred” - “true positive”. The number of “false negative” is defined as “false negative” = “n_true” - “true positive”.\n","\n","The mean_matched_score is the mean IoUs of matched true positives. The mean_true_score is the mean IoUs of matched true positives but normalized by the total number of ground truth objects. The panoptic_quality is calculated as described by [Kirillov et al. 2019](https://arxiv.org/abs/1801.00868).\n","\n","For more information about the other metric displayed, please consult the SI of the paper describing ZeroCostDL4Mic.\n","\n"," The results can be found in the \"*Quality Control*\" folder which is located inside your `model_folder`."]},{"cell_type":"code","metadata":{"id":"b8gyZwoERt58","cellView":"form"},"source":["\n","\n","\n","\n","#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","\n","\n","#@markdown ### Segmentation parameters:\n","Object_diameter = 0#@param {type:\"number\"}\n","\n","Flow_threshold = 0.4 #@param {type:\"slider\", min:0.1, max:1.1, step:0.1}\n","Cell_probability_threshold=0 #@param {type:\"slider\", min:-6, max:6, step:1}\n","\n","if Object_diameter is 0:\n"," Object_diameter = None\n"," print(\"The cell size will be estimated automatically for each image\")\n","\n","\n","# Find the number of channel in the input image\n","\n","random_choice = random.choice(os.listdir(Source_QC_folder))\n","x = io.imread(Source_QC_folder+\"/\"+random_choice)\n","n_channel = 1 if x.ndim == 2 else x.shape[-1]\n","\n","\n","channels=[0,0]\n","QC_model_folder = os.path.join(Model_folder,Model_name)\n","QC_model_path = os.path.join(QC_model_folder, 'final')\n","QC_model_name = Model_name\n","\n","#Create a quality control Folder and check if the folder already exist\n","if os.path.exists(QC_model_folder+\"/Quality Control\") == False:\n"," os.makedirs(QC_model_folder+\"/Quality Control\")\n","\n","if os.path.exists(QC_model_folder+\"/Quality Control/Prediction\"):\n"," rmtree(QC_model_folder+\"/Quality Control/Prediction\")\n","os.makedirs(QC_model_folder+\"/Quality Control/Prediction\")\n","\n","\n","model = models.CellposeModel(gpu=True, pretrained_model=QC_model_path, torch=True, diam_mean=30.0, net_avg=True, device=None, residual_on=True, style_on=True, concatenation=False)\n","\n","\n","# Here we need to make predictions\n","\n","for name in os.listdir(Source_QC_folder):\n"," \n"," print(\"Performing prediction on: \"+name)\n"," image = io.imread(Source_QC_folder+\"/\"+name) \n","\n"," short_name = os.path.splitext(name)\n"," masks, flows, styles = model.eval(image, diameter=Object_diameter, flow_threshold=Flow_threshold,cellprob_threshold=Cell_probability_threshold, channels=channels)\n"," \n"," os.chdir(QC_model_folder+\"/Quality Control/Prediction\")\n"," imsave(str(short_name[0])+\".tif\", masks, compress=ZIP_DEFLATED) \n"," \n","# Here we start testing the differences between GT and predicted masks\n","\n","with open(QC_model_folder+\"/Quality Control/Quality_Control for \"+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file, delimiter=\",\")\n"," writer.writerow([\"image\",\"Prediction v. GT Intersection over Union\", \"false positive\", \"true positive\", \"false negative\", \"precision\", \"recall\", \"accuracy\", \"f1 score\", \"n_true\", \"n_pred\", \"mean_true_score\", \"mean_matched_score\", \"panoptic_quality\"]) \n","\n","# define the images\n","\n"," for n in os.listdir(Source_QC_folder):\n"," \n"," if not os.path.isdir(os.path.join(Source_QC_folder,n)):\n"," print('Running QC on: '+n)\n"," test_input = io.imread(os.path.join(Source_QC_folder,n))\n"," test_prediction = io.imread(os.path.join(QC_model_folder+\"/Quality Control/Prediction\",n))\n"," test_ground_truth_image = io.imread(os.path.join(Target_QC_folder, n))\n","\n"," # Calculate the matching (with IoU threshold `thresh`) and all metrics\n","\n"," stats = matching(test_ground_truth_image, test_prediction, thresh=0.5)\n"," \n","\n"," #Convert pixel values to 0 or 255\n"," test_prediction_0_to_255 = test_prediction\n"," test_prediction_0_to_255[test_prediction_0_to_255>0] = 255\n","\n"," #Convert pixel values to 0 or 255\n"," test_ground_truth_0_to_255 = test_ground_truth_image\n"," test_ground_truth_0_to_255[test_ground_truth_0_to_255>0] = 255\n","\n","\n"," # Intersection over Union metric\n","\n"," intersection = np.logical_and(test_ground_truth_0_to_255, test_prediction_0_to_255)\n"," union = np.logical_or(test_ground_truth_0_to_255, test_prediction_0_to_255)\n"," iou_score = np.sum(intersection) / np.sum(union)\n"," writer.writerow([n, str(iou_score), str(stats.fp), str(stats.tp), str(stats.fn), str(stats.precision), str(stats.recall), str(stats.accuracy), str(stats.f1), str(stats.n_true), str(stats.n_pred), str(stats.mean_true_score), str(stats.mean_matched_score), str(stats.panoptic_quality)])\n","\n","from tabulate import tabulate\n","\n","df = pd.read_csv (QC_model_folder+\"/Quality Control/Quality_Control for \"+QC_model_name+\".csv\")\n","print(tabulate(df, headers='keys', tablefmt='psql'))\n","\n","\n","from astropy.visualization import simple_norm\n","\n","# ------------- For display ------------\n","print('--------------------------------------------------------------')\n","@interact\n","def show_QC_results(file = os.listdir(Source_QC_folder)):\n"," \n","\n"," plt.figure(figsize=(25,5))\n"," if n_channel > 1:\n"," source_image = io.imread(os.path.join(Source_QC_folder, file))\n"," if n_channel == 1:\n"," source_image = io.imread(os.path.join(Source_QC_folder, file), as_gray = True)\n","\n"," target_image = io.imread(os.path.join(Target_QC_folder, file), as_gray = True)\n"," prediction = io.imread(QC_model_folder+\"/Quality Control/Prediction/\"+file, as_gray = True)\n","\n"," stats = matching(prediction, target_image, thresh=0.5)\n","\n"," target_image_mask = np.empty_like(target_image)\n"," target_image_mask[target_image > 0] = 255\n"," target_image_mask[target_image == 0] = 0\n"," \n"," prediction_mask = np.empty_like(prediction)\n"," prediction_mask[prediction > 0] = 255\n"," prediction_mask[prediction == 0] = 0\n","\n"," intersection = np.logical_and(target_image_mask, prediction_mask)\n"," union = np.logical_or(target_image_mask, prediction_mask)\n"," iou_score = np.sum(intersection) / np.sum(union)\n","\n"," norm = simple_norm(source_image, percent = 99)\n","\n"," #Input\n"," plt.subplot(1,4,1)\n"," plt.axis('off')\n"," if n_channel > 1:\n"," plt.imshow(source_image)\n"," if n_channel == 1:\n"," plt.imshow(source_image, aspect='equal', norm=norm, cmap='magma', interpolation='nearest')\n"," plt.title('Input')\n","\n"," #Ground-truth\n"," plt.subplot(1,4,2)\n"," plt.axis('off')\n"," plt.imshow(target_image_mask, aspect='equal', cmap='Greens')\n"," plt.title('Ground Truth')\n","\n"," #Prediction\n"," plt.subplot(1,4,3)\n"," plt.axis('off')\n"," plt.imshow(prediction_mask, aspect='equal', cmap='Purples')\n"," plt.title('Prediction')\n","\n"," #Overlay\n"," plt.subplot(1,4,4)\n"," plt.axis('off')\n"," plt.imshow(target_image_mask, cmap='Greens')\n"," plt.imshow(prediction_mask, alpha=0.5, cmap='Purples')\n"," plt.title('Ground Truth and Prediction, Intersection over Union:'+str(round(iou_score,3 )));\n"," plt.savefig(QC_model_folder+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","\n","\n","# full_QC_model_path = QC_model_folder+'/'\n","# qc_pdf_export()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"JqAGDvAI7ALU"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"Lf9YM22iIlzv"},"source":["## **6.1 Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.) can now be used to process images. If an older model needs to be used, please untick the **Use_the_current_trained_model** box and enter the model's name and path to use. Predicted output images are saved in your **Prediction_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder_prediction`:** This folder should contain the images that you want to predict using the network that you will train.\n","\n","**`Result_folder`:** This folder is where the results from the predictions will be saved.\n","\n","**`Flow_threshold`:** This parameter controls the maximum allowed error of the flows for each mask. Increase this threshold if cellpose is not returning as many masks as you'd expect. Similarly, decrease this threshold if cellpose is returning too many ill-shaped masks. **Default value: 0.4**\n","\n","**`Cell_probability_threshold`:** The pixels greater than the Cell_probability_threshold are used to run dynamics and determine masks. Decrease this threshold if cellpose is not returning as many masks as you'd expect. Similarly, increase this threshold if cellpose is returning too many masks, particularly from dim areas. **Default value: 0.0**"]},{"cell_type":"code","metadata":{"id":"Yj8M-RD37RQA","cellView":"form"},"source":["# -------------------------------------------------- \n","#@markdown ###Data parameters\n","Data_folder_prediction = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n","#@markdown ### Model parameters\n","#@markdown Do you want to use the model you just trained?\n","Use_the_current_trained_model = False #@param {type:\"boolean\"}\n","#@markdown Otherwise, please provide path to the model folder below\n","prediction_model_path = \"\" #@param {type:\"string\"}\n","\n","#@markdown ### Segmentation parameters:\n","Object_diameter = 0#@param {type:\"number\"}\n","Flow_threshold = 0.4 #@param {type:\"slider\", min:0.1, max:1.1, step:0.1}\n","Cell_probability_threshold=0 #@param {type:\"slider\", min:-6, max:6, step:1}\n","\n","# -------------------------------------------------- \n","# TODO: allow to run on other file formats\n","# prediction_image_list = glob.glob(Data_folder_prediction+\"/*.jpg\")\n","# n_files = len(prediction_image_list)\n","\n","# filenames_list = []\n","# for name in os.listdir(Data_folder_prediction):\n","# if os.path.isfile(name):\n","# filenames_list.append(os.path.splitext(name)[0])\n","\n","# filenames_list = []\n","# for name in prediction_image_list:\n","# filenames_list.append(os.path.splitext(os.path.basename(name))[0])\n","\n","extension_list = ['*.jpg', '*.tif', '*.png']\n","prediction_image_list, filenames_list = get_image_list(Data_folder_prediction, extension_list)\n","\n","n_files = len(prediction_image_list)\n","print(\"Total number of files: \"+str(n_files))\n","\n","if Use_the_current_trained_model:\n"," prediction_model_path = os.path.join(Model_folder, Model_name)\n","\n","\n","model_path = os.path.join(prediction_model_path, \"final\")\n","# TODO: Check the line below for file compatibility\n","channels=[0,0] \n","model = models.CellposeModel(gpu=True, pretrained_model=model_path, torch=False, diam_mean=30.0, net_avg=True, device=None, residual_on=True, style_on=True, concatenation=False)\n","\n","\n","if (Object_diameter == 0):\n"," Object_diameter = None\n","\n","masks_list = []\n","for i in tqdm(range(n_files)):\n"," img = io.imread(prediction_image_list[i])\n"," masks, flows, styles = model.eval(img, diameter=Object_diameter, flow_threshold=Flow_threshold,cellprob_threshold=Cell_probability_threshold, channels=channels)\n"," imsave(os.path.join(Result_folder, filenames_list[i]+'.tif'), masks)\n","\n","\n","\n","\n","# ------------- For display ------------\n","print('--------------------------------------------------------------')\n","@interact\n","def show_labels(filename = filenames_list):\n"," plt.figure(figsize=(13,10))\n","\n"," img = io.imread(os.path.join(Data_folder_prediction, filename))\n"," mask = io.imread(os.path.join(Result_folder, filename+'.tif'))\n","\n"," plt.subplot(121)\n"," plt.imshow(img, cmap='gray', interpolation='nearest')\n"," plt.title('Source image')\n"," plt.axis('off')\n"," plt.subplot(122)\n"," plt.imshow(mask, cmap='nipy_spectral', interpolation='nearest')\n"," plt.title('Label')\n"," plt.axis('off');\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"D0I9oX82QDk6"},"source":["## **6.2. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"2nfGReX7QJnE"},"source":["---\n","#**Thank you for using Interactive segmentation - Cellpose 2D!**"]}]} \ No newline at end of file +{"nbformat":4,"nbformat_minor":0,"metadata":{"accelerator":"GPU","colab":{"name":"ZeroCostDL4Mic_Interactive_annotations_Cellpose.ipynb","provenance":[],"collapsed_sections":[],"toc_visible":true},"kernelspec":{"display_name":"ImJoy Interactive ML","language":"python","name":"imjoy-interactive-ml"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.7.9"}},"cells":[{"cell_type":"markdown","metadata":{"id":"8KzUtudGvysh"},"source":["# **Interactive segmentation**\n","---\n","\n","**Interactive segmentation** is a segmentation tool powered by deep learning and ImJoy that can be used to segment bioimages and was first published by [Ouyang *et al.* in 2021, on F1000R](https://f1000research.com/articles/10-142?s=09#ref-15).\n","\n","**The Original code** is freely available in GitHub:\n","/~https://github.com/imjoy-team/imjoy-interactive-segmentation\n","\n","**Please also cite this original paper when using or developing this notebook.**\n","\n","**!!Currently, this notebook only works with Google Chrome or Firefox!!**\n","\n"]},{"cell_type":"markdown","metadata":{"id":"yzATrlv6n14H"},"source":["# **1. Install Interactive segmentation**\n","---\n"]},{"cell_type":"code","metadata":{"id":"Xaacxk6suuDP","cellView":"form"},"source":["Notebook_version = '1.13'\n","Network = 'Kaibu'\n","\n","\n","from builtins import any as b_any\n","\n","def get_requirements_path():\n"," # Store requirements file in 'contents' directory\n"," current_dir = os.getcwd()\n"," dir_count = current_dir.count('/') - 1\n"," path = '../' * (dir_count) + 'requirements.txt'\n"," return path\n","\n","def filter_files(file_list, filter_list):\n"," filtered_list = []\n"," for fname in file_list:\n"," if b_any(fname.split('==')[0] in s for s in filter_list):\n"," filtered_list.append(fname)\n"," return filtered_list\n","\n","def build_requirements_file(before, after):\n"," path = get_requirements_path()\n","\n"," # Exporting requirements.txt for local run\n"," !pip freeze > $path\n","\n"," # Get minimum requirements file\n"," df = pd.read_csv(path, delimiter = \"\\n\")\n"," mod_list = [m.split('.')[0] for m in after if not m in before]\n"," req_list_temp = df.values.tolist()\n"," req_list = [x[0] for x in req_list_temp]\n","\n"," # Replace with package name and handle cases where import name is different to module name\n"," mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]\n"," mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list]\n"," filtered_list = filter_files(req_list, mod_replace_list)\n","\n"," file=open(path,'w')\n"," for item in filtered_list:\n"," file.writelines(item + '\\n')\n","\n"," file.close()\n","\n","import sys\n","before = [str(m) for m in sys.modules]\n","\n","#@markdown ##Install Interactive segmentation\n","import time\n","\n","# Start the clock to measure how long it takes\n","start = time.time()\n","\n","# !pip install -U Werkzeug==1.0.1\n","!pip install git+/~https://github.com/imjoy-team/imjoy-interactive-segmentation@master#egg=imjoy-interactive-trainer\n","!python3 -m ipykernel install --user --name imjoy-interactive-ml --display-name \"ImJoy Interactive ML\"\n","\n","from imjoy_interactive_trainer.imjoy_plugin import start_interactive_segmentation\n","from imjoy_interactive_trainer.interactive_trainer import InteractiveTrainer\n","from imjoy_interactive_trainer.data_utils import download_example_dataset\n","from imjoy_interactive_trainer.imgseg.geojson_utils import geojson_to_masks\n","\n","import os\n","import glob\n","from shutil import copyfile, rmtree\n","from tifffile import imread, imsave\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","from matplotlib import pyplot as plt\n","import cv2\n","from tqdm import tqdm\n","from skimage.util.shape import view_as_windows\n","from skimage import io\n","\n","!pip install cellpose \n","from cellpose import models\n","import numpy as np\n","\n","\n","import random\n","from zipfile import ZIP_DEFLATED\n","import csv\n","import pandas as pd\n","\n","from numba import jit\n","from scipy.optimize import linear_sum_assignment\n","from collections import namedtuple\n","\n","from tabulate import tabulate\n","from astropy.visualization import simple_norm\n","import matplotlib.pyplot as plt\n","\n","from concurrent.futures import ThreadPoolExecutor\n","\n","def PrepareDataAsPatches(Training_source, patch_width, patch_height, Data_tag):\n","\n"," # Here we assume that the Train and Test folders are already created\n"," patch_num = 0\n","\n"," for file in tqdm(os.listdir(Training_source)):\n"," \n"," if os.path.isfile(os.path.join(Training_source, file)):\n"," img = io.imread(os.path.join(Training_source, file))\n"," _,this_ext = os.path.splitext(file)\n","\n"," if len(img.shape) == 2:\n"," # Using view_as_windows with step size equal to the patch size to ensure there is no overlap\n"," patches_img = view_as_windows(img, (patch_width, patch_height), (patch_width, patch_height))\n"," patches_img = patches_img.reshape(patches_img.shape[0]*patches_img.shape[1], patch_width, patch_height)\n","\n"," elif len(img.shape) == 3:\n"," # Using view_as_windows with step size equal to the patch size to ensure there is no overlap\n"," patches_img = view_as_windows(img, (patch_width, patch_height, img.shape[2]), (patch_width, patch_height, img.shape[2]))\n"," patches_img = patches_img.reshape(patches_img.shape[0]*patches_img.shape[1], patch_width, patch_height, img.shape[2])\n","\n"," else:\n"," patches_img = []\n"," print('Data format currently unsupported.')\n","\n"," for i in range(patches_img.shape[0]):\n"," save_path = os.path.join(os.path.splitext(Training_source)[0], 'test','Patch_'+str(patch_num)+' - ' + os.path.splitext(os.path.basename(file))[0])\n"," os.mkdir(save_path)\n"," img_save_path = os.path.join(save_path, Data_tag)\n","\n"," if (len(patches_img[i].shape) == 2):\n"," this_image = np.repeat(patches_img[i][:,:,np.newaxis], repeats=3, axis=2) \n"," else:\n"," this_image = patches_img[i] \n","\n"," # Convert to 8-bit to save as png\n"," this_image =(this_image/this_image.max()*255).astype('uint8')\n"," io.imsave(img_save_path, this_image, check_contrast = False)\n","\n"," # Save raw images patches, preserving format and bit depth\n"," img_save_path_raw = os.path.join(save_path, 'Raw_data'+this_ext)\n"," io.imsave(img_save_path_raw, patches_img[i], check_contrast = False)\n","\n"," patch_num += 1\n","\n","\n","def get_image_list(Folder_path, extension_list = ['*.jpg', '*.tif', '*.png']):\n"," image_list = []\n"," for ext in extension_list:\n"," image_list = image_list + glob.glob(Folder_path+\"/\"+ext)\n","\n"," n_files = len(image_list)\n"," print('Number of files: '+str(n_files))\n","\n"," filenames_list = []\n"," for img_name in image_list:\n"," filenames_list.append(os.path.basename(img_name))\n","\n"," return image_list, filenames_list\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n"," NORMAL = '\\033[0m' # white (normal)\n","\n","# Check if this is the latest version of the notebook\n","# Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n","\n","# print('Notebook version: '+Notebook_version[0])\n","\n","# strlist = Notebook_version[0].split('.')\n","# Notebook_version_main = strlist[0]+'.'+strlist[1]\n","\n","# if Notebook_version_main == Latest_notebook_version.columns:\n","# print(\"This notebook is up-to-date.\")\n","# else:\n","# print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\"+ bcolors.NORMAL)\n","\n","All_notebook_versions = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_Notebook_versions.csv\", dtype=str)\n","print('Notebook version: '+Notebook_version)\n","Latest_Notebook_version = All_notebook_versions[All_notebook_versions[\"Notebook\"] == Network]['Version'].iloc[0]\n","print('Latest notebook version: '+Latest_Notebook_version)\n","if Notebook_version == Latest_Notebook_version:\n"," print(\"This notebook is up-to-date.\")\n","else:\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","# Build requirements file for local run\n","after = [str(m) for m in sys.modules]\n","build_requirements_file(before, after)\n","\n","# ---------------------------- Display ----------------------------\n","# Displaying the time elapsed for installation\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds,1),\"sec(s)\")\n","\n","\n","print(\"-----------\")\n","print(\"Interactive segmentation installed.\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"A2TvmvUFCBJo"},"source":["# **2. Complete the Colab session**\n","---"]},{"cell_type":"markdown","metadata":{"id":"9ATvHFVpCIZ2"},"source":["\n","## **2.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelerator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"Bo0sy_5YCQdY","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","\n","%tensorflow_version 1.x\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xh2v4Wuav4fe"},"source":["## **2.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"zL52lhh5v7tl","cellView":"form"},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"7dCOBJqkCuSf"},"source":["# **3. Select your parameters and paths**\n","---"]},{"cell_type":"markdown","metadata":{"id":"PE6tkvQQC06f"},"source":["## **3.1. Set and prepare dataset**\n","---\n","\n","**WARNING: Currently this notebook only builds 'Grayscale' Cellpose models. So please provide only grayscale equivalent dataset. WARNING.**\n","\n","**`Data_folder:`:** This is the path to the data to use for interactive annotation. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below. \n","\n","**`Use_patch:`:** This option splits the data available into patches of **`patch_size`** x **`patch_size`**. This allows to make all data consistent and formatted. We recommend to always use this option for stability. \n","\n","**`Reset_data:`:** Resetting the data will empty the training data folder and remove all the annotations available from previous uses.\n","\n","**`Use_example_data:`:** This will download and use the example data provided by [Ouyang *et al.* in 2021, on F1000R](https://f1000research.com/articles/10-142?s=09#ref-15).\n"]},{"cell_type":"code","metadata":{"id":"UgRFCIVFVa86","cellView":"form"},"source":["#@markdown ###**Prepare data**\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","#@markdown ###Split the data in small non-verlapping patches?\n","Use_patches = True #@param {type:\"boolean\"}\n","\n","patch_size = 256#@param {type:\"integer\"}\n","\n","#@markdown ###Reset the data? (**!annotations will be lost!**)\n","Reset_data = False #@param {type:\"boolean\"}\n","#@markdown ###Otherwise, use example data\n","Use_example_dataset = False #@param {type:\"boolean\"}\n","\n","\n","if Use_example_dataset:\n"," Data_folder = \"/content/data/hpa_dataset_v2\"\n"," download_example_dataset()\n"," Data_split = [\"microtubules.png\", \"er.png\", \"nuclei.png\"]\n"," channels=[2, 3]\n","\n","\n","else:\n"," Data_tag = \"data.png\" #Kaibu works best with PNGs!\n"," Data_split = [Data_tag]\n"," channels=[0, 0] # grayscale images without nuclei channel\n","\n"," if (Reset_data) and (os.path.exists(os.path.join(Data_folder, \"train\"))):\n"," rmtree(os.path.join(Data_folder, \"train\"))\n","\n"," if (Reset_data) and (os.path.exists(os.path.join(Data_folder, \"test\"))):\n"," rmtree(os.path.join(Data_folder, \"test\"))\n","\n"," if (os.path.exists(os.path.join(Data_folder, \"train\"))) and (os.path.exists(os.path.join(Data_folder, \"test\"))):\n"," print(\"Kaibu data already exist. Starting from these annotations!\")\n"," else:\n"," print(\"Creating new folders!\")\n","\n"," os.mkdir(os.path.join(Data_folder, \"train\"))\n"," os.mkdir(os.path.join(Data_folder, \"test\"))\n","\n"," if Use_patches:\n"," PrepareDataAsPatches(Data_folder, patch_size,patch_size, Data_tag)\n"," else:\n"," image_list, _ = get_image_list(Data_folder)\n"," # jpeg_image_list = glob.glob(Data_folder+\"/*.jpg\")\n"," n_files = len(image_list)\n"," print(\"Total number of files: \"+str(n_files))\n","\n"," for image in image_list:\n"," save_path = os.path.join(Data_folder, \"test\",os.path.splitext(os.path.basename(image))[0])\n"," os.mkdir(save_path)\n"," copyfile(image, save_path+\"/\"+Data_tag)\n","\n","\n","extension_list = ['*.jpg', '*.tif', '*.png']\n","image_list, filenames_list = get_image_list(Data_folder, extension_list)\n","\n","\n","if len(filenames_list) > 0:\n"," # ------------- For display ------------\n"," print('--------------------------------------------------------------')\n"," @interact\n"," def show_example_data(name = filenames_list):\n","\n"," plt.figure(figsize=(13,10))\n"," img = io.imread(os.path.join(Data_folder, name))\n","\n"," plt.imshow(img, cmap='gray')\n"," plt.title('Source image ('+str(img.shape[0])+'x'+str(img.shape[1])+')')\n"," plt.axis('off')\n","\n","\n","\n","\n","print(\"------\")\n","print(\"Data prepared for interactive segmentation.\")\n","\n","\n","\n","\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"6krdgD5IEeMC"},"source":["## **3.2. Prepare the Cellpose model**\n","---\n","\n","**`Model_folder:`:** Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**`Model_name:`:** Name of the model, the notebook will create a folder with this name within which the model will be saved.\n","\n","**`default_diameter:`:** Indicate the diameter of the objects (cells or Nuclei) you want to segment. If you input \"0\", this parameter will be estimated automatically for each of your images.\n","\n","**`model_type:`:** This is the Cellpose model that will be loaded initially and from which it will train from. This will allow to run reasonnable predictions even with no additional training data.\n","\n"]},{"cell_type":"code","metadata":{"id":"aSlO_wVmhKwo","cellView":"form"},"source":["#@markdown ###**Prepare model**\n","\n","Model_folder = \"\" #@param {type:\"string\"}\n","Model_name = \"\" #@param {type:\"string\"}\n","\n","if (os.path.exists(os.path.join(Model_folder, Model_name))):\n"," print(bcolors.WARNING +\"Model folder already exists and will be deleted at next step.\"+bcolors.NORMAL)\n","\n","\n","default_diameter = 0 #@param {type:\"number\"}\n","model_type = \"default\" #@param [\"cyto\", \"nuclei\", \"default\", \"none\",\"Own model\"]\n","\n","#@markdown ###**If using your own model, please select the path to the model**\n","own_model_path = \"\" #@param {type:\"string\"}\n","\n","\n","\n","resume = True\n","pretrained_model = None\n","\n","\n","if (model_type == \"default\"):\n"," model_type = 'cyto'\n","\n","if (model_type == \"none\"):\n"," model_type = 'cyto'\n"," resume = False\n","\n","if (model_type == \"Own model\"):\n"," model_type = None\n"," pretrained_model = own_model_path\n","\n","\n","\n","if (default_diameter == 0):\n"," default_diameter = None\n","\n","\n","print(\"------\")\n","# print(model_type)\n","# print(pretrained_model)\n","print(\"Model prepared.\")\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"qpv5orLZuuDR"},"source":["#**4. Interactive segmentation**\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"j2dOSJGFGHCP"},"source":["## **4.1. Run interactive segmentation interface**\n","---\n","\n"," This will start the interactive segmentation interface using ImJoy and Kaibu. \n","\n","* Get an image\n","* Predict on that image (using the pretrained model selected above)\n","* Edit the segmentations using the Selection and Draw tools\n","* You can save the annotations at any time\n","* When you're happy with the annotations, you can send the data for training\n","* You can then start the training\n","* Get a new image and annotate as above, send for training and repeat\n","\n"," The training will be running in the background as you annotate and send more training data. This will both generate high quality training data while building an increasingly good model.\n","\n"," The Kaibu interface can be made fullscreen by clicking on the three dots on the top right of the cell, and select the Full screen option. Then, it needs to be minimized and maximised again.\n"]},{"cell_type":"code","metadata":{"id":"02YTQ_ErW6NY","cellView":"form"},"source":["# @markdown #Start interactive segmentation interface\n","\n","# Restart the trainer if necessary\n","instance_exist = True\n","try:\n"," trainer = InteractiveTrainer.get_instance()\n","except:\n"," instance_exist = False\n","\n","if instance_exist:\n"," print(\"Trainer already exists. Restarting trainer!\")\n"," trainer.stop()\n","\n","\n","if (os.path.exists(os.path.join(Model_folder, Model_name))):\n"," print('Deleting pre-existing model folder...')\n"," rmtree(os.path.join(Model_folder, Model_name))\n","\n","os.mkdir(os.path.join(Model_folder, Model_name))\n","\n","model_config = dict(type=\"cellpose\",\n"," model_dir=os.path.join(Model_folder, Model_name),\n"," use_gpu=True,\n"," channels=[0, 0],\n"," style_on=0,\n"," batch_size=1,\n"," default_diameter = default_diameter,\n"," pretrained_model = pretrained_model,\n"," model_type = model_type,\n"," resume = resume)\n","\n","start_interactive_segmentation(model_config,\n"," Data_folder,\n"," Data_split,\n"," object_name=\"cell\",\n"," scale_factor=1.0,\n"," restore_test_annotation=True)\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"9kucqOfCGOpa"},"source":["## **4.2. Create masks from annotations**\n","---\n","\n"," This cell will allow you to create and visualise instance segmentation masks. It will be created from the annotations made from the interface and will be saved into a folder called **`Paired training dataset`** in your data folder (**`Data_folder`**). This data can be used to train another segmentation model if necessary."]},{"cell_type":"code","metadata":{"id":"fMSEf9cQ0Nb9","cellView":"form"},"source":["#@markdown ##Create and check annotated training images\n","if (os.path.exists(os.path.join(Data_folder,'Paired training dataset'))):\n"," rmtree(os.path.join(Data_folder,'Paired training dataset'))\n","\n","os.mkdir(os.path.join(Data_folder,'Paired training dataset'))\n","os.mkdir(os.path.join(Data_folder,'Paired training dataset','Images'))\n","os.mkdir(os.path.join(Data_folder,'Paired training dataset','Masks'))\n","\n","\n","dir_list = os.listdir(os.path.join(Data_folder, \"train\"))\n","# _, ext = os.path.splitext(Data_tag)\n","\n","for dir in dir_list:\n"," annotation_file = os.path.join(Data_folder, \"train\", dir, \"annotation.json\") \n"," mask_dict = geojson_to_masks(annotation_file, mask_types=[\"labels\"]) \n"," labels = mask_dict[\"labels\"]\n"," imsave(os.path.join(Data_folder, \"train\", dir, \"label.tif\"), labels)\n","\n"," imsave(os.path.join(Data_folder,'Paired training dataset','Masks', dir+\".tif\"), labels)\n","\n","\n"," file_list = os.listdir(os.path.join(Data_folder, \"train\", dir))\n"," for file in file_list:\n"," filename, this_ext = os.path.splitext(file)\n"," if filename == 'Raw_data':\n"," copyfile(os.path.join(Data_folder, \"train\", dir, file), os.path.join(Data_folder,'Paired training dataset','Images', dir+this_ext))\n","\n"," # raw_data_tag = glob.glob(dir+\"/Raw_data.*\")\n"," # print(raw_data_tag)\n"," # copyfile(os.path.join(Data_folder, \"train\", dir, Data_tag), os.path.join(Data_folder,'Paired training dataset','Images', dir+ext))\n","\n","\n","\n","\n","# ------------- For display ------------\n","print('--------------------------------------------------------------')\n","@interact\n","def show_labels(dir=dir_list):\n"," plt.figure(figsize=(13,10))\n","\n"," imgSource = cv2.imread(os.path.join(Data_folder, \"train\", dir, Data_tag))\n"," imgLabel = imread(os.path.join(Data_folder, \"train\", dir, \"label.tif\"))\n","\n"," plt.subplot(121)\n"," plt.imshow(imgSource, cmap='gray', interpolation='nearest')\n"," plt.title('Source image')\n"," plt.axis('off')\n"," plt.subplot(122)\n"," plt.imshow(imgLabel, cmap='nipy_spectral', interpolation='nearest')\n"," plt.title('Label')\n"," plt.axis('off');\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"yrfFxMalGyE9"},"source":["## **4.3. Stop training**\n","---\n","\n"," Once training has started, the training will carry on until stopped. Here, the training can be stopped. This will automatically create the final model, which can be used for Quality Control (Section 5 below) and for predictions (Section 6 below).\n"]},{"cell_type":"code","metadata":{"id":"yC-u5eKQHUHC","cellView":"form"},"source":["#@markdown ##Stop training\n","\n","# Stop the trainer if it exists\n","instance_exist = True\n","try:\n"," trainer = InteractiveTrainer.get_instance()\n","except:\n"," instance_exist = False\n","\n","print('-------------')\n","if instance_exist:\n"," print(\"Trainer stopped.\")\n"," trainer.stop()\n","else:\n"," print(\"No trainers currently running.\")\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"krBTcbMlI91b"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"RcdqcaPnc5C9","cellView":"form"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, indicate which model you want to assess:\n","\n","QC_model_path = \"\" #@param {type:\"string\"}\n","\n","\n","if Use_the_current_trained_model :\n","\n"," QC_model_path = Model_folder+\"/\"+Model_name+\"/final\"\n","\n"," #model = models.CellposeModel(gpu=True, pretrained_model=QC_model_path, torch=False, diam_mean=30.0, net_avg=True, device=None, residual_on=True, style_on=True, concatenation=False)\n"," model = models.CellposeModel(gpu=True, pretrained_model=QC_model_path, torch=True, diam_mean=30.0, net_avg=True, device=None, residual_on=True, style_on=True, concatenation=False)\n","\n"," QC_model_folder = os.path.dirname(QC_model_path)\n"," QC_model_name = os.path.basename(QC_model_folder)\n"," Saving_path = QC_model_folder\n","\n"," print(\"The \"+str(QC_model_name)+\" model will be evaluated\")\n","\n","else:\n","\n"," if os.path.exists(QC_model_path):\n"," model = models.CellposeModel(gpu=True, pretrained_model=QC_model_path, torch=True, diam_mean=30.0, net_avg=True, device=None, residual_on=True, style_on=True, concatenation=False)\n"," \n"," QC_model_folder = os.path.dirname(QC_model_path)\n"," Saving_path = QC_model_folder\n"," QC_model_name = os.path.basename(QC_model_folder)\n"," print(\"The \"+str(QC_model_name)+\" model will be evaluated\")\n"," \n"," else: \n"," print(bcolors.WARNING+'!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path before proceeding further.')\n","\n","\n","\n","# Here we load the def that perform the QC, code taken from StarDist /~https://github.com/mpicbg-csbd/stardist/blob/master/stardist/matching.py\n","\n","\n","\n","matching_criteria = dict()\n","\n","def label_are_sequential(y):\n"," \"\"\" returns true if y has only sequential labels from 1... \"\"\"\n"," labels = np.unique(y)\n"," return (set(labels)-{0}) == set(range(1,1+labels.max()))\n","\n","\n","def is_array_of_integers(y):\n"," return isinstance(y,np.ndarray) and np.issubdtype(y.dtype, np.integer)\n","\n","\n","def _check_label_array(y, name=None, check_sequential=False):\n"," err = ValueError(\"{label} must be an array of {integers}.\".format(\n"," label = 'labels' if name is None else name,\n"," integers = ('sequential ' if check_sequential else '') + 'non-negative integers',\n"," ))\n"," is_array_of_integers(y) or print(\"An error occured\")\n"," if check_sequential:\n"," label_are_sequential(y) or print(\"An error occured\")\n"," else:\n"," y.min() >= 0 or print(\"An error occured\")\n"," return True\n","\n","\n","def label_overlap(x, y, check=True):\n"," if check:\n"," _check_label_array(x,'x',True)\n"," _check_label_array(y,'y',True)\n"," x.shape == y.shape or _raise(ValueError(\"x and y must have the same shape\"))\n"," return _label_overlap(x, y)\n","\n","@jit(nopython=True)\n","def _label_overlap(x, y):\n"," x = x.ravel()\n"," y = y.ravel()\n"," overlap = np.zeros((1+x.max(),1+y.max()), dtype=np.uint)\n"," for i in range(len(x)):\n"," overlap[x[i],y[i]] += 1\n"," return overlap\n","\n","\n","def intersection_over_union(overlap):\n"," _check_label_array(overlap,'overlap')\n"," if np.sum(overlap) == 0:\n"," return overlap\n"," n_pixels_pred = np.sum(overlap, axis=0, keepdims=True)\n"," n_pixels_true = np.sum(overlap, axis=1, keepdims=True)\n"," return overlap / (n_pixels_pred + n_pixels_true - overlap)\n","\n","matching_criteria['iou'] = intersection_over_union\n","\n","\n","def intersection_over_true(overlap):\n"," _check_label_array(overlap,'overlap')\n"," if np.sum(overlap) == 0:\n"," return overlap\n"," n_pixels_true = np.sum(overlap, axis=1, keepdims=True)\n"," return overlap / n_pixels_true\n","\n","matching_criteria['iot'] = intersection_over_true\n","\n","\n","def intersection_over_pred(overlap):\n"," _check_label_array(overlap,'overlap')\n"," if np.sum(overlap) == 0:\n"," return overlap\n"," n_pixels_pred = np.sum(overlap, axis=0, keepdims=True)\n"," return overlap / n_pixels_pred\n","\n","matching_criteria['iop'] = intersection_over_pred\n","\n","\n","def precision(tp,fp,fn):\n"," return tp/(tp+fp) if tp > 0 else 0\n","def recall(tp,fp,fn):\n"," return tp/(tp+fn) if tp > 0 else 0\n","def accuracy(tp,fp,fn):\n"," # also known as \"average precision\" (?)\n"," # -> https://www.kaggle.com/c/data-science-bowl-2018#evaluation\n"," return tp/(tp+fp+fn) if tp > 0 else 0\n","def f1(tp,fp,fn):\n"," # also known as \"dice coefficient\"\n"," return (2*tp)/(2*tp+fp+fn) if tp > 0 else 0\n","\n","\n","def _safe_divide(x,y):\n"," return x/y if y>0 else 0.0\n","\n","def matching(y_true, y_pred, thresh=0.5, criterion='iou', report_matches=False):\n"," \"\"\"Calculate detection/instance segmentation metrics between ground truth and predicted label images.\n"," Currently, the following metrics are implemented:\n"," 'fp', 'tp', 'fn', 'precision', 'recall', 'accuracy', 'f1', 'criterion', 'thresh', 'n_true', 'n_pred', 'mean_true_score', 'mean_matched_score', 'panoptic_quality'\n"," Corresponding objects of y_true and y_pred are counted as true positives (tp), false positives (fp), and false negatives (fn)\n"," whether their intersection over union (IoU) >= thresh (for criterion='iou', which can be changed)\n"," * mean_matched_score is the mean IoUs of matched true positives\n"," * mean_true_score is the mean IoUs of matched true positives but normalized by the total number of GT objects\n"," * panoptic_quality defined as in Eq. 1 of Kirillov et al. \"Panoptic Segmentation\", CVPR 2019\n"," Parameters\n"," ----------\n"," y_true: ndarray\n"," ground truth label image (integer valued)\n"," predicted label image (integer valued)\n"," thresh: float\n"," threshold for matching criterion (default 0.5)\n"," criterion: string\n"," matching criterion (default IoU)\n"," report_matches: bool\n"," if True, additionally calculate matched_pairs and matched_scores (note, that this returns even gt-pred pairs whose scores are below 'thresh')\n"," Returns\n"," -------\n"," Matching object with different metrics as attributes\n"," Examples\n"," --------\n"," >>> y_true = np.zeros((100,100), np.uint16)\n"," >>> y_true[10:20,10:20] = 1\n"," >>> y_pred = np.roll(y_true,5,axis = 0)\n"," >>> stats = matching(y_true, y_pred)\n"," >>> print(stats)\n"," Matching(criterion='iou', thresh=0.5, fp=1, tp=0, fn=1, precision=0, recall=0, accuracy=0, f1=0, n_true=1, n_pred=1, mean_true_score=0.0, mean_matched_score=0.0, panoptic_quality=0.0)\n"," \"\"\"\n"," _check_label_array(y_true,'y_true')\n"," _check_label_array(y_pred,'y_pred')\n"," y_true.shape == y_pred.shape or _raise(ValueError(\"y_true ({y_true.shape}) and y_pred ({y_pred.shape}) have different shapes\".format(y_true=y_true, y_pred=y_pred)))\n"," criterion in matching_criteria or _raise(ValueError(\"Matching criterion '%s' not supported.\" % criterion))\n"," if thresh is None: thresh = 0\n"," thresh = float(thresh) if np.isscalar(thresh) else map(float,thresh)\n","\n"," y_true, _, map_rev_true = relabel_sequential(y_true)\n"," y_pred, _, map_rev_pred = relabel_sequential(y_pred)\n","\n"," overlap = label_overlap(y_true, y_pred, check=False)\n"," scores = matching_criteria[criterion](overlap)\n"," assert 0 <= np.min(scores) <= np.max(scores) <= 1\n","\n"," # ignoring background\n"," scores = scores[1:,1:]\n"," n_true, n_pred = scores.shape\n"," n_matched = min(n_true, n_pred)\n","\n"," def _single(thr):\n"," not_trivial = n_matched > 0 and np.any(scores >= thr)\n"," if not_trivial:\n"," # compute optimal matching with scores as tie-breaker\n"," costs = -(scores >= thr).astype(float) - scores / (2*n_matched)\n"," true_ind, pred_ind = linear_sum_assignment(costs)\n"," assert n_matched == len(true_ind) == len(pred_ind)\n"," match_ok = scores[true_ind,pred_ind] >= thr\n"," tp = np.count_nonzero(match_ok)\n"," else:\n"," tp = 0\n"," fp = n_pred - tp\n"," fn = n_true - tp\n"," # assert tp+fp == n_pred\n"," # assert tp+fn == n_true\n","\n"," # the score sum over all matched objects (tp)\n"," sum_matched_score = np.sum(scores[true_ind,pred_ind][match_ok]) if not_trivial else 0.0\n","\n"," # the score average over all matched objects (tp)\n"," mean_matched_score = _safe_divide(sum_matched_score, tp)\n"," # the score average over all gt/true objects\n"," mean_true_score = _safe_divide(sum_matched_score, n_true)\n"," panoptic_quality = _safe_divide(sum_matched_score, tp+fp/2+fn/2)\n","\n"," stats_dict = dict (\n"," criterion = criterion,\n"," thresh = thr,\n"," fp = fp,\n"," tp = tp,\n"," fn = fn,\n"," precision = precision(tp,fp,fn),\n"," recall = recall(tp,fp,fn),\n"," accuracy = accuracy(tp,fp,fn),\n"," f1 = f1(tp,fp,fn),\n"," n_true = n_true,\n"," n_pred = n_pred,\n"," mean_true_score = mean_true_score,\n"," mean_matched_score = mean_matched_score,\n"," panoptic_quality = panoptic_quality,\n"," )\n"," if bool(report_matches):\n"," if not_trivial:\n"," stats_dict.update (\n"," # int() to be json serializable\n"," matched_pairs = tuple((int(map_rev_true[i]),int(map_rev_pred[j])) for i,j in zip(1+true_ind,1+pred_ind)),\n"," matched_scores = tuple(scores[true_ind,pred_ind]),\n"," matched_tps = tuple(map(int,np.flatnonzero(match_ok))),\n"," )\n"," else:\n"," stats_dict.update (\n"," matched_pairs = (),\n"," matched_scores = (),\n"," matched_tps = (),\n"," )\n"," return namedtuple('Matching',stats_dict.keys())(*stats_dict.values())\n","\n"," return _single(thresh) if np.isscalar(thresh) else tuple(map(_single,thresh))\n","\n","\n","\n","def matching_dataset(y_true, y_pred, thresh=0.5, criterion='iou', by_image=False, show_progress=True, parallel=False):\n"," \"\"\"matching metrics for list of images, see `stardist.matching.matching`\n"," \"\"\"\n"," len(y_true) == len(y_pred) or _raise(ValueError(\"y_true and y_pred must have the same length.\"))\n"," return matching_dataset_lazy (\n"," tuple(zip(y_true,y_pred)), thresh=thresh, criterion=criterion, by_image=by_image, show_progress=show_progress, parallel=parallel,\n"," )\n","\n","\n","\n","def matching_dataset_lazy(y_gen, thresh=0.5, criterion='iou', by_image=False, show_progress=True, parallel=False):\n","\n"," expected_keys = set(('fp', 'tp', 'fn', 'precision', 'recall', 'accuracy', 'f1', 'criterion', 'thresh', 'n_true', 'n_pred', 'mean_true_score', 'mean_matched_score', 'panoptic_quality'))\n","\n"," single_thresh = False\n"," if np.isscalar(thresh):\n"," single_thresh = True\n"," thresh = (thresh,)\n","\n"," tqdm_kwargs = {}\n"," tqdm_kwargs['disable'] = not bool(show_progress)\n"," if int(show_progress) > 1:\n"," tqdm_kwargs['total'] = int(show_progress)\n","\n"," # compute matching stats for every pair of label images\n"," if parallel:\n"," fn = lambda pair: matching(*pair, thresh=thresh, criterion=criterion, report_matches=False)\n"," with ThreadPoolExecutor() as pool:\n"," stats_all = tuple(pool.map(fn, tqdm(y_gen,**tqdm_kwargs)))\n"," else:\n"," stats_all = tuple (\n"," matching(y_t, y_p, thresh=thresh, criterion=criterion, report_matches=False)\n"," for y_t,y_p in tqdm(y_gen,**tqdm_kwargs)\n"," )\n","\n"," # accumulate results over all images for each threshold separately\n"," n_images, n_threshs = len(stats_all), len(thresh)\n"," accumulate = [{} for _ in range(n_threshs)]\n"," for stats in stats_all:\n"," for i,s in enumerate(stats):\n"," acc = accumulate[i]\n"," for k,v in s._asdict().items():\n"," if k == 'mean_true_score' and not bool(by_image):\n"," # convert mean_true_score to \"sum_matched_score\"\n"," acc[k] = acc.setdefault(k,0) + v * s.n_true\n"," else:\n"," try:\n"," acc[k] = acc.setdefault(k,0) + v\n"," except TypeError:\n"," pass\n","\n"," # normalize/compute 'precision', 'recall', 'accuracy', 'f1'\n"," for thr,acc in zip(thresh,accumulate):\n"," set(acc.keys()) == expected_keys or _raise(ValueError(\"unexpected keys\"))\n"," acc['criterion'] = criterion\n"," acc['thresh'] = thr\n"," acc['by_image'] = bool(by_image)\n"," if bool(by_image):\n"," for k in ('precision', 'recall', 'accuracy', 'f1', 'mean_true_score', 'mean_matched_score', 'panoptic_quality'):\n"," acc[k] /= n_images\n"," else:\n"," tp, fp, fn, n_true = acc['tp'], acc['fp'], acc['fn'], acc['n_true']\n"," sum_matched_score = acc['mean_true_score']\n","\n"," mean_matched_score = _safe_divide(sum_matched_score, tp)\n"," mean_true_score = _safe_divide(sum_matched_score, n_true)\n"," panoptic_quality = _safe_divide(sum_matched_score, tp+fp/2+fn/2)\n","\n"," acc.update(\n"," precision = precision(tp,fp,fn),\n"," recall = recall(tp,fp,fn),\n"," accuracy = accuracy(tp,fp,fn),\n"," f1 = f1(tp,fp,fn),\n"," mean_true_score = mean_true_score,\n"," mean_matched_score = mean_matched_score,\n"," panoptic_quality = panoptic_quality,\n"," )\n","\n"," accumulate = tuple(namedtuple('DatasetMatching',acc.keys())(*acc.values()) for acc in accumulate)\n"," return accumulate[0] if single_thresh else accumulate\n","\n","\n","\n","# copied from scikit-image master for now (remove when part of a release)\n","def relabel_sequential(label_field, offset=1):\n"," \"\"\"Relabel arbitrary labels to {`offset`, ... `offset` + number_of_labels}.\n"," This function also returns the forward map (mapping the original labels to\n"," the reduced labels) and the inverse map (mapping the reduced labels back\n"," to the original ones).\n"," Parameters\n"," ----------\n"," label_field : numpy array of int, arbitrary shape\n"," An array of labels, which must be non-negative integers.\n"," offset : int, optional\n"," The return labels will start at `offset`, which should be\n"," strictly positive.\n"," Returns\n"," -------\n"," relabeled : numpy array of int, same shape as `label_field`\n"," The input label field with labels mapped to\n"," {offset, ..., number_of_labels + offset - 1}.\n"," The data type will be the same as `label_field`, except when\n"," offset + number_of_labels causes overflow of the current data type.\n"," forward_map : numpy array of int, shape ``(label_field.max() + 1,)``\n"," The map from the original label space to the returned label\n"," space. Can be used to re-apply the same mapping. See examples\n"," for usage. The data type will be the same as `relabeled`.\n"," inverse_map : 1D numpy array of int, of length offset + number of labels\n"," The map from the new label space to the original space. This\n"," can be used to reconstruct the original label field from the\n"," relabeled one. The data type will be the same as `relabeled`.\n"," Notes\n"," -----\n"," The label 0 is assumed to denote the background and is never remapped.\n"," The forward map can be extremely big for some inputs, since its\n"," length is given by the maximum of the label field. However, in most\n"," situations, ``label_field.max()`` is much smaller than\n"," ``label_field.size``, and in these cases the forward map is\n"," guaranteed to be smaller than either the input or output images.\n"," Examples\n"," --------\n"," >>> from skimage.segmentation import relabel_sequential\n"," >>> label_field = np.array([1, 1, 5, 5, 8, 99, 42])\n"," >>> relab, fw, inv = relabel_sequential(label_field)\n"," >>> relab\n"," array([1, 1, 2, 2, 3, 5, 4])\n"," >>> fw\n"," array([0, 1, 0, 0, 0, 2, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n"," 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0,\n"," 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n"," 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n"," 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5])\n"," >>> inv\n"," array([ 0, 1, 5, 8, 42, 99])\n"," >>> (fw[label_field] == relab).all()\n"," True\n"," >>> (inv[relab] == label_field).all()\n"," True\n"," >>> relab, fw, inv = relabel_sequential(label_field, offset=5)\n"," >>> relab\n"," array([5, 5, 6, 6, 7, 9, 8])\n"," \"\"\"\n"," offset = int(offset)\n"," if offset <= 0:\n"," raise ValueError(\"Offset must be strictly positive.\")\n"," if np.min(label_field) < 0:\n"," raise ValueError(\"Cannot relabel array that contains negative values.\")\n"," max_label = int(label_field.max()) # Ensure max_label is an integer\n"," if not np.issubdtype(label_field.dtype, np.integer):\n"," new_type = np.min_scalar_type(max_label)\n"," label_field = label_field.astype(new_type)\n"," labels = np.unique(label_field)\n"," labels0 = labels[labels != 0]\n"," new_max_label = offset - 1 + len(labels0)\n"," new_labels0 = np.arange(offset, new_max_label + 1)\n"," output_type = label_field.dtype\n"," required_type = np.min_scalar_type(new_max_label)\n"," if np.dtype(required_type).itemsize > np.dtype(label_field.dtype).itemsize:\n"," output_type = required_type\n"," forward_map = np.zeros(max_label + 1, dtype=output_type)\n"," forward_map[labels0] = new_labels0\n"," inverse_map = np.zeros(new_max_label + 1, dtype=output_type)\n"," inverse_map[offset:] = labels0\n"," relabeled = forward_map[label_field]\n"," return relabeled, forward_map, inverse_map\n","\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Drt2bI08JFwc"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","It is good practice to evaluate the training progress by looking at the training loss over training epochs. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","During training values should decrease before reaching a minimal value which does not decrease further even after more training.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required.\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"0W3iPXWBuuDS","scrolled":false,"cellView":"form"},"source":["#@markdown ##Plot loss function\n","\n","# Stop the trainer if it exists\n","instance_exist = True\n","try:\n"," trainer = InteractiveTrainer.get_instance()\n","except:\n"," instance_exist = False\n","\n","if Use_the_current_trained_model and instance_exist:\n","\n"," trainer = InteractiveTrainer.get_instance()\n"," reports = trainer.get_reports()\n"," \n"," loss = [report['loss'] for report in reports]\n","\n"," plt.figure(figsize=(15,10))\n","\n"," plt.subplot(2,1,1)\n"," plt.plot(loss, label='Training loss')\n"," plt.title('Training loss vs. epoch number (linear scale)')\n"," plt.ylabel('Loss')\n"," plt.xlabel('Epoch number')\n","\n","\n"," plt.subplot(2,1,2)\n"," plt.semilogy(loss, label='Training loss')\n"," plt.title('Training loss vs. epoch number (log scale)')\n"," plt.ylabel('Loss')\n"," plt.xlabel('Epoch number')\n"," # plt.savefig(os.path.join(QC_model_path,'Quality Control/lossCurvePlots.png'), bbox_inches='tight', pad_inches=0)\n"," plt.show()\n","\n","else:\n"," print(bcolors.WARNING+\"Loss curves can currently only be obtained from a currently trained model.\"+bcolors.NORMAL)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"YZMyEUreRsD_"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","This section will calculate the Intersection over Union score for all the images provided in the `Source_QC_folder` and `Target_QC_folder` ! The result for one of the image will also be displayed.\n","\n","The **Intersection over Union** (IoU) metric is a method that can be used to quantify the percent overlap between the target mask and your prediction output. **Therefore, the closer to 1, the better the performance.** This metric can be used to assess the quality of your model to accurately predict nuclei. \n","\n","Here, the IoU is both calculated over the whole image and on a per-object basis. The value displayed below is the IoU value calculated over the entire image. The IoU value calculated on a per-object basis is used to calculate the other metrics displayed.\n","\n","“n_true” refers to the number of objects present in the ground truth image. “n_pred” refers to the number of objects present in the predicted image. \n","\n","When a segmented object has an IoU value above 0.5 (compared to the corresponding ground truth), it is then considered a true positive. The number of “**true positives**” is available in the table below. The number of “false positive” is then defined as “**false positive**” = “n_pred” - “true positive”. The number of “false negative” is defined as “false negative” = “n_true” - “true positive”.\n","\n","The mean_matched_score is the mean IoUs of matched true positives. The mean_true_score is the mean IoUs of matched true positives but normalized by the total number of ground truth objects. The panoptic_quality is calculated as described by [Kirillov et al. 2019](https://arxiv.org/abs/1801.00868).\n","\n","For more information about the other metric displayed, please consult the SI of the paper describing ZeroCostDL4Mic.\n","\n"," The results can be found in the \"*Quality Control*\" folder which is located inside your `model_folder`."]},{"cell_type":"code","metadata":{"id":"b8gyZwoERt58","cellView":"form"},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","\n","\n","#@markdown ### Segmentation parameters:\n","Object_diameter = 0#@param {type:\"number\"}\n","\n","Flow_threshold = 0.4 #@param {type:\"slider\", min:0.1, max:1.1, step:0.1}\n","Cell_probability_threshold=0 #@param {type:\"slider\", min:-6, max:6, step:1}\n","\n","if Object_diameter is 0:\n"," Object_diameter = None\n"," print(\"The cell size will be estimated automatically for each image\")\n","\n","\n","# Find the number of channel in the input image\n","\n","random_choice = random.choice(os.listdir(Source_QC_folder))\n","x = io.imread(Source_QC_folder+\"/\"+random_choice)\n","n_channel = 1 if x.ndim == 2 else x.shape[-1]\n","\n","\n","channels=[0,0]\n","QC_model_folder = os.path.join(Model_folder,Model_name)\n","QC_model_path = os.path.join(QC_model_folder, 'final')\n","QC_model_name = Model_name\n","\n","#Create a quality control Folder and check if the folder already exist\n","if os.path.exists(QC_model_folder+\"/Quality Control\") == False:\n"," os.makedirs(QC_model_folder+\"/Quality Control\")\n","\n","if os.path.exists(QC_model_folder+\"/Quality Control/Prediction\"):\n"," rmtree(QC_model_folder+\"/Quality Control/Prediction\")\n","os.makedirs(QC_model_folder+\"/Quality Control/Prediction\")\n","\n","\n","model = models.CellposeModel(gpu=True, pretrained_model=QC_model_path, torch=True, diam_mean=30.0, net_avg=True, device=None, residual_on=True, style_on=True, concatenation=False)\n","\n","\n","# Here we need to make predictions\n","\n","for name in os.listdir(Source_QC_folder):\n"," \n"," print(\"Performing prediction on: \"+name)\n"," image = io.imread(Source_QC_folder+\"/\"+name) \n","\n"," short_name = os.path.splitext(name)\n"," masks, flows, styles = model.eval(image, diameter=Object_diameter, flow_threshold=Flow_threshold,cellprob_threshold=Cell_probability_threshold, channels=channels)\n"," \n"," os.chdir(QC_model_folder+\"/Quality Control/Prediction\")\n"," imsave(str(short_name[0])+\".tif\", masks, compress=ZIP_DEFLATED) \n"," \n","# Here we start testing the differences between GT and predicted masks\n","\n","with open(QC_model_folder+\"/Quality Control/Quality_Control for \"+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file, delimiter=\",\")\n"," writer.writerow([\"image\",\"Prediction v. GT Intersection over Union\", \"false positive\", \"true positive\", \"false negative\", \"precision\", \"recall\", \"accuracy\", \"f1 score\", \"n_true\", \"n_pred\", \"mean_true_score\", \"mean_matched_score\", \"panoptic_quality\"]) \n","\n","# define the images\n","\n"," for n in os.listdir(Source_QC_folder):\n"," \n"," if not os.path.isdir(os.path.join(Source_QC_folder,n)):\n"," print('Running QC on: '+n)\n"," test_input = io.imread(os.path.join(Source_QC_folder,n))\n"," test_prediction = io.imread(os.path.join(QC_model_folder+\"/Quality Control/Prediction\",n))\n"," test_ground_truth_image = io.imread(os.path.join(Target_QC_folder, n))\n","\n"," # Calculate the matching (with IoU threshold `thresh`) and all metrics\n","\n"," stats = matching(test_ground_truth_image, test_prediction, thresh=0.5)\n"," \n","\n"," #Convert pixel values to 0 or 255\n"," test_prediction_0_to_255 = test_prediction\n"," test_prediction_0_to_255[test_prediction_0_to_255>0] = 255\n","\n"," #Convert pixel values to 0 or 255\n"," test_ground_truth_0_to_255 = test_ground_truth_image\n"," test_ground_truth_0_to_255[test_ground_truth_0_to_255>0] = 255\n","\n","\n"," # Intersection over Union metric\n","\n"," intersection = np.logical_and(test_ground_truth_0_to_255, test_prediction_0_to_255)\n"," union = np.logical_or(test_ground_truth_0_to_255, test_prediction_0_to_255)\n"," iou_score = np.sum(intersection) / np.sum(union)\n"," writer.writerow([n, str(iou_score), str(stats.fp), str(stats.tp), str(stats.fn), str(stats.precision), str(stats.recall), str(stats.accuracy), str(stats.f1), str(stats.n_true), str(stats.n_pred), str(stats.mean_true_score), str(stats.mean_matched_score), str(stats.panoptic_quality)])\n","\n","\n","\n","df = pd.read_csv (QC_model_folder+\"/Quality Control/Quality_Control for \"+QC_model_name+\".csv\")\n","print(tabulate(df, headers='keys', tablefmt='psql'))\n","\n","\n","\n","# ------------- For display ------------\n","print('--------------------------------------------------------------')\n","@interact\n","def show_QC_results(file = os.listdir(Source_QC_folder)):\n"," \n","\n"," plt.figure(figsize=(25,5))\n"," if n_channel > 1:\n"," source_image = io.imread(os.path.join(Source_QC_folder, file))\n"," if n_channel == 1:\n"," source_image = io.imread(os.path.join(Source_QC_folder, file), as_gray = True)\n","\n"," target_image = io.imread(os.path.join(Target_QC_folder, file), as_gray = True)\n"," prediction = io.imread(QC_model_folder+\"/Quality Control/Prediction/\"+file, as_gray = True)\n","\n"," stats = matching(prediction, target_image, thresh=0.5)\n","\n"," target_image_mask = np.empty_like(target_image)\n"," target_image_mask[target_image > 0] = 255\n"," target_image_mask[target_image == 0] = 0\n"," \n"," prediction_mask = np.empty_like(prediction)\n"," prediction_mask[prediction > 0] = 255\n"," prediction_mask[prediction == 0] = 0\n","\n"," intersection = np.logical_and(target_image_mask, prediction_mask)\n"," union = np.logical_or(target_image_mask, prediction_mask)\n"," iou_score = np.sum(intersection) / np.sum(union)\n","\n"," norm = simple_norm(source_image, percent = 99)\n","\n"," #Input\n"," plt.subplot(1,4,1)\n"," plt.axis('off')\n"," if n_channel > 1:\n"," plt.imshow(source_image)\n"," if n_channel == 1:\n"," plt.imshow(source_image, aspect='equal', norm=norm, cmap='magma', interpolation='nearest')\n"," plt.title('Input')\n","\n"," #Ground-truth\n"," plt.subplot(1,4,2)\n"," plt.axis('off')\n"," plt.imshow(target_image_mask, aspect='equal', cmap='Greens')\n"," plt.title('Ground Truth')\n","\n"," #Prediction\n"," plt.subplot(1,4,3)\n"," plt.axis('off')\n"," plt.imshow(prediction_mask, aspect='equal', cmap='Purples')\n"," plt.title('Prediction')\n","\n"," #Overlay\n"," plt.subplot(1,4,4)\n"," plt.axis('off')\n"," plt.imshow(target_image_mask, cmap='Greens')\n"," plt.imshow(prediction_mask, alpha=0.5, cmap='Purples')\n"," plt.title('Ground Truth and Prediction, Intersection over Union:'+str(round(iou_score,3 )));\n"," plt.savefig(QC_model_folder+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","\n","\n","# full_QC_model_path = QC_model_folder+'/'\n","# qc_pdf_export()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"JqAGDvAI7ALU"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"Lf9YM22iIlzv"},"source":["## **6.1 Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.) can now be used to process images. If an older model needs to be used, please untick the **Use_the_current_trained_model** box and enter the model's name and path to use. Predicted output images are saved in your **Prediction_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder_prediction`:** This folder should contain the images that you want to predict using the network that you will train.\n","\n","**`Result_folder`:** This folder is where the results from the predictions will be saved.\n","\n","**`Flow_threshold`:** This parameter controls the maximum allowed error of the flows for each mask. Increase this threshold if cellpose is not returning as many masks as you'd expect. Similarly, decrease this threshold if cellpose is returning too many ill-shaped masks. **Default value: 0.4**\n","\n","**`Cell_probability_threshold`:** The pixels greater than the Cell_probability_threshold are used to run dynamics and determine masks. Decrease this threshold if cellpose is not returning as many masks as you'd expect. Similarly, increase this threshold if cellpose is returning too many masks, particularly from dim areas. **Default value: 0.0**"]},{"cell_type":"code","metadata":{"id":"Yj8M-RD37RQA","cellView":"form"},"source":["# -------------------------------------------------- \n","#@markdown ###Data parameters\n","Data_folder_prediction = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n","#@markdown ### Model parameters\n","#@markdown Do you want to use the model you just trained?\n","Use_the_current_trained_model = False #@param {type:\"boolean\"}\n","#@markdown Otherwise, please provide path to the model folder below\n","prediction_model_path = \"\" #@param {type:\"string\"}\n","\n","#@markdown ### Segmentation parameters:\n","Object_diameter = 0#@param {type:\"number\"}\n","Flow_threshold = 0.4 #@param {type:\"slider\", min:0.1, max:1.1, step:0.1}\n","Cell_probability_threshold=0 #@param {type:\"slider\", min:-6, max:6, step:1}\n","\n","# -------------------------------------------------- \n","# TODO: allow to run on other file formats\n","# prediction_image_list = glob.glob(Data_folder_prediction+\"/*.jpg\")\n","# n_files = len(prediction_image_list)\n","\n","# filenames_list = []\n","# for name in os.listdir(Data_folder_prediction):\n","# if os.path.isfile(name):\n","# filenames_list.append(os.path.splitext(name)[0])\n","\n","# filenames_list = []\n","# for name in prediction_image_list:\n","# filenames_list.append(os.path.splitext(os.path.basename(name))[0])\n","\n","extension_list = ['*.jpg', '*.tif', '*.png']\n","prediction_image_list, filenames_list = get_image_list(Data_folder_prediction, extension_list)\n","\n","n_files = len(prediction_image_list)\n","print(\"Total number of files: \"+str(n_files))\n","\n","if Use_the_current_trained_model:\n"," prediction_model_path = os.path.join(Model_folder, Model_name)\n","\n","\n","model_path = os.path.join(prediction_model_path, \"final\")\n","# TODO: Check the line below for file compatibility\n","channels=[0,0] \n","model = models.CellposeModel(gpu=True, pretrained_model=model_path, torch=False, diam_mean=30.0, net_avg=True, device=None, residual_on=True, style_on=True, concatenation=False)\n","\n","\n","if (Object_diameter == 0):\n"," Object_diameter = None\n","\n","masks_list = []\n","for i in tqdm(range(n_files)):\n"," img = io.imread(prediction_image_list[i])\n"," masks, flows, styles = model.eval(img, diameter=Object_diameter, flow_threshold=Flow_threshold,cellprob_threshold=Cell_probability_threshold, channels=channels)\n"," imsave(os.path.join(Result_folder, filenames_list[i]+'.tif'), masks)\n","\n","\n","\n","\n","# ------------- For display ------------\n","print('--------------------------------------------------------------')\n","@interact\n","def show_labels(filename = filenames_list):\n"," plt.figure(figsize=(13,10))\n","\n"," img = io.imread(os.path.join(Data_folder_prediction, filename))\n"," mask = io.imread(os.path.join(Result_folder, filename+'.tif'))\n","\n"," plt.subplot(121)\n"," plt.imshow(img, cmap='gray', interpolation='nearest')\n"," plt.title('Source image')\n"," plt.axis('off')\n"," plt.subplot(122)\n"," plt.imshow(mask, cmap='nipy_spectral', interpolation='nearest')\n"," plt.title('Label')\n"," plt.axis('off');\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"D0I9oX82QDk6"},"source":["## **6.2. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"w4sH7tKyKQ2z"},"source":["# **7. Version log**\n","\n","---\n","**v1.13**: \n","\n","* The section 1 and 2 are now swapped for better export of *requirements.txt*. \n","* This version also now includes built-in version check and the version log that you're reading now. \n","* This version also specifically pulls StarDist packages version 0.6.2, due to current incompatibilities with the newest versions.\n","* Better data input compatibilities.\n","\n","---"]},{"cell_type":"markdown","metadata":{"id":"2nfGReX7QJnE"},"source":["---\n","#**Thank you for using Interactive segmentation - Cellpose 2D!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/Beta notebooks/fnet_2D_ZeroCostDL4Mic.ipynb b/Colab_notebooks/Beta notebooks/fnet_2D_ZeroCostDL4Mic.ipynb new file mode 100644 index 00000000..2af8e180 --- /dev/null +++ b/Colab_notebooks/Beta notebooks/fnet_2D_ZeroCostDL4Mic.ipynb @@ -0,0 +1 @@ +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"fnet_2D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"12UsRdIQbcWQjYewI2wrcwIWfVxc6hOfc","timestamp":1618500199387},{"file_id":"1A26cn0nxWQCv-LuP3UBfyCWlKBGIo0RU","timestamp":1611063104553},{"file_id":"1MmLTCC0nyX3Akb9V4C_OVxM3X_M8u-eX","timestamp":1610543191319},{"file_id":"1paNjUObR5Rcr4BMGADJTz0PQBBLZDPrY","timestamp":1602522500580},{"file_id":"1WZRIoSBNcRUEq4-Rq5M4mDkIaOlEHnxz","timestamp":1588762142860},{"file_id":"10weAY0es-pEfHlACCaBCKK7PmgdoJqdh","timestamp":1587728072051},{"file_id":"10Ze0rFZoooyyTL_OIVWGdFJEhWE6_cSB","timestamp":1586789421439},{"file_id":"1SsGyUbWcMaLGHFepMuKElRNYLdEBUwf6","timestamp":1583244509550}],"collapsed_sections":[],"toc_visible":true,"machine_shape":"hm"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.7"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"IkSguVy8Xv83"},"source":["#**Label-free prediction - fnet**\n","---\n","\n"," \n","Label-free prediction (fnet) is a neural network developped to infer the distribution of specific cellular structures from label-free images such as brightfield or EM images. It was first published in 2018 by [Ounkomol *et al.* in Nature Methods](https://www.nature.com/articles/s41592-018-0111-2). The network uses a common U-Net architecture and is trained using paired imaging volumes from the same field of view, imaged in a label-free (e.g. brightfield) and labelled condition (e.g. fluorescence images of a specific label of interest). When trained, this allows the user to identify certain structures from brightfield images alone. The performance of fnet may depend significantly on the structure at hand.\n","\n","---\n"," *Disclaimer*:\n","\n"," This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n"," This notebook is largely based on the paper: \n","\n","**Label-free prediction of three-dimensional fluorescence images from transmitted light microscopy** by Ounkomol *et al.* in Nature Methods, 2018 (https://www.nature.com/articles/s41592-018-0111-2)\n","\n"," And source code found in: /~https://github.com/AllenCellModeling/pytorch_fnet\n","\n"," **Please also cite this original paper when using or developing this notebook.** \n"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"gKDLkLWUd-YX"},"source":["#**0. Before getting started**\n","---\n","**Data Format**\n","\n"," **The data used to train fnet must be 3D stacks in .tiff (.tif) file format and contain the signal (e.g. bright-field image) and the target channel (e.g. fluorescence) for each field of view**. To use this notebook on user data, upload the data in the following format to your google drive. To ensure corresponding images are used during training give corresponding signal and target images the same name.\n","\n","Information on how to generate a training dataset is available in our Wiki page: /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n"," **Note: Your *dataset_folder* should not have spaces or brackets in its name as this is not recognized by the fnet code and will throw an error** \n","\n","\n","* Experiment A\n"," - **Training dataset**\n"," - bright-field images\n"," - img_1.tif, img_2.tif, ...\n"," - fluorescence images\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - bright-field images\n"," - img_1.tif, img_2.tif\n"," - fluorescence images\n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **1. Initialise the Colab session**\n","\n","\n","\n","\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelerator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"zCvebubeSaGY","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"sNIVx8_CLolt"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","cellView":"form"},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"AdN8B91xZO0x"},"source":["#**2. Install fnet and dependencies**\n","---\n","Running fnet requires the fnet folder to be downloaded into the session's Files. As fnet needs several packages to be installed, this step may take a few minutes.\n","\n","You can ignore **the error warnings** as they refer to packages not required for this notebook.\n","\n","**Note: It is not necessary to keep the pytorch_fnet folder after you are finished using the notebook, so it can be deleted afterwards by playing the last cell (bottom).**"]},{"cell_type":"code","metadata":{"id":"fq21zJVFNASx","cellView":"form"},"source":["Notebook_version = ['1.12']\n","\n","from builtins import any as b_any\n","\n","def get_requirements_path():\n"," # Store requirements file in 'contents' directory \n"," current_dir = os.getcwd()\n"," dir_count = current_dir.count('/') - 1\n"," path = '../' * (dir_count) + 'requirements.txt'\n"," return path\n","\n","def filter_files(file_list, filter_list):\n"," filtered_list = []\n"," for fname in file_list:\n"," if b_any(fname.split('==')[0] in s for s in filter_list):\n"," filtered_list.append(fname)\n"," return filtered_list\n","\n","def build_requirements_file(before, after):\n"," path = get_requirements_path()\n","\n"," # Exporting requirements.txt for local run\n"," !pip freeze > $path\n","\n"," # Get minimum requirements file\n"," df = pd.read_csv(path, delimiter = \"\\n\")\n"," mod_list = [m.split('.')[0] for m in after if not m in before]\n"," req_list_temp = df.values.tolist()\n"," req_list = [x[0] for x in req_list_temp]\n","\n"," # Replace with package name and handle cases where import name is different to module name\n"," mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]\n"," mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list] \n"," filtered_list = filter_files(req_list, mod_replace_list)\n","\n"," file=open(path,'w')\n"," for item in filtered_list:\n"," file.writelines(item + '\\n')\n","\n"," file.close()\n","\n","import sys\n","before = [str(m) for m in sys.modules]\n","\n","!pip install fpdf\n","\n","#@markdown ##Play this cell to download fnet to your drive. If it is already installed this will only install the fnet dependencies.\n","import os\n","import csv\n","import shutil\n","import random\n","from tempfile import mkstemp\n","from shutil import move, copymode\n","from os import fdopen, remove\n","import sys\n","import numpy as np\n","import shutil\n","import os\n","from tempfile import mkstemp\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from skimage import img_as_float32\n","from distutils.dir_util import copy_tree\n","from datetime import datetime\n","import time\n","from fpdf import FPDF, HTMLMixin\n","from pip._internal.operations.freeze import freeze\n","import subprocess\n","\n","#Ensure tensorflow 1.x\n","%tensorflow_version 1.x\n","import tensorflow\n","print(tensorflow.__version__)\n","\n","print(\"Tensorflow enabled.\")\n","\n","#clone fnet from github to colab\n","#!pip install -U scipy==1.2.0\n","#!pip install matplotlib==2.2.3\n","if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet'):\n"," !git clone -b release_1 --single-branch /~https://github.com/AllenCellModeling/pytorch_fnet.git; cd pytorch_fnet; pip install .\n"," shutil.move('/content/pytorch_fnet','/content/gdrive/My Drive/pytorch_fnet')\n","!pip install -U scipy==1.2.0\n","!pip install matplotlib==2.2.3\n","from skimage import io\n","from matplotlib import pyplot as plt\n","import pandas as pd\n","#from skimage.util import img_as_uint\n","import matplotlib as mpl\n","#from scipy import signal\n","#from scipy import ndimage\n","\n","\n","#This function replaces the old default files with new values\n","def replace(file_path, pattern, subst):\n"," #Create temp file\n"," fh, abs_path = mkstemp()\n"," with fdopen(fh,'w') as new_file:\n"," with open(file_path) as old_file:\n"," for line in old_file:\n"," new_file.write(line.replace(pattern, subst))\n"," #Copy the file permissions from the old file to the new file\n"," copymode(file_path, abs_path)\n"," #Remove original file\n"," remove(file_path)\n"," #Move new file\n"," move(abs_path, file_path)\n","\n","def insert_line_to_file(filepath,line_number,insertion):\n"," f = open(filepath, \"r\")\n"," contents = f.readlines()\n"," f.close()\n"," f = open(filepath, \"r\")\n"," if not insertion in f.read():\n"," contents.insert(line_number, insertion)\n"," f.close()\n"," f = open(filepath, \"w\")\n"," contents = \"\".join(contents)\n"," f.write(contents)\n"," f.close()\n","\n","def add_validation(filepath,line_number,insert,append):\n"," f = open(filepath, \"r\")\n"," contents = f.readlines()\n"," f.close()\n"," f = open(filepath, \"r\")\n"," if not 'PATH_DATASET_VAL_CSV=' in f.read():\n"," contents.insert(line_number, insert)\n"," contents.append(append)\n"," f.close()\n"," f = open(filepath, \"w\")\n"," contents = \"\".join(contents)\n"," f.write(contents)\n"," f.close()\n","\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","#Here we replace values in the old files\n","#Change maximum pixel number\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/fnet/transforms.py\",'n_max_pixels=9732096','n_max_pixels=20000000')\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/predict.py\",'6000000','20000000')\n","\n","#Prevent resizing in the training and the prediction\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/predict.py\",\"0.37241\",\"1.0\")\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/train_model.py\",\"0.37241\",\"1.0\")\n","\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/train_model.py\",\"'--class_dataset', default='CziDataset'\",\"'--class_dataset', default='TiffDataset'\")\n","\n","\n","#2D \n","\n","replace(\"/content/gdrive/MyDrive/pytorch_fnet/train_model.py\",\"default=[32, 64, 64]\",\"default=[128, 128]\")\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/train_model.py\",\"'--nn_module', default='fnet_nn_3d'\",\"'--nn_module', default='fnet_nn_2d'\")\n","\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/train_model.py\",\", default_resizer_str]\",\"]\")\n","#replace(\"/content/gdrive/My Drive/pytorch_fnet/predict.py\",\", default_resizer_str]\",\"]\")\n","\n","\n","print(\"-------------------\")\n","print(\"Libraries installed\")\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","\n","\n","# Check if this is the latest version of the notebook\n","Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n","\n","if Notebook_version == list(Latest_notebook_version.columns):\n"," print(\"This notebook is up-to-date.\")\n","\n","if not Notebook_version == list(Latest_notebook_version.columns):\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'Label-free Prediction (fnet)'\n"," day = datetime.now()\n"," date_time = str(day)[0:10]\n","\n"," Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+date_time\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," \n"," # add another cell \n"," if trained:\n"," training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n"," pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n"," pdf.ln(1)\n","\n"," Header_2 = 'Information for your materials and methods:'\n"," pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n"," #print(all_packages)\n","\n"," #Main Packages\n"," main_packages = ''\n"," version_numbers = []\n"," for name in ['tensorflow','numpy','torch','scipy']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n"," cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n"," cuda_version = cuda_version.stdout.decode('utf-8')\n"," cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n"," gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n"," gpu_name = gpu_name.stdout.decode('utf-8')\n"," gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n"," #print(cuda_version[cuda_version.find(', V')+3:-1])\n"," #print(gpu_name)\n","\n"," shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n"," dataset_size = len(os.listdir(Training_source))\n","\n"," text = 'The '+Network+' model was trained from scratch for '+str(steps)+' steps on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(32)+','+str(64)+','+str(64)+')) with a batch size of '+str(batch_size)+' and an MSE loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), torch (v '+version_numbers[2]+'), scipy (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," #text = 'The '+Network+' model ('+model_name+') was trained using '+str(dataset_size)+' paired images (image dimensions: '+str(shape)+') using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), torch (v '+version_numbers[2]+'), scipy (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The GPU used was a '+gpu_name+'.'\n","\n"," #if Use_pretrained_model:\n"," # text = 'The '+Network+' model ('+model_name+') was trained using '+str(dataset_size)+' paired images (image dimensions: '+str(shape)+') using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), pytorch (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The GPU used was a '+gpu_name+'. The model was trained from the pretrained model: '+pretrained_model_path+'.'\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.multi_cell(190, 5, txt = text, align='L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n"," pdf.set_font('')\n"," if augmentation:\n"," aug_text = 'The dataset was augmented by'\n"," if Rotation:\n"," aug_text = aug_text+'\\n- rotation'\n"," if Flip:\n"," aug_text = aug_text+'\\n- flipping'\n"," else:\n"," aug_text = 'No augmentation was used for training.'\n"," pdf.multi_cell(190, 5, txt=aug_text, align='L')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," # if Use_Default_Advanced_Parameters:\n"," # pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n"," pdf.cell(200, 5, txt='The following parameters were used for training:')\n"," pdf.ln(1)\n"," html = \"\"\" \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
percentage_validation{0}
steps{1}
batch_size{2}
\n"," \"\"\".format(percentage_validation,steps,batch_size)\n"," pdf.write_html(html)\n","\n"," #pdf.multi_cell(190, 5, txt = text_2, align='L')\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," pdf.ln(1)\n"," pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(30, 5, txt= 'Training_source:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(28, 5, txt= 'Training_target:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n"," #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(22, 5, txt= 'Model Path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n"," pdf.ln(1)\n"," pdf.cell(60, 5, txt = 'Example Training pair (single slice)', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread('/content/TrainingDataExample_Fnet.png').shape\n"," pdf.image('/content/TrainingDataExample_Fnet.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- Label-free prediction (fnet): Ounkomol, Chawin, et al. \"Label-free prediction of three-dimensional fluorescence images from transmitted-light microscopy.\" Nature methods 15.11 (2018): 917-920.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," pdf.ln(3)\n"," reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," if trained:\n"," pdf.output(model_path+'/'+model_name+'/'+model_name+'_'+date_time+\"_training_report.pdf\")\n"," else:\n"," pdf.output('/content/'+model_name+'_'+date_time+\"_training_report.pdf\")\n"," \n","\n","\n","def qc_pdf_export():\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'Label-free prediction (fnet)'\n"," #model_name = os.path.basename(QC_model_folder)\n"," day = datetime.now()\n"," date_time = str(day)[0:10]\n","\n"," Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+date_time\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(2)\n"," pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n"," if os.path.exists(full_QC_model_path+'/Quality Control/lossCurvePlots.png'):\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/lossCurvePlots.png').shape\n"," pdf.image(full_QC_model_path+'/Quality Control/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," # pdf.ln(3)\n"," pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.')\n"," pdf.ln(3)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(3)\n"," pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n"," pdf.image(full_QC_model_path+'/Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n","\n"," pdf.ln(1)\n"," html = \"\"\"\n"," \n"," \n"," \"\"\"\n"," with open(full_QC_model_path+'/Quality Control/QC_metrics_'+QC_model_name+'.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," slice_n = header[1]\n"," mSSIM_PvsGT = header[2]\n"," NRMSE_PvsGT = header[3]\n"," PSNR_PvsGT = header[4]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,slice_n,mSSIM_PvsGT,NRMSE_PvsGT,PSNR_PvsGT)\n"," html = html+header\n"," for row in metrics:\n"," image = row[0]\n"," slice_n = row[1]\n"," mSSIM_PvsGT = row[2]\n"," NRMSE_PvsGT = row[3]\n"," PSNR_PvsGT = row[4]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,slice_n,str(round(float(mSSIM_PvsGT),3)),str(round(float(NRMSE_PvsGT),3)),str(round(float(PSNR_PvsGT),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}{3}{4}
{0}{1}{2}{3}{4}
\"\"\"\n"," \n"," pdf.write_html(html)\n","\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- Label-free prediction (fnet): Ounkomol, Chawin, et al. \"Label-free prediction of three-dimensional fluorescence images from transmitted-light microscopy.\" Nature methods 15.11 (2018): 917-920.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n"," pdf.ln(3)\n"," reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n","\n","# Build requirements file for local run\n","after = [str(m) for m in sys.modules]\n","build_requirements_file(before, after)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"HLYcZR9gMv42"},"source":["# **3. Select your parameters and paths**\n","---"]},{"cell_type":"markdown","metadata":{"id":"FQ_QxtSWQ7CL"},"source":["## **3.1. Setting the main training parameters**\n","---\n"," **Paths for training data**\n","\n"," **`Training_source`,`Training_target:`** These are the paths to your folders containing the Training_source (brightfield) and Training_target (fluorescent label) training data respectively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**Note: The stacks for fnet should either have 32 or more slices or have a number of slices which are a power of 2 (e.g. 2,4,8,16).**\n","\n"," **Training Parameters**\n","\n"," **`percentage validation`** Input the percentage of your training dataset you want to use to validate the network during training. **Default value: 10** \n","\n"," **`steps:`** Input how many iterations you want to train the network for. A larger number may improve performance but risks overfitting to the training data. To reach good performance of fnet requires several 10000's iterations which will usually require **several hours**, depending on the dataset size. **Default: 50000**\n","\n","**`batch_size:`** Reducing or increasing the **batch size** may speed up or slow down your training, respectively and can influence network performance. **Default: 4**"]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","cellView":"form"},"source":["#@markdown ###Datasets\n","#Datasets\n","from astropy.visualization import simple_norm\n","\n","#Change checkpoints\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/train_model.py\",\"'--interval_save', type=int, default=500\",\"'--interval_save', type=int, default=100\")\n","\n","#Adapt Class Dataset for Tiff files\n","#replace(\"/content/gdrive/My Drive/pytorch_fnet/train_model.py\",\"'--class_dataset', default='CziDataset'\",\"'--class_dataset', default='TiffDataset'\")\n","\n","#Fetch the path and extract the name of the signal folder\n","Training_source = \"\" #@param {type: \"string\"}\n","source_name = os.path.basename(os.path.normpath(Training_source))\n","\n","#Fetch the path and extract the name of the signal folder\n","Training_target = \"\" #@param {type: \"string\"}\n","target_name = os.path.basename(os.path.normpath(Training_target))\n","\n","#@markdown ###Model name and model path\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","#here we check that no model with the same name already exist, if so print a warning\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n"," print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\"+W)\n","#@markdown ---\n","\n","#@markdown ###Training Parameters\n","\n","percentage_validation = 10#@param{type:\"number\"}\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," shutil.rmtree(model_path+'/'+model_name)\n"," \n","#dataset = model_name #The name of the dataset and the model will be the same\n","\n","#Here, we check if the dataset already exists. If not, copy the dataset from google drive to the data folder\n"," \n","if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name):\n"," #shutil.copytree(own_dataset,'/content/gdrive/My Drive/pytorch_fnet/data/'+dataset)\n"," os.makedirs('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name)\n"," shutil.copytree(Training_source,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," shutil.copytree(Training_target,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n","elif os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name) and not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name):\n"," shutil.copytree(Training_source,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," shutil.copytree(Training_target,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n","elif os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name) and os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n"," shutil.copytree(Training_source,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," shutil.copytree(Training_target,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n","\n","#Create a path_csv file to point to the training images\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data')\n","\n","source = os.listdir('./'+model_name+'/'+source_name)\n","target = os.listdir('./'+model_name+'/'+target_name)\n","\n","#print(\"Selected \"+dataset+\" as training set\")\n","\n","model_name_x = model_name+\"}\" # this variable is only used to ensure closed curly brackets when editing the .sh files\n","\n","#We need to declare that we will run validation on the dataset\n","#We need to add a new line to the train.sh file\n","with open(\"/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh\", \"r\") as f:\n"," if not \"gpu_ids ${GPU_IDS} \\\\\" in f.read():\n"," replace(\"/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh\",\" --gpu_ids ${GPU_IDS}\",\" --gpu_ids ${GPU_IDS} \\\\\")\n","\n","#We add the necessary validation parameters here.\n","insert = 'PATH_DATASET_VAL_CSV=\"data/csvs/${DATASET}_val.csv\"'\n","append = '\\n --path_dataset_val_csv ${PATH_DATASET_VAL_CSV}'\n","add_validation(\"/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh\",10,insert,append)\n","\n","#Clear the White space from train.sh\n","\n","with open('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh', 'r') as inFile,\\\n"," open('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model_temp.sh', 'w') as outFile:\n"," for line in inFile:\n"," if line.strip():\n"," outFile.write(line)\n","os.remove('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh')\n","os.rename('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model_temp.sh','/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh')\n","\n","#Here we define the random set of training files to be used for validation\n","val_files = random.sample(source,round(len(source)*(percentage_validation/100)))\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input'):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input')\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Target'):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Target')\n","\n","#Make validation directories\n","os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input')\n","os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Target')\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data')\n","\n","#Move a random set of files from the training to the validation folders\n","for file in val_files:\n"," shutil.move('./'+model_name+'/'+source_name+'/'+file,'./'+model_name+'/Validation_Input/'+file)\n"," shutil.move('./'+model_name+'/'+target_name+'/'+file,'./'+model_name+'/Validation_Target/'+file)\n","\n","#Redefine the source and target lists after moving the validation files\n","source = os.listdir('./'+model_name+'/'+source_name)\n","target = os.listdir('./'+model_name+'/'+target_name)\n","\n","#Define Validation file lists\n","val_signal = os.listdir('./'+model_name+'/Validation_Input')\n","val_target = os.listdir('./'+model_name+'/Validation_Target')\n","\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'_val.csv'):\n"," os.remove('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'_val.csv')\n","\n","#Finally, we create a validation csv file to construct the validation dataset\n","with open(model_name+'_val.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(val_signal)):\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+model_name+\"/Validation_Input/\"+val_signal[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+model_name+\"/Validation_Target/\"+val_target[i]])\n","\n","shutil.move('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'_val.csv','/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'_val.csv')\n","\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'.csv'):\n"," os.remove('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'.csv')\n","with open(model_name+'.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(source)):\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+model_name+\"/\"+source_name+\"/\"+source[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+model_name+\"/\"+target_name+\"/\"+target[i]])\n","\n","shutil.move('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'.csv','/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'.csv')\n","\n","\n","#Training parameters in fnet are indicated in the train_model.sh file.\n","#Here, we edit this file to include the desired parameters\n","\n","#1. Add permissions to train_model.sh\n","os.chdir(\"/content/gdrive/My Drive/pytorch_fnet/scripts\")\n","!chmod u+x train_model.sh\n","\n","#2. Select parameters\n","steps = 50000#@param {type:\"number\"}\n","batch_size = 16#@param {type:\"number\"}\n","number_of_images = len(source)\n","\n","#3. Insert the above values into train_model.sh\n","!if ! grep saved_models\\/\\${ train_model.sh;then sed -i 's/saved_models\\/.*/saved_models\\/\\${DATASET}\"/g' train_model.sh; fi \n","!sed -i \"s/1:-.*/1:-$model_name_x/g\" train_model.sh #change the dataset to be trained with\n","!sed -i \"s/N_ITER=.*/N_ITER=$steps/g\" train_model.sh #change the number of training iterations (steps)\n","!sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" train_model.sh #change the number of training images\n","!sed -i \"s/BATCH_SIZE=.*/BATCH_SIZE=$batch_size/g\" train_model.sh #change the batch size\n","\n","#We also change the training split as in our notebook the test images are used separately for prediction and we want fnet to train on the whole training data set.\n","!sed -i \"s/train_size .* -v/train_size 1.0 -v/g\" train_model.sh\n","\n","#If new parameters are inserted here for training a model with the same name\n","#the previous training csv needs to be removed, to prevent the model using the old training split or paths.\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name)\n","\n","#No Augmentation by default\n","Use_Data_augmentation = False\n","\n","#Load one randomly chosen training source file\n","random_choice=random.choice(os.listdir(Training_source))\n","x = io.imread(Training_source+\"/\"+random_choice)\n","\n","#Find image Z dimension and select the mid-plane\n","# Image_Z = x.shape[0]\n","# mid_plane = int(Image_Z / 2)+1\n","\n","os.chdir(Training_target)\n","y = io.imread(Training_target+\"/\"+random_choice)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, norm=simple_norm(x, percent = 99), interpolation='nearest')\n","plt.axis('off')\n","plt.title('Training Source (single Z plane)');\n","plt.subplot(1,2,2)\n","plt.imshow(y, norm=simple_norm(y, percent = 99), interpolation='nearest')\n","plt.axis('off')\n","plt.title('Training Target (single Z plane)');\n","plt.savefig('/content/TrainingDataExample_Fnet.png',bbox_inches='tight',pad_inches=0)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xyQZKby8yFME"},"source":["## **3.2. Data augmentation**\n","---\n",""]},{"cell_type":"markdown","metadata":{"id":"w_jCy7xOx2g3"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n","Data augmentation is performed here by rotating images in XY-Plane and flip them along X-Axis. This only works if the patches are square in XY.\n","\n","**Note:** Using a full augmented dataset can exceed the RAM limitations of the colab notebook. If the augmented dataset is too large, the notebook will therefore only pick a subset of the augmented dataset for training. Make sure you only augment datasets which are small (ca. 20-30 images)."]},{"cell_type":"code","metadata":{"id":"DMqWq5-AxnFU","cellView":"form"},"source":["from skimage import io\n","import numpy as np\n","\n","Use_Data_augmentation = True #@param{type:\"boolean\"}\n","\n","#@markdown Select this option if you want to use augmentation to increase the size of your dataset\n","\n","#@markdown **Rotate each image 3 times by 90 degrees.**\n","Rotation = True #@param{type:\"boolean\"}\n","\n","#@markdown **Flip each image once around the x axis of the stack.**\n","Flip = True #@param{type:\"boolean\"}\n","\n","\n","#@markdown **Would you like to save your augmented images?**\n","\n","Save_augmented_images = True #@param {type:\"boolean\"}\n","\n","Saving_path = \"\" #@param {type:\"string\"}\n","\n","\n","if not Save_augmented_images:\n"," Saving_path= \"/content\"\n","\n","\n","def rotation_aug(Source_path, Target_path, aug_source_dest='augmented_source', aug_target_dest='augmented_target', flip=False):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path)\n"," \n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," # Source Rotation\n"," source_img_90 = np.rot90(source_img,axes=(0,1))\n"," source_img_180 = np.rot90(source_img_90,axes=(0,1))\n"," source_img_270 = np.rot90(source_img_180,axes=(0,1))\n","\n"," # Target Rotation\n"," target_img_90 = np.rot90(target_img,axes=(0,1))\n"," target_img_180 = np.rot90(target_img_90,axes=(0,1))\n"," target_img_270 = np.rot90(target_img_180,axes=(0,1))\n","\n"," # Add a flip to the rotation\n"," \n"," if flip == True:\n"," source_img_lr = np.fliplr(source_img)\n"," source_img_90_lr = np.fliplr(source_img_90)\n"," source_img_180_lr = np.fliplr(source_img_180)\n"," source_img_270_lr = np.fliplr(source_img_270)\n","\n"," target_img_lr = np.fliplr(target_img)\n"," target_img_90_lr = np.fliplr(target_img_90)\n"," target_img_180_lr = np.fliplr(target_img_180)\n"," target_img_270_lr = np.fliplr(target_img_270)\n","\n"," #source_img_90_ud = np.flipud(source_img_90)\n"," \n"," # Save the augmented files\n"," # Source images\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+image,source_img)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_90.tif',source_img_90)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_180.tif',source_img_180)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_270.tif',source_img_270)\n"," # Target images\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+image,target_img)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_90.tif',target_img_90)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_180.tif',target_img_180)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_270.tif',target_img_270)\n","\n"," if flip == True:\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_90_lr.tif',source_img_90_lr)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_180_lr.tif',source_img_180_lr)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_270_lr.tif',source_img_270_lr)\n","\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_90_lr.tif',target_img_90_lr)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_180_lr.tif',target_img_180_lr)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_270_lr.tif',target_img_270_lr)\n","\n","def flip(Source_path, Target_path, aug_source_dest='augmented_source', aug_target_dest='augmented_target'):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path) \n","\n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," source_img_lr = np.fliplr(source_img)\n"," target_img_lr = np.fliplr(target_img)\n","\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+image,source_img)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n","\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+image,target_img)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n","\n","\n","if Use_Data_augmentation:\n","\n"," if os.path.exists(Saving_path+'/augmented_source'):\n"," shutil.rmtree(Saving_path+'/augmented_source')\n"," os.mkdir(Saving_path+'/augmented_source')\n","\n"," if os.path.exists(Saving_path+'/augmented_target'):\n"," shutil.rmtree(Saving_path+'/augmented_target') \n"," os.mkdir(Saving_path+'/augmented_target')\n"," \n"," if os.path.exists(Saving_path+'/augmented_validation_source'):\n"," shutil.rmtree(Saving_path+'/augmented_validation_source') \n"," os.mkdir(Saving_path+'/augmented_validation_source')\n"," \n"," if os.path.exists(Saving_path+'/augmented_validation_target'):\n"," shutil.rmtree(Saving_path+'/augmented_validation_target') \n"," os.mkdir(Saving_path+'/augmented_validation_target')\n"," \n"," print(\"Data augmentation enabled\")\n"," print(\"Data augmentation in progress....\")\n","\n"," if Rotation == True:\n"," rotation_aug('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name,flip=Flip)\n"," rotation_aug('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input','/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Target', aug_source_dest='augmented_validation_source', aug_target_dest='augmented_validation_target', flip=Flip)\n"," elif Rotation == False and Flip == True:\n"," flip('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n"," flip('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input','/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Target', aug_source_dest='augmented_validation_source', aug_target_dest='augmented_validation_target')\n"," if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input')\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Target')\n","\n"," #Fetch the path and extract the name of the signal folder\n"," Training_source = Saving_path+\"/augmented_source\"\n"," source_name = os.path.basename(os.path.normpath(Training_source))\n","\n"," #Fetch the path and extract the name of the target folder\n"," Training_target = Saving_path+\"/augmented_target\"\n"," target_name = os.path.basename(os.path.normpath(Training_target))\n","\n"," #Fetch the path and extract the name of the Validation source folder\n"," Validation_source = Saving_path+'/augmented_validation_source'\n"," Validation_target = Saving_path+'/augmented_validation_target'\n","\n"," if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n","\n"," shutil.copytree(Training_source,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," shutil.copytree(Training_target,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n"," shutil.copytree(Validation_source,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input')\n"," shutil.copytree(Validation_target,'/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Target')\n","\n","\n"," os.chdir('/content/gdrive/My Drive/pytorch_fnet/data')\n"," #Redefine the source and target lists after moving the validation files\n"," source = os.listdir('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n"," target = os.listdir('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n","\n"," #Define Validation file lists\n"," val_signal = os.listdir('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input')\n"," val_target = os.listdir('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Target')\n","\n"," if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'_val.csv'):\n"," os.remove('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'_val.csv')\n","\n"," #Finally, we create a validation csv file to construct the validation dataset\n"," with open(model_name+'_val.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(val_signal)):\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+model_name+\"/Validation_Input/\"+val_signal[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+model_name+\"/Validation_Target/\"+val_target[i]])\n","\n"," shutil.move('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'_val.csv','/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'_val.csv')\n","\n"," if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'.csv'):\n"," os.remove('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'.csv')\n"," with open(model_name+'.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(source)):\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+model_name+\"/\"+source_name+\"/\"+source[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+model_name+\"/\"+target_name+\"/\"+target[i]])\n","\n"," shutil.move('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'.csv','/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'.csv')\n","\n"," #Here, we ensure that the all files, including Validation are saved somewhere together for later access, e.g. for retraining.\n"," for image in os.listdir('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input'):\n"," shutil.copyfile(os.path.join('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input',image),Saving_path+'/augmented_source/'+image)\n"," shutil.copyfile(os.path.join('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Target',image),Saving_path+'/augmented_target/'+image)\n"," \n"," #Here, we ensure that there aren't too many images in the buffer.\n"," #The best value will depend on the size of the images and the assigned GPU.\n"," #If too many images are loaded to the buffer the notebook will terminate the training as the RAM limit will be exceeded.\n"," if len(source)>400:\n"," number_of_images = 400\n"," else:\n"," number_of_images = len(source)\n","\n"," os.chdir(\"/content/gdrive/My Drive/pytorch_fnet/scripts\")\n"," !chmod u+x train_model.sh\n"," !sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" train_model.sh #change the number of training images\n","\n"," print(\"Done\")\n","if not Use_Data_augmentation:\n"," print(bcolors.WARNING+\"Data augmentation disabled\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Nyf9ndiS7sL9"},"source":["#**4. Train the network**\n","---\n","Before training, carefully read the different options. This applies especially if you have trained fnet on a dataset before.\n","\n","\n","###**Choose one of the options to train fnet**.\n","\n","**4.1.** If this is the first training on the chosen dataset, play this section to start training.\n","\n","**4.2.** If you want to continue training on an already pre-trained model choose this section\n","\n"," **Carefully read the options before starting training.**"]},{"cell_type":"markdown","metadata":{"id":"P9OJ0nlI71Rc"},"source":["##**4.1. Start Training**\n","---\n","\n","####Play the cell below to start training. \n","\n","**Note:** If you are training with a model of the same name as before, the model will be overwritten. If you want to keep the previous model save it before playing the cell below or give your model a different name (section 3).\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"code","metadata":{"id":"X8YHeSGr76je","cellView":"form"},"source":["#@markdown ####If your dataset is large the notebook might crash unexpectedly when loading the training data into the buffer. If this happens, reduce the number of images to be loaded into the buffer and restart the training.\n","os.chdir(\"/content/gdrive/My Drive/pytorch_fnet/scripts\")\n","number_of_images = 50 #@param{type:\"number\"}\n","!chmod u+x train_model.sh\n","!sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" train_model.sh #change the number of training images"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"7Ofm-71T8ABX"},"source":["#@markdown ##Start training\n","#pdf_export(augmentation = Use_Data_augmentation)\n","start = time.time()\n","\n","#Overwriting old models and saving them separately if True\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+model_name):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+model_name)\n","\n","#This tifffile release runs error-free in this version of fnet.\n","!pip install tifffile==2019.7.26\n","\n","#Here we import an additional module to the functions.py file to run it without errors.\n","\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/')\n","insert_line_to_file(\"/content/gdrive/My Drive/pytorch_fnet/fnet/functions.py\",5,\"import fnet.fnet_model\")\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," shutil.rmtree(model_path+'/'+model_name)\n"," \n","print('Let''s start the training!')\n","#Here we start the training\n","!./scripts/train_model.sh $model_name 0\n","\n","#After training overwrite any existing model in the model_path with the new trained model.\n","# if os.path.exists(model_path+'/'+model_name):\n","# shutil.rmtree(model_path+'/'+model_name)\n","shutil.copytree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+model_name,model_path+'/'+model_name)\n","\n","shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+model_name+'_val.csv',model_path+'/'+model_name+'/'+model_name+'_val.csv')\n","#Get rid of duplicates of training data in pytorch_fnet after training completes\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+source_name)\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/'+target_name)\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Input')\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+model_name+'/Validation_Target')\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","#Create a pdf document with training summary\n","\n","#pdf_export(trained = True, augmentation = Use_Data_augmentation)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"bOdyjxWV8IrO"},"source":["**Note:** Fnet takes a long time for training. If your notebook times out due to the length of the training or due to a loss of GPU acceleration the last checkpoint will be saved in the saved_models folder in the pytorch_fnet folder. If you want to save it in a more convenient location on your drive, remount the drive (if you got disconnected) and in the next cell enter the location (`model_path`) where you want to save the model (`model_name`) before continuing in 4.2. **If you did not time out you can ignore this section.**"]},{"cell_type":"code","metadata":{"cellView":"form","id":"aWJxOy-R8OhH"},"source":["#@markdown ##Play this cell if your model training timed out and indicate where you want to save the last checkpoint.\n","\n","import shutil\n","import os\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+model_name):\n"," shutil.copytree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+model_name,model_path+'/'+model_name)\n","else:\n"," print('This model name does not exist in your saved_models folder. Make sure you have entered the name of the model that timed out.')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-JxxMmVr8Tw-"},"source":["##**4.2. Training from a previously saved model**\n","---\n","This section allows you to use networks you have previously trained and saved and to continue training them for more training steps. The folders have the same meaning as above (3.1.). If you want to save the previously trained model, create a copy now as this section will overwrite the weights of the old model. **You can currently only train the model with the same dataset and batch size that the network was previously trained on.**\n","\n","**Note: To use this section the *pytorch_fnet* folder must be in your *gdrive/My Drive*. (Simply, play cell 2. to make sure).**"]},{"cell_type":"code","metadata":{"cellView":"form","id":"iDIgosht8U7F"},"source":["#@markdown To test if performance improves after the initial training, you can continue training on the old model. This option can also be useful if Colab disconnects or times out.\n","#@markdown Enter the paths of the datasets you want to continue training on.\n","\n","#Here we replace values in the old files\n","\n","insert = 'PATH_DATASET_VAL_CSV=\"data/csvs/${DATASET}_val.csv\"'\n","append = '\\n --path_dataset_val_csv ${PATH_DATASET_VAL_CSV}'\n","\n","add_validation(\"/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh\",10,insert,append)\n","#Clear the White space from train.sh\n","\n","with open('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh', 'r') as inFile,\\\n"," open('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model_temp.sh', 'w') as outFile:\n"," for line in inFile:\n"," if line.strip():\n"," outFile.write(line)\n","os.remove('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh')\n","os.rename('/content/gdrive/My Drive/pytorch_fnet/scripts/train_model_temp.sh','/content/gdrive/My Drive/pytorch_fnet/scripts/train_model.sh')\n","\n","#Datasets\n","\n","#Change checkpoints\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/train_model.py\",\"'--interval_save', type=int, default=500\",\"'--interval_save', type=int, default=100\")\n","\n","#Adapt Class Dataset for Tiff files\n","replace(\"/content/gdrive/My Drive/pytorch_fnet/train_model.py\",\"'--class_dataset', default='CziDataset'\",\"'--class_dataset', default='TiffDataset'\")\n","\n","\n","Training_source = \"\" #@param {type: \"string\"}\n","source_name = os.path.basename(os.path.normpath(Training_source))\n","\n","#Fetch the path and extract the name of the signal folder\n","Training_target = \"\" #@param {type: \"string\"}\n","target_name = os.path.basename(os.path.normpath(Training_target))\n","\n","Pretrained_model_folder = \"\" #@param{type:\"string\"}\n","#model_name = \"\" #@param {type:\"string\"}\n","\n","Pretrained_model_name = os.path.basename(Pretrained_model_folder)\n","Pretrained_model_path = os.path.dirname(Pretrained_model_folder)\n","batch_size = 4 #@param {type:\"number\"}\n","\n","Pretrained_model_name_x = Pretrained_model_name+\"}\"\n","\n","#Move your model to fnet\n","if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Pretrained_model_name):\n"," shutil.copytree(Pretrained_model_folder,'/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Pretrained_model_name)\n","\n","#Move the datasets into fnet\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name)\n","os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name)\n","shutil.copytree(Training_source,'/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/'+source_name)\n","shutil.copytree(Training_target,'/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/'+target_name)\n","\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/scripts')\n","\n","### number_of_images = len(os.listdir(Training_source)) ###\n","\n","#Change the train_model.sh file to include chosen dataset\n","!chmod u+x ./train_model.sh\n","!sed -i \"s/1:-.*/1:-$Pretrained_model_name_x/g\" train_model.sh\n","!sed -i \"s/train_size .* -v/train_size 1.0 -v/g\" train_model.sh #Use the whole training dataset for training\n","!sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" train_model.sh #change the number of training images\n","!sed -i \"s/BATCH_SIZE=.*/BATCH_SIZE=$batch_size/g\" train_model.sh #change the batch size\n","\n","\n","# We will use the same validation files from the training dataset as used before,\n","# This makes sure that the model is not validated with files it has seen in training before saving.\n","\n","#First we get the names of the validation files from the previous training which are saved in the validation csv.\n","val_source_list = []\n","\n","##CHECK THIS Prediction_model_name\n","if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name+'_val.csv'):\n"," shutil.copyfile(Pretrained_model_folder+'/'+Pretrained_model_name+'_val.csv','/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name+'_val.csv')\n","\n","with open('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name+'_val.csv', 'r') as f:\n","#with open(Pretrained_model_folder+'/'+Pretrained_model_name+'_val.csv', 'r') as f:\n"," contents = csv.reader(f,delimiter=',')\n"," for row in contents:\n"," val_source_list.append(row[0])\n","\n","#Get the file list without the header\n","val_source_list = val_source_list[1::]\n","\n","#Get only the file names and not the full path\n","for i in range(0,len(val_source_list)):\n"," val_source_list[i] = os.path.basename(os.path.normpath(val_source_list[i]))\n","\n","source = os.listdir('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/'+source_name)\n","\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Input'):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Input')\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Target'):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Target')\n","\n","#Make validation directories\n","os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Input')\n","os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Target')\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data')\n","\n","#Move a random set of files from the training to the validation folders\n","for file in val_source_list:\n"," #os.chdir('/content/gdrive/My Drive/pytorch_fnet/data')\n"," shutil.move('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/'+source_name+'/'+file,'/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Input/'+file)\n"," shutil.move('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/'+target_name+'/'+file,'/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/Validation_Target/'+file)\n","\n","#Redefine the source and target lists after moving the validation files\n","source = os.listdir('./'+Pretrained_model_name+'/'+source_name)\n","target = os.listdir('./'+Pretrained_model_name+'/'+target_name)\n","\n","#Define Validation file lists\n","val_signal = os.listdir('./'+Pretrained_model_name+'/Validation_Input')\n","val_target = os.listdir('./'+Pretrained_model_name+'/Validation_Target')\n","\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name+'_val.csv'):\n"," os.remove('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name+'_val.csv')\n","\n","shutil.copyfile(Pretrained_model_folder+'/'+Pretrained_model_name+'_val.csv','/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name+'_val.csv')\n","\n","#Make a training csv file.\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name)\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data')\n","source = os.listdir('./'+Pretrained_model_name+'/'+source_name)\n","target = os.listdir('./'+Pretrained_model_name+'/'+target_name)\n","with open(Pretrained_model_name+'.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(source)):\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Pretrained_model_name+\"/\"+source_name+\"/\"+source[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Pretrained_model_name+\"/\"+target_name+\"/\"+target[i]])\n","\n","shutil.move('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'.csv','/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name+'.csv')\n","\n","#Find the number of previous training iterations (steps) from loss csv file\n","\n","with open(Pretrained_model_folder+'/losses.csv') as f:\n"," previous_steps = sum(1 for line in f)\n","print('continuing training after step '+str(previous_steps-1))\n","\n","print('To start re-training play section 4.2. below')\n","\n","#@markdown For how many additional steps do you want to train the model?\n","add_steps = 10000#@param {type:\"number\"}\n","\n","#Calculate the new number of total training epochs. Subtract 1 to discount the title row of the csv file.\n","new_steps = previous_steps + add_steps -1\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/scripts')\n","\n","#Edit train_model.sh file to include new total number of training epochs\n","!sed -i \"s/N_ITER=.*/N_ITER=$new_steps/g\" train_model.sh\n","\n","#Load one randomly chosen training source file\n","random_choice=random.choice(os.listdir(Training_source))\n","x = io.imread(Training_source+\"/\"+random_choice)\n","\n","#Find image Z dimension and select the mid-plane\n","Image_Z = x.shape[0]\n","mid_plane = int(Image_Z / 2)+1\n","\n","os.chdir(Training_target)\n","y = io.imread(Training_target+\"/\"+random_choice)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x[mid_plane], norm=simple_norm(x[mid_plane], percent = 99), interpolation='nearest')\n","plt.axis('off')\n","plt.title('Training Source (single Z plane)');\n","plt.subplot(1,2,2)\n","plt.imshow(y[mid_plane], norm=simple_norm(y[mid_plane], percent = 99), interpolation='nearest')\n","plt.axis('off')\n","plt.title('Training Target (single Z plane)');\n","plt.savefig('/content/TrainingDataExample_Fnet.png',bbox_inches='tight',pad_inches=0)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"cellView":"form","id":"5IXdFqhM8gO2"},"source":["start = time.time()\n","\n","#@markdown ##4.2. Start re-training model\n","!pip install tifffile==2019.7.26\n","\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/fnet')\n","\n","insert_line_to_file(\"/content/gdrive/My Drive/pytorch_fnet/fnet/functions.py\",5,\"import fnet.fnet_model\")\n","\n","#Here we retrain the model on the chosen dataset.\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/')\n","!chmod u+x ./scripts/train_model.sh\n","!./scripts/train_model.sh $Pretrained_model_name 0\n","\n","if os.path.exists(Pretrained_model_folder):\n"," shutil.rmtree(Pretrained_model_folder)\n","shutil.copytree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Pretrained_model_name,Pretrained_model_folder)\n","\n","#Get rid of duplicates of training data in pytorch_fnet after training completes\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/'+source_name)\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+Pretrained_model_name+'/'+target_name)\n","\n","shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Pretrained_model_name+'_val.csv',Pretrained_model_folder+'/'+Pretrained_model_name+'_val.csv')\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","min, sec = divmod(dt, 60) \n","hour, min = divmod(min, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",min,\"min(s)\",round(sec),\"sec(s)\")\n","\n","#Create a pdf document with training summary\n","\n","# save FPDF() class into a \n","# variable pdf \n","from datetime import datetime\n","\n","class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n","pdf = MyFPDF()\n","pdf.add_page()\n","pdf.set_right_margin(-1)\n","pdf.set_font(\"Arial\", size = 11, style='B') \n","\n","Network = 'Label-free Prediction (fnet)'\n","day = datetime.now()\n","date_time = str(day)[0:10]\n","\n","Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+date_time\n","pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," \n","# add another cell \n","training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n","pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n","pdf.ln(1)\n","\n","Header_2 = 'Information for your materials and methods:'\n","pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n","all_packages = ''\n","for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","#print(all_packages)\n","\n","#Main Packages\n","main_packages = ''\n","version_numbers = []\n","for name in ['tensorflow','numpy','torch','scipy']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n","cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n","cuda_version = cuda_version.stdout.decode('utf-8')\n","cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n","gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n","gpu_name = gpu_name.stdout.decode('utf-8')\n","gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n","#print(cuda_version[cuda_version.find(', V')+3:-1])\n","#print(gpu_name)\n","\n","shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n","dataset_size = len(os.listdir(Training_source))\n","\n","text = 'The '+Network+' model was trained from scratch for '+str(steps)+' steps on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(32)+','+str(64)+','+str(64)+')) with a batch size of '+str(batch_size)+' and an MSE loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was retrained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), torch (v '+version_numbers[2]+'), scipy (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","pdf.multi_cell(190, 5, txt = text, align='L')\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.ln(1)\n","pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n","pdf.set_font('')\n","if Use_Data_augmentation:\n"," aug_text = 'The dataset was augmented by'\n"," if Rotation:\n"," aug_text = aug_text+'\\n- rotation'\n"," if Flip:\n"," aug_text = aug_text+'\\n- flipping'\n","else:\n"," aug_text = 'No augmentation was used for training.'\n","pdf.multi_cell(190, 5, txt=aug_text, align='L')\n","pdf.set_font('Arial', size = 11, style = 'B')\n","pdf.ln(1)\n","pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n","pdf.set_font('')\n","pdf.set_font_size(10.)\n","# if Use_Default_Advanced_Parameters:\n","# pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n","pdf.cell(200, 5, txt='The following parameters were used for training:')\n","pdf.ln(1)\n","html = \"\"\" \n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
percentage_validation{0}
steps{1}
batch_size{2}
\n","\"\"\".format(percentage_validation,steps,batch_size)\n","pdf.write_html(html)\n","\n","#pdf.multi_cell(190, 5, txt = text_2, align='L')\n","pdf.set_font(\"Arial\", size = 11, style='B')\n","pdf.ln(1)\n","pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.cell(30, 5, txt= 'Training_source:', align = 'L', ln=0)\n","pdf.set_font('')\n","pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.cell(28, 5, txt= 'Training_target:', align = 'L', ln=0)\n","pdf.set_font('')\n","pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n","#pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n","pdf.ln(1)\n","pdf.set_font('')\n","pdf.set_font('Arial', size = 10, style = 'B')\n","pdf.cell(22, 5, txt= 'Model Path:', align = 'L', ln=0)\n","pdf.set_font('')\n","pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n","pdf.ln(1)\n","pdf.cell(60, 5, txt = 'Example Training pair (single slice)', ln=1)\n","pdf.ln(1)\n","exp_size = io.imread(model_path+'/TrainingDataExample_Fnet.png').shape\n","pdf.image(model_path+'/TrainingDataExample_Fnet.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n","pdf.ln(1)\n","ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n","pdf.multi_cell(190, 5, txt = ref_1, align='L')\n","ref_2 = '- Label-free Prediction (fnet): Ounkomol, Chawin, et al. \"Label-free prediction of three-dimensional fluorescence images from transmitted-light microscopy.\" Nature methods 15.11 (2018): 917-920.'\n","pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","pdf.ln(3)\n","reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n","pdf.set_font('Arial', size = 11, style='B')\n","pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n","pdf.output(Prediction_model_folder+'/'+Prediction_model_name+'_'+date_time+\"_training_report.pdf\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"_0Hynw3-xHp1"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**"]},{"cell_type":"code","metadata":{"id":"eAJzMwPA6tlH","cellView":"form"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","\n","Use_the_current_trained_model = False #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the name of the model and path to model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","#Create a folder for the quality control metrics\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"dhJROwlAMv5o"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased.\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"vMzSP50kMv5p","cellView":"form"},"source":["#@markdown ##Play the cell to show figure of training errors\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","iterationNumber_training = []\n","iterationNumber_val = []\n","\n","import csv\n","from matplotlib import pyplot as plt\n","with open(QC_model_path+'/'+QC_model_name+'/'+'losses.csv','r') as csvfile:\n"," plots = csv.reader(csvfile, delimiter=',')\n"," next(plots)\n"," for row in plots:\n"," iterationNumber_training.append(int(row[0]))\n"," lossDataFromCSV.append(float(row[1]))\n","\n","with open(QC_model_path+'/'+QC_model_name+'/'+'losses_val.csv','r') as csvfile_val:\n"," plots = csv.reader(csvfile_val, delimiter=',')\n"," next(plots)\n"," for row in plots:\n"," iterationNumber_val.append(int(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(iterationNumber_training, lossDataFromCSV, label='Training loss')\n","plt.plot(iterationNumber_val, vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. iteration number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Iteration')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(iterationNumber_training, lossDataFromCSV, label='Training loss')\n","plt.semilogy(iterationNumber_val, vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. iteration number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Iteration')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png', bbox_inches='tight', pad_inches=0)\n","plt.show()\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"X5_92nL2xdP6"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n","\n","\n","**Note:** If you receive a *CUDA out of memory* error, this can be caused by the size of the data that model needs to predict or the type of GPU has allocated to your session. To solve this issue, you can *factory reset runtime* to attempt to connect to a different GPU or use a dataset with smaller images.\n"]},{"cell_type":"code","metadata":{"id":"w90MdriMxhjD"},"source":["#Overwrite results folder if it already exists at the given location\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/results'):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/results')\n","\n","!pip install -U scipy==1.2.0\n","!pip install --no-cache-dir tifffile==2019.7.26 \n","from distutils.dir_util import copy_tree\n","\n","#----------------CREATING PREDICTIONS FOR QUALITY CONTROL----------------------------------#\n","\n","\n","#Choose the folder with the quality control datasets\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","Predictions_name = \"QualityControl\" \n","Predictions_name_x = Predictions_name+\"}\"\n","\n","#If the folder you are creating already exists, delete the existing version to overwrite.\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+Predictions_name):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+Predictions_name)\n","\n","if Use_the_current_trained_model == True:\n"," #Move the contents of the saved_models folder from your training to the new folder\n"," #Here, we use a different copyfunction as we only need the contents of the trained_model folder\n"," copy_tree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+QC_model_name,'/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name)\n","else:\n"," copy_tree(QC_model_path+'/'+QC_model_name,'/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name)\n"," #dataset = QC_model_name\n","\n","# Get the name of the folder the test data is in\n","source_dataset_name = os.path.basename(os.path.normpath(Source_QC_folder))\n","target_dataset_name = os.path.basename(os.path.normpath(Target_QC_folder))\n","\n","# Get permission to the predict.sh file and change the name of the dataset to the Predictions_folder.\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/')\n","!chmod u+x /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict_2d.sh\n","!sed -i \"1,21!d\" /content/gdrive/MyDrive/pytorch_fnet/scripts/predict_2d.sh\n","\n","!sed -i \"s/1:-.*/1:-$Predictions_name_x/g\" /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict_2d.sh\n","\n","#Here, we remove the 'train' option from predict.sh as we don't need to run predictions on the train data.\n","#!sed -i \"s/in test.*/in test/g\" /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh\n","\n","#Check that we are using .tif files\n","# file_list = os.listdir(Source_QC_folder)\n","# text = file_list[0]\n","\n","# if text.endswith('.tif') or text.endswith('.tiff'):\n","# !chmod u+x /content/gdrive/My\\ Drive/pytorch_fnet//scripts/predict.sh\n","# !if ! grep class_dataset /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh;then sed -i 's/DIR} \\\\/DIR} \\\\\\'$''\\n' --class_dataset TiffDataset \\\\/' /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh; fi\n","# !if grep CziDataset /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh;then sed -i 's/CziDataset/TiffDataset/' /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh; fi \n","\n","#Create test_data folder in pytorch_fnet\n","\n","# If your test data is not in the pytorch_fnet data folder it needs to be copied there.\n","if Use_the_current_trained_model == True:\n"," if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+QC_model_name+'/'+source_dataset_name):\n"," shutil.copytree(Source_QC_folder,'/content/gdrive/My Drive/pytorch_fnet/data/'+QC_model_name+'/'+source_dataset_name)\n"," # shutil.copytree(Target_QC_folder,'/content/gdrive/My Drive/pytorch_fnet/data/'+QC_model_name+'/'+target_dataset_name)\n","else:\n"," if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+Predictions_name+'/'+source_dataset_name):\n"," shutil.copytree(Source_QC_folder,'/content/gdrive/My Drive/pytorch_fnet/data/'+Predictions_name+'/'+source_dataset_name)\n"," # shutil.copytree(Target_QC_folder,'/content/gdrive/My Drive/pytorch_fnet/data/'+Predictions_name+'/'+target_dataset_name)\n","\n","\n","# Make a folder that will hold the test.csv file in your new folder\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs')\n","if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name):\n"," os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name)\n","\n","\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs/')\n","\n","#Make a new folder in saved_models to use the trained model for inference.\n","if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name):\n"," os.mkdir('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name) \n","\n","\n","#Get file list from the folders containing the files you want to use for inference.\n","#test_signal = os.listdir('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+source_dataset_name)\n","test_signal = os.listdir(Source_QC_folder)\n","test_target = os.listdir(Target_QC_folder)\n","#Now we make a path csv file to point the predict.sh file to the correct paths for the inference files.\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name+'/')\n","\n","#If an old test csv exists we want to overwrite it, so we can insert new test data.\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name+'/test.csv'):\n"," os.remove('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name+'/test.csv')\n","\n","#Here we create a new test.csv\n","with open('test.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(test_signal)):\n"," # if Use_the_current_trained_model == True:\n"," # writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+QC_model_name+\"/\"+source_dataset_name+\"/\"+test_signal[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+QC_model_name+\"/\"+target_dataset_name+\"/\"+test_signal[i]])\n"," # # This currently assumes that the names are identical for source and target: see \"test_target\" variable is never used\n"," # else:\n"," # writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Predictions_name+\"/\"+source_dataset_name+\"/\"+test_signal[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Predictions_name+\"/\"+target_dataset_name+\"/\"+test_signal[i]])\n"," if Use_the_current_trained_model ==True:\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+QC_model_name+\"/\"+source_dataset_name+\"/\"+test_signal[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+QC_model_name+\"/\"+source_dataset_name+\"/\"+test_signal[i]])\n"," else:\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Predictions_name+\"/\"+source_dataset_name+\"/\"+test_signal[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Predictions_name+\"/\"+source_dataset_name+\"/\"+test_signal[i]])\n","#We run the predictions\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/')\n","!scripts/predict_2d.sh $Predictions_name 0\n","\n","#Save the results\n","QC_results_files = os.listdir('/content/gdrive/My Drive/pytorch_fnet/results/'+Predictions_name+'/test')\n","\n","if os.path.exists(QC_model_path+'/'+QC_model_name+'/Quality Control/Prediction'):\n"," shutil.rmtree(QC_model_path+'/'+QC_model_name+'/Quality Control/Prediction')\n","os.mkdir(QC_model_path+'/'+QC_model_name+'/Quality Control/Prediction')\n","\n","if os.path.exists(QC_model_path+'/'+QC_model_name+'/Quality Control/Signal'):\n"," shutil.rmtree(QC_model_path+'/'+QC_model_name+'/Quality Control/Signal')\n","os.mkdir(QC_model_path+'/'+QC_model_name+'/Quality Control/Signal')\n","\n","if os.path.exists(QC_model_path+'/'+QC_model_name+'/Quality Control/Target'):\n"," shutil.rmtree(QC_model_path+'/'+QC_model_name+'/Quality Control/Target')\n","os.mkdir(QC_model_path+'/'+QC_model_name+'/Quality Control/Target')\n","\n","for i in range(len(QC_results_files)-2):\n"," shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/results/'+Predictions_name+'/test/'+QC_results_files[i]+'/prediction_'+Predictions_name+'.tiff', QC_model_path+'/'+QC_model_name+'/Quality Control/Prediction/'+'Predicted_'+test_signal[i])\n"," shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/results/'+Predictions_name+'/test/'+QC_results_files[i]+'/signal.tiff', QC_model_path+'/'+QC_model_name+'/Quality Control/Signal/'+test_signal[i])\n"," shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/results/'+Predictions_name+'/test/'+QC_results_files[i]+'/target.tiff', QC_model_path+'/'+QC_model_name+'/Quality Control/Target/'+test_signal[i])\n","\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/results')\n","\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+QC_model_name):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/data/'+QC_model_name)\n","\n","\n","#-----------------------------METRICS EVALUATION-------------------------------#\n","\n","# Calculating the position of the mid-plane slice\n","# Perform prediction on all datasets in the Source_QC folder\n","\n","#Finding the middle slice\n","img = io.imread(os.path.join(Source_QC_folder, os.listdir(Source_QC_folder)[0]))\n","n_slices = img.shape[0]\n","z_mid_plane = int(n_slices / 2)+1\n","\n","path_metrics_save = QC_model_path+'/'+QC_model_name+'/Quality Control/'\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(path_metrics_save+'QC_metrics_'+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"File name\",\"Slice #\",\"Prediction v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Prediction v. GT PSNR\"]) \n"," \n"," # These lists will be used to collect all the metrics values per slice\n"," file_name_list = []\n"," slice_number_list = []\n"," mSSIM_GvP_list = []\n"," NRMSE_GvP_list = []\n"," PSNR_GvP_list = []\n","\n"," # These lists will be used to display the mean metrics for the stacks\n"," mSSIM_GvP_list_mean = []\n"," NRMSE_GvP_list_mean = []\n"," PSNR_GvP_list_mean = []\n","\n"," # Let's loop through the provided dataset in the QC folders\n"," for thisFile in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder, thisFile)):\n"," print('Running QC on: '+thisFile)\n","\n"," test_GT_stack = io.imread(os.path.join(Target_QC_folder, thisFile))\n"," if len(test_GT_stack.shape) > 2:\n"," test_GT_stack = test_GT_stack.squeeze()\n"," test_source_stack = io.imread(os.path.join(Source_QC_folder,thisFile))\n"," test_prediction_stack = io.imread(os.path.join(path_metrics_save+\"Prediction/\",'Predicted_'+thisFile))\n"," #test_prediction_stack = np.squeeze(test_prediction_stack,axis=(0,))\n"," #n_slices = test_GT_stack.shape[0]\n","\n"," img_SSIM_GTvsPrediction_stack = np.zeros((test_GT_stack.shape[0], test_GT_stack.shape[1]))\n"," img_RSE_GTvsPrediction_stack = np.zeros((test_GT_stack.shape[0], test_GT_stack.shape[1]))\n","\n"," #for z in range(n_slices): \n"," \n"," # -------------------------------- Prediction --------------------------------\n","\n"," test_GT_norm,test_prediction_norm = norm_minmse(test_GT_stack, test_prediction_stack, normalize_gt=True)\n","\n"," # -------------------------------- Calculate the SSIM metric and maps --------------------------------\n","\n"," # Calculate the SSIM maps and index\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = structural_similarity(test_GT_norm, np.squeeze(test_prediction_norm,0), data_range=1.0, full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n"," #Calculate ssim_maps\n"," img_SSIM_GTvsPrediction_stack = np.float32(img_SSIM_GTvsPrediction)\n"," \n","\n"," # -------------------------------- Calculate the NRMSE metrics --------------------------------\n","\n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n","\n"," # Calculate SE maps\n"," img_RSE_GTvsPrediction_stack = np.float32(img_RSE_GTvsPrediction)\n","\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n","\n","\n"," # Calculate the PSNR between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,np.squeeze(test_prediction_norm,0),data_range=1.0)\n","\n","\n"," writer.writerow([thisFile,str(index_SSIM_GTvsPrediction),str(NRMSE_GTvsPrediction),str(PSNR_GTvsPrediction)])\n"," \n"," # Collect values to display in dataframe output\n"," #file_name_list.append(thisFile)\n"," #slice_number_list.append(z)\n"," mSSIM_GvP_list.append(index_SSIM_GTvsPrediction)\n","\n"," NRMSE_GvP_list.append(NRMSE_GTvsPrediction)\n","\n"," PSNR_GvP_list.append(PSNR_GTvsPrediction)\n","\n","\n"," #if (z == z_mid_plane): # catch these for display\n"," SSIM_GTvsP_forDisplay = index_SSIM_GTvsPrediction\n","\n"," NRMSE_GTvsP_forDisplay = NRMSE_GTvsPrediction\n","\n"," \n"," # If calculating average metrics for dataframe output\n"," file_name_list.append(thisFile)\n"," mSSIM_GvP_list_mean.append(sum(mSSIM_GvP_list)/len(mSSIM_GvP_list))\n","\n"," NRMSE_GvP_list_mean.append(sum(NRMSE_GvP_list)/len(NRMSE_GvP_list))\n","\n"," PSNR_GvP_list_mean.append(sum(PSNR_GvP_list)/len(PSNR_GvP_list))\n","\n"," # ----------- Change the stacks to 32 bit images -----------\n"," img_SSIM_GTvsPrediction_stack_32 = img_as_float32(img_SSIM_GTvsPrediction_stack, force_copy=False)\n"," img_RSE_GTvsPrediction_stack_32 = img_as_float32(img_RSE_GTvsPrediction_stack, force_copy=False)\n","\n","\n"," # ----------- Saving the error map stacks -----------\n"," io.imsave(path_metrics_save+'SSIM_GTvsPrediction_'+thisFile,img_SSIM_GTvsPrediction_stack_32)\n"," io.imsave(path_metrics_save+'RSE_GTvsPrediction_'+thisFile,img_RSE_GTvsPrediction_stack_32)\n","\n","#Averages of the metrics per stack as dataframe output\n","pdResults = pd.DataFrame(file_name_list, columns = [\"File name\"])\n","pdResults[\"Prediction v. GT mSSIM\"] = mSSIM_GvP_list_mean\n","\n","pdResults[\"Prediction v. GT NRMSE\"] = NRMSE_GvP_list_mean\n","\n","pdResults[\"Prediction v. GT PSNR\"] = PSNR_GvP_list_mean\n","\n","pdResults.head()\n","\n","# All data is now processed saved\n","Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same way\n","\n","plt.figure(figsize=(15,10))\n","# Currently only displays the last computed set, from memory\n","\n","# Target (Ground-truth)\n","plt.subplot(2,3,1)\n","plt.axis('off')\n","img_GT = io.imread(os.path.join(Target_QC_folder, Test_FileList[-1]))\n","#if len(img_GT.shape) > 2:\n"," # img_GT = img_GT\n","plt.imshow(img_GT)\n","plt.title('Target')\n","\n","\n","#Setting up colours\n","cmap = plt.cm.Greys\n","\n","\n","# Source\n","plt.subplot(2,3,2)\n","plt.axis('off')\n","img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_Source,aspect='equal',cmap=cmap)\n","plt.title('Source')\n","\n","\n","#Prediction\n","plt.subplot(2,3,3)\n","plt.axis('off')\n","img_Prediction = io.imread(os.path.join(path_metrics_save+'Prediction/', 'Predicted_'+Test_FileList[-1]))\n","img_Prediction = np.squeeze(img_Prediction,axis=(0,))\n","plt.imshow(img_Prediction)\n","plt.title('Prediction')\n","\n","#Setting up colours\n","cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Prediction\n","plt.subplot(2,3,5)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_SSIM_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'SSIM_GTvsPrediction_'+Test_FileList[-1]))\n","imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n","plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n","plt.title('SSIM map: Target vs. Prediction',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(SSIM_GTvsP_forDisplay,3)),fontsize=14)\n","\n","\n","#Root Squared Error between GT and Prediction\n","plt.subplot(2,3,6)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_RSE_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'RSE_GTvsPrediction_'+Test_FileList[-1]))\n","imRSE_GTvsPrediction = plt.imshow(np.squeeze(img_RSE_GTvsPrediction,0), cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n","plt.title('RSE map Target vs. Prediction',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsP_forDisplay,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n","plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","\n","print('-----------------------------------')\n","print('Here are the average scores for the stacks you tested in Quality control. To see values for all slices, open the .csv file saved in the Qulity Control folder.')\n","pdResults.head()\n","\n","\n","#Make a pdf summary of the QC results\n","\n","#qc_pdf_export()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-tJeeJjLnRkP"},"source":["#**6. Using the trained model**\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"d8wuQGjoq6eN"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Results_folder** folder.\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Results_folder`:** This folder will contain the predicted output images.\n","\n","If you want to use a model different from the most recently trained one, untick the box and enter the path of the model in **`Prediction_model_folder`**.\n","\n","**Note: `Prediction_model_folder` expects a folder name which contains a model.p file from a previous training.**\n","\n","**Note:** If you receive a *CUDA out of memory* error, this can be caused by the size of the data that model needs to predict or the type of GPU has allocated to your session. To solve this issue, you can *factory reset runtime* to attempt to connect to a different GPU or use a dataset with smaller images.\n"]},{"cell_type":"code","metadata":{"id":"y2TD5p7MZrEb","cellView":"form"},"source":["#Before prediction we will remove the old prediction folder because fnet won't execute if a path already exists that has the same name.\n","#This is just in case you have already trained on a dataset with the same name\n","#The data will be saved outside of the pytorch_folder (Results_folder) so it won't be lost when you run this section again.\n","\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/results'):\n"," shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/results')\n","\n","!pip install -U scipy==1.2.0\n","!pip install --no-cache-dir tifffile==2019.7.26 \n","\n","#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then play the cell to predict outputs from your unseen images.\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Results_folder = \"\" #@param {type:\"string\"}\n","\n","Predictions_name = 'TempPredictionFolder'\n","Predictions_name_x = Predictions_name+\"}\"\n","\n","#If the folder you are creating already exists, delete the existing version to overwrite.\n","if os.path.exists(Results_folder+'/'+Predictions_name):\n"," shutil.rmtree(Results_folder+'/'+Predictions_name)\n","\n","#@markdown ###Do you want to use the current trained model?\n","\n","Use_the_current_trained_model = True #@param{type:\"boolean\"}\n","\n","#@markdown ###If not, provide the name of the model you want to use \n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","if Use_the_current_trained_model:\n"," #Move the contents of the saved_models folder from your training to the new folder\n"," #Here, we use a different copyfunction as we only need the contents of the trained_model folder\n"," copy_tree('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+model_name,'/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name)\n","else:\n"," copy_tree(Prediction_model_path+'/'+Prediction_model_name,'/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name)\n"," #dataset = Prediction_model_name\n","\n","# Get the name of the folder the test data is in\n","test_dataset_name = os.path.basename(os.path.normpath(Data_folder))\n","\n","# Get permission to the predict.sh file and change the name of the dataset to the Predictions_folder.\n","# We also allow the maximum number of images to be processed to be higher, i.e. 1000.\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/')\n","!chmod u+x /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh\n","!sed -i \"s/N_IMAGES=.*/N_IMAGES=1000/g\" /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict_2d.sh\n","!sed -i \"s/1:-.*/1:-$Predictions_name_x/g\" /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict_2d.sh\n","\n","#Here, we remove the 'train' option from predict.sh as we don't need to run predictions on the train data.\n","!sed -i \"s/in test.*/in test/g\" /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh\n","!sed -i \"1,21!d\" /content/gdrive/MyDrive/pytorch_fnet/scripts/predict_2d.sh\n","\n","#Check that we are using .tif files\n","file_list = os.listdir(Data_folder)\n","text = file_list[0]\n","\n","if text.endswith('.tif') or text.endswith('.tiff'):\n"," !chmod u+x /content/gdrive/My\\ Drive/pytorch_fnet//scripts/predict.sh\n"," !if ! grep class_dataset /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh;then sed -i 's/DIR} \\\\/DIR} \\\\\\'$''\\n' --class_dataset TiffDataset \\\\/' /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh; fi\n"," !if grep CziDataset /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh;then sed -i 's/CziDataset/TiffDataset/' /content/gdrive/My\\ Drive/pytorch_fnet/scripts/predict.sh; fi \n","\n","#Create test_data folder in pytorch_fnet\n","\n","# If your test data is not in the pytorch_fnet data folder it needs to be copied there.\n","if Use_the_current_trained_model == True:\n"," if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+Prediction_model_name+'/'+test_dataset_name):\n"," shutil.copytree(Data_folder,'/content/gdrive/My Drive/pytorch_fnet/data/'+Prediction_model_name+'/'+test_dataset_name)\n","else:\n"," if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/'+Predictions_name+'/'+test_dataset_name):\n"," shutil.copytree(Data_folder,'/content/gdrive/My Drive/pytorch_fnet/data/'+Predictions_name+'/'+test_dataset_name)\n","\n","\n","# Make a folder that will hold the test.csv file in your new folder\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs')\n","if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name):\n"," os.mkdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name)\n","\n","\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs/')\n","\n","#Make a new folder in saved_models to use the trained model for inference.\n","if not os.path.exists('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name):\n"," os.mkdir('/content/gdrive/My Drive/pytorch_fnet/saved_models/'+Predictions_name) \n","\n","\n","#Get file list from the folders containing the files you want to use for inference.\n","#test_signal = os.listdir('/content/gdrive/My Drive/pytorch_fnet/data/'+dataset+'/'+test_dataset_name)\n","test_signal = os.listdir(Data_folder)\n","\n","#Now we make a path csv file to point the predict.sh file to the correct paths for the inference files.\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name+'/')\n","\n","#If an old test csv exists we want to overwrite it, so we can insert new test data.\n","if os.path.exists('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name+'/test.csv'):\n"," os.remove('/content/gdrive/My Drive/pytorch_fnet/data/csvs/'+Predictions_name+'/test.csv')\n","\n","#Here we create a new test.csv\n","with open('test.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(test_signal)):\n"," if Use_the_current_trained_model ==True:\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Prediction_model_name+\"/\"+test_dataset_name+\"/\"+test_signal[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Prediction_model_name+\"/\"+test_dataset_name+\"/\"+test_signal[i]])\n"," else:\n"," writer.writerow([\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Predictions_name+\"/\"+test_dataset_name+\"/\"+test_signal[i],\"/content/gdrive/My Drive/pytorch_fnet/data/\"+Predictions_name+\"/\"+test_dataset_name+\"/\"+test_signal[i]])\n","\n","#We run the predictions\n","start = time.time()\n","os.chdir('/content/gdrive/My Drive/pytorch_fnet/')\n","!chmod u+x scripts/predict_2d.sh\n","!scripts/predict_2d.sh $Predictions_name 0\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","#Save the results\n","results_files = os.listdir('/content/gdrive/My Drive/pytorch_fnet/results/'+Predictions_name+'/test')\n","for i in range(len(results_files)-2):\n"," shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/results/'+Predictions_name+'/test/'+results_files[i]+'/prediction_'+Predictions_name+'.tiff', Results_folder+'/'+'Prediction_'+test_signal[i])\n"," shutil.copyfile('/content/gdrive/My Drive/pytorch_fnet/results/'+Predictions_name+'/test/'+results_files[i]+'/signal.tiff', Results_folder+'/'+test_signal[i])\n","\n","#Comment this out if you want to see the total original results from the prediction in the pytorch_fnet folder.\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet/results')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"bFtArIjs9tS9"},"source":["##**6.2. Assess predicted output**\n","---\n","Here, we inspect an example prediction from the predictions on the test dataset. Select the slice of the slice you want to visualize."]},{"cell_type":"code","metadata":{"id":"66-af3rO9vM4","cellView":"form"},"source":["!pip install matplotlib==2.2.3\n","import numpy as np\n","import matplotlib.pyplot as plt\n","from skimage import io\n","import os\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","\n","#@markdown ###Select the image would you like to view below\n","\n","def show_image(file=os.listdir(Data_folder)):\n"," os.chdir(Results_folder)\n","\n","#source_image = io.imread(test_signal[0])\n"," source_image = io.imread(os.path.join(Data_folder,file))\n"," prediction_image = io.imread(os.path.join(Results_folder,'Prediction_'+file))\n"," prediction_image = np.squeeze(prediction_image, axis=(0,))\n","\n","#Create the figure\n"," fig = plt.figure(figsize=(10,20))\n","\n"," #Setting up colours\n"," cmap = plt.cm.Greys\n","\n"," plt.subplot(1,2,1)\n"," print(prediction_image.shape)\n"," plt.imshow(source_image, cmap = cmap, aspect = 'equal')\n"," plt.title('Source')\n"," plt.subplot(1,2,2)\n"," plt.imshow(prediction_image, cmap = cmap, aspect = 'equal')\n"," plt.title('Prediction')\n","\n","interact(show_image, continuous_update=False);"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"Z-8uu6aEHrdd"},"source":["for myfile in os.listdir('')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"89tlSWBC940z"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"RoiTamQC9_Pr"},"source":["## **6.4. Purge unnecessary folders**\n","---\n"]},{"cell_type":"code","metadata":{"cellView":"form","id":"3VStzQ0k-FUm"},"source":["#@markdown ##If you have checked that all your data is saved you can delete the pytorch_fnet folder from your drive by playing this cell.\n","\n","import shutil\n","shutil.rmtree('/content/gdrive/My Drive/pytorch_fnet')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"UvSlTaH14s3t"},"source":["#**Thank you for using fnet!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/CARE_2D_ZeroCostDL4Mic.ipynb b/Colab_notebooks/CARE_2D_ZeroCostDL4Mic.ipynb index 2a34f74e..d96bccb4 100644 --- a/Colab_notebooks/CARE_2D_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/CARE_2D_ZeroCostDL4Mic.ipynb @@ -1,2039 +1 @@ -{ - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "accelerator": "GPU", - "colab": { - "name": "CARE_2D_ZeroCostDL4Mic.ipynb", - "provenance": [], - "collapsed_sections": [], - "toc_visible": true, - "machine_shape": "hm" - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.4" - } - }, - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "V9zNGvape2-I" - }, - "source": [ - "# **CARE: Content-aware image restoration (2D)**\n", - "\n", - "---\n", - "\n", - "CARE is a neural network capable of image restoration from corrupted bio-images, first published in 2018 by [Weigert *et al.* in Nature Methods](https://www.nature.com/articles/s41592-018-0216-7). The CARE network uses a U-Net network architecture and allows image restoration and resolution improvement in 2D and 3D images, in a supervised manner, using noisy images as input and low-noise images as targets for training. The function of the network is essentially determined by the set of images provided in the training dataset. For instance, if noisy images are provided as input and high signal-to-noise ratio images are provided as targets, the network will perform denoising.\n", - "\n", - " **This particular notebook enables restoration of 2D datasets. If you are interested in restoring a 3D dataset, you should use the CARE 3D notebook instead.**\n", - "\n", - "---\n", - "\n", - "*Disclaimer*:\n", - "\n", - "This notebook is part of the *Zero-Cost Deep-Learning to Enhance Microscopy* project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n", - "\n", - "This notebook is based on the following paper: \n", - "\n", - "**Content-aware image restoration: pushing the limits of fluorescence microscopy**, by Weigert *et al.* published in Nature Methods in 2018 (https://www.nature.com/articles/s41592-018-0216-7)\n", - "\n", - "And source code found in: /~https://github.com/csbdeep/csbdeep\n", - "\n", - "For a more in-depth description of the features of the network, please refer to [this guide](http://csbdeep.bioimagecomputing.com/doc/) provided by the original authors of the work.\n", - "\n", - "We provide a dataset for the training of this notebook as a way to test its functionalities but the training and test data of the restoration experiments is also available from the authors of the original paper [here](https://publications.mpi-cbg.de/publications-sites/7207/).\n", - "\n", - "\n", - "**Please also cite this original paper when using or developing this notebook.**" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jWAz2i7RdxUV" - }, - "source": [ - "# **How to use this notebook?**\n", - "\n", - "---\n", - "\n", - "Video describing how to use our notebooks are available on youtube:\n", - " - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n", - " - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n", - "\n", - "\n", - "---\n", - "###**Structure of a notebook**\n", - "\n", - "The notebook contains two types of cell: \n", - "\n", - "**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n", - "\n", - "**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n", - "\n", - "---\n", - "###**Table of contents, Code snippets** and **Files**\n", - "\n", - "On the top left side of the notebook you find three tabs which contain from top to bottom:\n", - "\n", - "*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n", - "\n", - "*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n", - "\n", - "*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n", - "\n", - "**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n", - "\n", - "**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n", - "\n", - "---\n", - "###**Making changes to the notebook**\n", - "\n", - "**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n", - "\n", - "To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n", - "You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "vNMDQHm0Ah-Z" - }, - "source": [ - "#**0. Before getting started**\n", - "---\n", - " For CARE to train, **it needs to have access to a paired training dataset**. This means that the same image needs to be acquired in the two conditions (for instance, low signal-to-noise ratio and high signal-to-noise ratio) and provided with indication of correspondence.\n", - "\n", - " Therefore, the data structure is important. It is necessary that all the input data are in the same folder and that all the output data is in a separate folder. The provided training dataset is already split in two folders called \"Training - Low SNR images\" (Training_source) and \"Training - high SNR images\" (Training_target). Information on how to generate a training dataset is available in our Wiki page: /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n", - "\n", - "**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n", - "\n", - " **Additionally, the corresponding input and output files need to have the same name**.\n", - "\n", - " Please note that you currently can **only use .tif files!**\n", - "\n", - "\n", - "Here's a common data structure that can work:\n", - "* Experiment A\n", - " - **Training dataset**\n", - " - Low SNR images (Training_source)\n", - " - img_1.tif, img_2.tif, ...\n", - " - High SNR images (Training_target)\n", - " - img_1.tif, img_2.tif, ...\n", - " - **Quality control dataset**\n", - " - Low SNR images\n", - " - img_1.tif, img_2.tif\n", - " - High SNR images\n", - " - img_1.tif, img_2.tif\n", - " - **Data to be predicted**\n", - " - **Results**\n", - "\n", - "---\n", - "**Important note**\n", - "\n", - "- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n", - "\n", - "- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n", - "\n", - "- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "b4-r1gE7Iamv" - }, - "source": [ - "# **1. Initialise the Colab session**\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "DMNHVZfHmbKb" - }, - "source": [ - "\n", - "## **1.1. Check for GPU access**\n", - "---\n", - "\n", - "By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n", - "\n", - "Go to **Runtime -> Change the Runtime type**\n", - "\n", - "**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n", - "\n", - "**Accelerator: GPU** *(Graphics processing unit)*\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "BDhmUgqCStlm", - "cellView": "form" - }, - "source": [ - "#@markdown ##Run this cell to check if you have GPU access\n", - "\n", - "%tensorflow_version 1.x\n", - "\n", - "import tensorflow as tf\n", - "if tf.test.gpu_device_name()=='':\n", - " print('You do not have GPU access.') \n", - " print('Did you change your runtime ?') \n", - " print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n", - " print('Expect slow performance. To access GPU try reconnecting later')\n", - "\n", - "else:\n", - " print('You have GPU access')\n", - " !nvidia-smi" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-oqBTeLaImnU" - }, - "source": [ - "## **1.2. Mount your Google Drive**\n", - "---\n", - " To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n", - "\n", - " Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n", - "\n", - " Once this is done, your data are available in the **Files** tab on the top left of notebook." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "01Djr8v-5pPk", - "cellView": "form" - }, - "source": [ - "\n", - "#@markdown ##Run this cell to connect your Google Drive to Colab\n", - "\n", - "#@markdown * Click on the URL. \n", - "\n", - "#@markdown * Sign in your Google Account. \n", - "\n", - "#@markdown * Copy the authorization code. \n", - "\n", - "#@markdown * Enter the authorization code. \n", - "\n", - "#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n", - "\n", - "#mounts user's Google Drive to Google Colab.\n", - "\n", - "from google.colab import drive\n", - "drive.mount('/content/gdrive')\n", - "\n", - "\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "n4yWFoJNnoin" - }, - "source": [ - "# **2. Install CARE and dependencies**\n", - "---\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "5d6BsNWn_bHL" - }, - "source": [ - "## **2.1. Install key dependencies**\n", - "---\n", - " " - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "3u2mXn3XsWzd", - "cellView": "form" - }, - "source": [ - "#@markdown ##Install CARE and dependencies\n", - "\n", - "\n", - "#Here, we install libraries which are not already included in Colab.\n", - "\n", - "!pip install tifffile # contains tools to operate tiff-files\n", - "!pip install csbdeep # contains tools for restoration of fluorescence microcopy images (Content-aware Image Restoration, CARE). It uses Keras and Tensorflow.\n", - "!pip install wget\n", - "!pip install memory_profiler\n", - "!pip install fpdf\n", - "\n", - "#Force session restart\n", - "exit(0)\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "3m8GnyWX-r0Z" - }, - "source": [ - "## **2.2. Restart your runtime**\n", - "---\n", - " " - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "bK6zwRkh-usk" - }, - "source": [ - "** Your Runtime has automatically restarted. This is normal.**\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "eDrWDRP2_fRm" - }, - "source": [ - "## **2.3. Load key dependencies**\n", - "---\n", - " " - ] - }, - { - "cell_type": "code", - "metadata": { - "cellView": "form", - "id": "aGxvAcGT-rTq" - }, - "source": [ - "#@markdown ##Load key dependencies\n", - "\n", - "Notebook_version = ['1.12']\n", - "\n", - "from builtins import any as b_any\n", - "\n", - "def get_requirements_path():\n", - " # Store requirements file in 'contents' directory \n", - " current_dir = os.getcwd()\n", - " dir_count = current_dir.count('/') - 1\n", - " path = '../' * (dir_count) + 'requirements.txt'\n", - " return path\n", - "\n", - "def filter_files(file_list, filter_list):\n", - " filtered_list = []\n", - " for fname in file_list:\n", - " if b_any(fname.split('==')[0] in s for s in filter_list):\n", - " filtered_list.append(fname)\n", - " return filtered_list\n", - "\n", - "def build_requirements_file(before, after):\n", - " path = get_requirements_path()\n", - "\n", - " # Exporting requirements.txt for local run\n", - " !pip freeze > $path\n", - "\n", - " # Get minimum requirements file\n", - " df = pd.read_csv(path, delimiter = \"\\n\")\n", - " mod_list = [m.split('.')[0] for m in after if not m in before]\n", - " req_list_temp = df.values.tolist()\n", - " req_list = [x[0] for x in req_list_temp]\n", - "\n", - " # Replace with package name and handle cases where import name is different to module name\n", - " mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]\n", - " mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list] \n", - " filtered_list = filter_files(req_list, mod_replace_list)\n", - "\n", - " file=open(path,'w')\n", - " for item in filtered_list:\n", - " file.writelines(item + '\\n')\n", - "\n", - " file.close()\n", - "\n", - "import sys\n", - "before = [str(m) for m in sys.modules]\n", - "\n", - "\n", - "%load_ext memory_profiler\n", - "\n", - "#Here, we import and enable Tensorflow 1 instead of Tensorflow 2.\n", - "%tensorflow_version 1.x\n", - "\n", - "import tensorflow \n", - "import tensorflow as tf\n", - "\n", - "print(tensorflow.__version__)\n", - "print(\"Tensorflow enabled.\")\n", - "\n", - "# ------- Variable specific to CARE -------\n", - "from csbdeep.utils import download_and_extract_zip_file, plot_some, axes_dict, plot_history, Path, download_and_extract_zip_file\n", - "from csbdeep.data import RawData, create_patches \n", - "from csbdeep.io import load_training_data, save_tiff_imagej_compatible\n", - "from csbdeep.models import Config, CARE\n", - "from csbdeep import data\n", - "from __future__ import print_function, unicode_literals, absolute_import, division\n", - "%matplotlib inline\n", - "%config InlineBackend.figure_format = 'retina'\n", - "\n", - "\n", - "\n", - "# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n", - "import numpy as np\n", - "from matplotlib import pyplot as plt\n", - "import urllib\n", - "import os, random\n", - "import shutil \n", - "import zipfile\n", - "from tifffile import imread, imsave\n", - "import time\n", - "import sys\n", - "import wget\n", - "from pathlib import Path\n", - "import pandas as pd\n", - "import csv\n", - "from glob import glob\n", - "from scipy import signal\n", - "from scipy import ndimage\n", - "from skimage import io\n", - "from sklearn.linear_model import LinearRegression\n", - "from skimage.util import img_as_uint\n", - "import matplotlib as mpl\n", - "from skimage.metrics import structural_similarity\n", - "from skimage.metrics import peak_signal_noise_ratio as psnr\n", - "from astropy.visualization import simple_norm\n", - "from skimage import img_as_float32\n", - "from skimage.util import img_as_ubyte\n", - "from tqdm import tqdm \n", - "from fpdf import FPDF, HTMLMixin\n", - "from datetime import datetime\n", - "import subprocess\n", - "from pip._internal.operations.freeze import freeze\n", - "\n", - "# Colors for the warning messages\n", - "class bcolors:\n", - " WARNING = '\\033[31m'\n", - "\n", - "W = '\\033[0m' # white (normal)\n", - "R = '\\033[31m' # red\n", - "\n", - "#Disable some of the tensorflow warnings\n", - "import warnings\n", - "warnings.filterwarnings(\"ignore\")\n", - "\n", - "print(\"Libraries installed\")\n", - "\n", - "\n", - "# Check if this is the latest version of the notebook\n", - "Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n", - "print('Notebook version: '+Notebook_version[0])\n", - "strlist = Notebook_version[0].split('.')\n", - "Notebook_version_main = strlist[0]+'.'+strlist[1]\n", - "if Notebook_version_main == Latest_notebook_version.columns:\n", - " print(\"This notebook is up-to-date.\")\n", - "else:\n", - " print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n", - "\n", - "\n", - "!pip freeze > requirements.txt\n", - "\n", - "#Create a pdf document with training summary\n", - "\n", - "def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n", - " # save FPDF() class into a \n", - " # variable pdf \n", - " #from datetime import datetime\n", - "\n", - " class MyFPDF(FPDF, HTMLMixin):\n", - " pass\n", - "\n", - " pdf = MyFPDF()\n", - " pdf.add_page()\n", - " pdf.set_right_margin(-1)\n", - " pdf.set_font(\"Arial\", size = 11, style='B') \n", - "\n", - " Network = 'CARE 2D'\n", - " day = datetime.now()\n", - " datetime_str = str(day)[0:10]\n", - "\n", - " Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n", - " pdf.multi_cell(180, 5, txt = Header, align = 'L') \n", - "\n", - " # add another cell \n", - " if trained:\n", - " training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n", - " pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n", - " pdf.ln(1)\n", - "\n", - " Header_2 = 'Information for your materials and methods:'\n", - " pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n", - "\n", - " all_packages = ''\n", - " for requirement in freeze(local_only=True):\n", - " all_packages = all_packages+requirement+', '\n", - " #print(all_packages)\n", - "\n", - " #Main Packages\n", - " main_packages = ''\n", - " version_numbers = []\n", - " for name in ['tensorflow','numpy','Keras','csbdeep']:\n", - " find_name=all_packages.find(name)\n", - " main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n", - " #Version numbers only here:\n", - " version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n", - "\n", - " cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n", - " cuda_version = cuda_version.stdout.decode('utf-8')\n", - " cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n", - " gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n", - " gpu_name = gpu_name.stdout.decode('utf-8')\n", - " gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n", - " #print(cuda_version[cuda_version.find(', V')+3:-1])\n", - " #print(gpu_name)\n", - "\n", - " shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n", - " dataset_size = len(os.listdir(Training_source))\n", - "\n", - " text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(dataset_size*number_of_patches)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+config.train_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n", - "\n", - " if pretrained_model:\n", - " text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(dataset_size*number_of_patches)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+config.train_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was re-trained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n", - "\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - " pdf.multi_cell(190, 5, txt = text, align='L')\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.ln(1)\n", - " pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n", - " pdf.set_font('')\n", - " if augmentation:\n", - " aug_text = 'The dataset was augmented by a factor of '+str(Multiply_dataset_by)+' by'\n", - " if rotate_270_degrees != 0 or rotate_90_degrees != 0:\n", - " aug_text = aug_text+'\\n- rotation'\n", - " if flip_left_right != 0 or flip_top_bottom != 0:\n", - " aug_text = aug_text+'\\n- flipping'\n", - " if random_zoom_magnification != 0:\n", - " aug_text = aug_text+'\\n- random zoom magnification'\n", - " if random_distortion != 0:\n", - " aug_text = aug_text+'\\n- random distortion'\n", - " if image_shear != 0:\n", - " aug_text = aug_text+'\\n- image shearing'\n", - " if skew_image != 0:\n", - " aug_text = aug_text+'\\n- image skewing'\n", - " else:\n", - " aug_text = 'No augmentation was used for training.'\n", - " pdf.multi_cell(190, 5, txt=aug_text, align='L')\n", - " pdf.set_font('Arial', size = 11, style = 'B')\n", - " pdf.ln(1)\n", - " pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - " if Use_Default_Advanced_Parameters:\n", - " pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n", - " pdf.cell(200, 5, txt='The following parameters were used for training:')\n", - " pdf.ln(1)\n", - " html = \"\"\" \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
ParameterValue
number_of_epochs{0}
patch_size{1}
number_of_patches{2}
batch_size{3}
number_of_steps{4}
percentage_validation{5}
initial_learning_rate{6}
\n", - " \"\"\".format(number_of_epochs,str(patch_size)+'x'+str(patch_size),number_of_patches,batch_size,number_of_steps,percentage_validation,initial_learning_rate)\n", - " pdf.write_html(html)\n", - "\n", - " #pdf.multi_cell(190, 5, txt = text_2, align='L')\n", - " pdf.set_font(\"Arial\", size = 11, style='B')\n", - " pdf.ln(1)\n", - " pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.cell(29, 5, txt= 'Training_source:', align = 'L', ln=0)\n", - " pdf.set_font('')\n", - " pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.cell(27, 5, txt= 'Training_target:', align = 'L', ln=0)\n", - " pdf.set_font('')\n", - " pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n", - " #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n", - " pdf.ln(1)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.cell(22, 5, txt= 'Model Path:', align = 'L', ln=0)\n", - " pdf.set_font('')\n", - " pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n", - " pdf.ln(1)\n", - " pdf.cell(60, 5, txt = 'Example Training pair', ln=1)\n", - " pdf.ln(1)\n", - " exp_size = io.imread('/content/TrainingDataExample_CARE2D.png').shape\n", - " pdf.image('/content/TrainingDataExample_CARE2D.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n", - " pdf.ln(1)\n", - " ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n", - " pdf.multi_cell(190, 5, txt = ref_1, align='L')\n", - " ref_2 = '- CARE: Weigert, Martin, et al. \"Content-aware image restoration: pushing the limits of fluorescence microscopy.\" Nature methods 15.12 (2018): 1090-1097.'\n", - " pdf.multi_cell(190, 5, txt = ref_2, align='L')\n", - " if augmentation:\n", - " ref_3 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\n", - " pdf.multi_cell(190, 5, txt = ref_3, align='L')\n", - " pdf.ln(3)\n", - " reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n", - " pdf.set_font('Arial', size = 11, style='B')\n", - " pdf.multi_cell(190, 5, txt=reminder, align='C')\n", - "\n", - " pdf.output(model_path+'/'+model_name+'/'+model_name+\"_training_report.pdf\")\n", - "\n", - "\n", - "#Make a pdf summary of the QC results\n", - "\n", - "def qc_pdf_export():\n", - " class MyFPDF(FPDF, HTMLMixin):\n", - " pass\n", - "\n", - " pdf = MyFPDF()\n", - " pdf.add_page()\n", - " pdf.set_right_margin(-1)\n", - " pdf.set_font(\"Arial\", size = 11, style='B') \n", - "\n", - " Network = 'CARE 2D'\n", - " #model_name = os.path.basename(full_QC_model_path)\n", - " day = datetime.now()\n", - " datetime_str = str(day)[0:10]\n", - "\n", - " Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n", - " pdf.multi_cell(180, 5, txt = Header, align = 'L') \n", - "\n", - " all_packages = ''\n", - " for requirement in freeze(local_only=True):\n", - " all_packages = all_packages+requirement+', '\n", - "\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 11, style = 'B')\n", - " pdf.ln(2)\n", - " pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n", - " pdf.ln(1)\n", - " exp_size = io.imread(full_QC_model_path+'Quality Control/QC_example_data.png').shape\n", - " if os.path.exists(full_QC_model_path+'Quality Control/lossCurvePlots.png'):\n", - " pdf.image(full_QC_model_path+'Quality Control/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/13))\n", - " else:\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size=10)\n", - " pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.', align='L')\n", - " pdf.ln(2)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.ln(3)\n", - " pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n", - " pdf.ln(1)\n", - " exp_size = io.imread(full_QC_model_path+'Quality Control/QC_example_data.png').shape\n", - " pdf.image(full_QC_model_path+'Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/10))\n", - " pdf.ln(1)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 11, style = 'B')\n", - " pdf.ln(1)\n", - " pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - "\n", - " pdf.ln(1)\n", - " html = \"\"\"\n", - " \n", - " \n", - " \"\"\"\n", - " with open(full_QC_model_path+'Quality Control/QC_metrics_'+QC_model_name+'.csv', 'r') as csvfile:\n", - " metrics = csv.reader(csvfile)\n", - " header = next(metrics)\n", - " image = header[0]\n", - " mSSIM_PvsGT = header[1]\n", - " mSSIM_SvsGT = header[2]\n", - " NRMSE_PvsGT = header[3]\n", - " NRMSE_SvsGT = header[4]\n", - " PSNR_PvsGT = header[5]\n", - " PSNR_SvsGT = header[6]\n", - " header = \"\"\"\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \"\"\".format(image,mSSIM_PvsGT,mSSIM_SvsGT,NRMSE_PvsGT,NRMSE_SvsGT,PSNR_PvsGT,PSNR_SvsGT)\n", - " html = html+header\n", - " for row in metrics:\n", - " image = row[0]\n", - " mSSIM_PvsGT = row[1]\n", - " mSSIM_SvsGT = row[2]\n", - " NRMSE_PvsGT = row[3]\n", - " NRMSE_SvsGT = row[4]\n", - " PSNR_PvsGT = row[5]\n", - " PSNR_SvsGT = row[6]\n", - " cells = \"\"\"\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \"\"\".format(image,str(round(float(mSSIM_PvsGT),3)),str(round(float(mSSIM_SvsGT),3)),str(round(float(NRMSE_PvsGT),3)),str(round(float(NRMSE_SvsGT),3)),str(round(float(PSNR_PvsGT),3)),str(round(float(PSNR_SvsGT),3)))\n", - " html = html+cells\n", - " html = html+\"\"\"
{0}{1}{2}{3}{4}{5}{6}
{0}{1}{2}{3}{4}{5}{6}
\"\"\"\n", - "\n", - " pdf.write_html(html)\n", - "\n", - " pdf.ln(1)\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - " ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n", - " pdf.multi_cell(190, 5, txt = ref_1, align='L')\n", - " ref_2 = '- CARE: Weigert, Martin, et al. \"Content-aware image restoration: pushing the limits of fluorescence microscopy.\" Nature methods 15.12 (2018): 1090-1097.'\n", - " pdf.multi_cell(190, 5, txt = ref_2, align='L')\n", - "\n", - " pdf.ln(3)\n", - " reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n", - "\n", - " pdf.set_font('Arial', size = 11, style='B')\n", - " pdf.multi_cell(190, 5, txt=reminder, align='C')\n", - "\n", - " pdf.output(full_QC_model_path+'Quality Control/'+QC_model_name+'_QC_report.pdf')\n", - "\n", - "# Build requirements file for local run\n", - "after = [str(m) for m in sys.modules]\n", - "build_requirements_file(before, after)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Fw0kkTU6CsU4" - }, - "source": [ - "# **3. Select your parameters and paths**\n", - "\n", - "---\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "WzYAA-MuaYrT" - }, - "source": [ - "## **3.1. Setting main training parameters**\n", - "---\n", - "\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "CB6acvUFtWqd" - }, - "source": [ - " **Paths for training, predictions and results**\n", - "\n", - "**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source (Low SNR images) and Training_target (High SNR images or ground truth) training data respecively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n", - "\n", - "**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n", - "\n", - "**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n", - "\n", - "**Training Parameters**\n", - "\n", - "**`number_of_epochs`:**Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for 100-300 epochs. Evaluate the performance after training (see 5). **Default value: 50**\n", - "\n", - "**`patch_size`:** CARE divides the image into patches for training. Input the size of the patches (length of a side). The value should be smaller than the dimensions of the image and divisible by 8. **Default value: 128**\n", - "\n", - "**When choosing the patch_size, the value should be i) large enough that it will enclose many instances, ii) small enough that the resulting patches fit into the RAM.** \n", - "\n", - "**`number_of_patches`:** Input the number of the patches per image. Increasing the number of patches allows for larger training datasets. **Default value: 50** \n", - "\n", - "**Decreasing the patch size or increasing the number of patches may improve the training but may also increase the training time.**\n", - "\n", - "**Advanced Parameters - experienced users only**\n", - "\n", - "**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 16**\n", - "\n", - "**`number_of_steps`:** Define the number of training steps by epoch. By default or if set to zero this parameter is calculated so that each patch is seen at least once per epoch. **Default value: Number of patches / batch_size**\n", - "\n", - "**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during training. **Default value: 10** \n", - "\n", - "**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0004**" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "ewpNJ_I0Mv47", - "cellView": "form" - }, - "source": [ - "#@markdown ###Path to training images:\n", - "\n", - "Training_source = \"\" #@param {type:\"string\"}\n", - "InputFile = Training_source+\"/*.tif\"\n", - "\n", - "Training_target = \"\" #@param {type:\"string\"}\n", - "OutputFile = Training_target+\"/*.tif\"\n", - "\n", - "#Define where the patch file will be saved\n", - "base = \"/content\"\n", - "\n", - "\n", - "# model name and path\n", - "#@markdown ###Name of the model and path to model folder:\n", - "model_name = \"\" #@param {type:\"string\"}\n", - "model_path = \"\" #@param {type:\"string\"}\n", - "\n", - "# other parameters for training.\n", - "#@markdown ###Training Parameters\n", - "#@markdown Number of epochs:\n", - "number_of_epochs = 50#@param {type:\"number\"}\n", - "\n", - "#@markdown Patch size (pixels) and number\n", - "patch_size = 128#@param {type:\"number\"} # in pixels\n", - "number_of_patches = 50#@param {type:\"number\"}\n", - "\n", - "#@markdown ###Advanced Parameters\n", - "\n", - "Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n", - "#@markdown ###If not, please input:\n", - "\n", - "batch_size = 16#@param {type:\"number\"}\n", - "number_of_steps = 0#@param {type:\"number\"}\n", - "percentage_validation = 10 #@param {type:\"number\"}\n", - "initial_learning_rate = 0.0004 #@param {type:\"number\"}\n", - "\n", - "\n", - "if (Use_Default_Advanced_Parameters): \n", - " print(\"Default advanced parameters enabled\")\n", - " batch_size = 16\n", - " percentage_validation = 10\n", - " initial_learning_rate = 0.0004\n", - "\n", - "#Here we define the percentage to use for validation\n", - "percentage = percentage_validation/100\n", - "\n", - "\n", - "#here we check that no model with the same name already exist, if so print a warning\n", - "if os.path.exists(model_path+'/'+model_name):\n", - " print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n", - " print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\"+W)\n", - " \n", - "\n", - "# Here we disable pre-trained model by default (in case the cell is not ran)\n", - "Use_pretrained_model = False\n", - "\n", - "# Here we disable data augmentation by default (in case the cell is not ran)\n", - "\n", - "Use_Data_augmentation = False\n", - "\n", - "print(\"Parameters initiated.\")\n", - "\n", - "# This will display a randomly chosen dataset input and output\n", - "random_choice = random.choice(os.listdir(Training_source))\n", - "x = imread(Training_source+\"/\"+random_choice)\n", - "\n", - "\n", - "# Here we check that the input images contains the expected dimensions\n", - "if len(x.shape) == 2:\n", - " print(\"Image dimensions (y,x)\",x.shape)\n", - "\n", - "if not len(x.shape) == 2:\n", - " print(bcolors.WARNING +\"Your images appear to have the wrong dimensions. Image dimension\",x.shape)\n", - "\n", - "\n", - "#Find image XY dimension\n", - "Image_Y = x.shape[0]\n", - "Image_X = x.shape[1]\n", - "\n", - "#Hyperparameters failsafes\n", - "\n", - "# Here we check that patch_size is smaller than the smallest xy dimension of the image \n", - "\n", - "if patch_size > min(Image_Y, Image_X):\n", - " patch_size = min(Image_Y, Image_X)\n", - " print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n", - "\n", - "# Here we check that patch_size is divisible by 8\n", - "if not patch_size % 8 == 0:\n", - " patch_size = ((int(patch_size / 8)-1) * 8)\n", - " print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is now:\",patch_size)\n", - "\n", - "\n", - "os.chdir(Training_target)\n", - "y = imread(Training_target+\"/\"+random_choice)\n", - "\n", - "f=plt.figure(figsize=(16,8))\n", - "plt.subplot(1,2,1)\n", - "plt.imshow(x, norm=simple_norm(x, percent = 99), interpolation='nearest')\n", - "plt.title('Training source')\n", - "plt.axis('off');\n", - "\n", - "plt.subplot(1,2,2)\n", - "plt.imshow(y, norm=simple_norm(y, percent = 99), interpolation='nearest')\n", - "plt.title('Training target')\n", - "plt.axis('off');\n", - "plt.savefig('/content/TrainingDataExample_CARE2D.png',bbox_inches='tight',pad_inches=0)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "xGcl7WGP4WHt" - }, - "source": [ - "## **3.2. Data augmentation**\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "5Lio8hpZ4PJ1" - }, - "source": [ - "Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n", - "\n", - " **However, data augmentation is not a magic solution and may also introduce issues. Therefore, we recommend that you train your network with and without augmentation, and use the QC section to validate that it improves overall performances.** \n", - "\n", - "Data augmentation is performed here by [Augmentor.](/~https://github.com/mdbloice/Augmentor)\n", - "\n", - "[Augmentor](/~https://github.com/mdbloice/Augmentor) was described in the following article:\n", - "\n", - "Marcus D Bloice, Peter M Roth, Andreas Holzinger, Biomedical image augmentation using Augmentor, Bioinformatics, https://doi.org/10.1093/bioinformatics/btz259\n", - "\n", - "**Please also cite this original paper when publishing results obtained using this notebook with augmentation enabled.** " - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "htqjkJWt5J_8", - "cellView": "form" - }, - "source": [ - "#Data augmentation\n", - "\n", - "Use_Data_augmentation = False #@param {type:\"boolean\"}\n", - "\n", - "if Use_Data_augmentation:\n", - " !pip install Augmentor\n", - " import Augmentor\n", - "\n", - "\n", - "#@markdown ####Choose a factor by which you want to multiply your original dataset\n", - "\n", - "Multiply_dataset_by = 30 #@param {type:\"slider\", min:1, max:30, step:1}\n", - "\n", - "Save_augmented_images = False #@param {type:\"boolean\"}\n", - "\n", - "Saving_path = \"\" #@param {type:\"string\"}\n", - "\n", - "\n", - "Use_Default_Augmentation_Parameters = True #@param {type:\"boolean\"}\n", - "#@markdown ###If not, please choose the probability of the following image manipulations to be used to augment your dataset (1 = always used; 0 = disabled ):\n", - "\n", - "#@markdown ####Mirror and rotate images\n", - "rotate_90_degrees = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n", - "\n", - "rotate_270_degrees = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n", - "\n", - "flip_left_right = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n", - "\n", - "flip_top_bottom = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n", - "\n", - "#@markdown ####Random image Zoom\n", - "\n", - "random_zoom = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n", - "\n", - "random_zoom_magnification = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n", - "\n", - "#@markdown ####Random image distortion\n", - "\n", - "random_distortion = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n", - "\n", - "\n", - "#@markdown ####Image shearing and skewing \n", - "\n", - "image_shear = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n", - "max_image_shear = 1 #@param {type:\"slider\", min:1, max:25, step:1}\n", - "\n", - "skew_image = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n", - "\n", - "skew_image_magnitude = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n", - "\n", - "\n", - "if Use_Default_Augmentation_Parameters:\n", - " rotate_90_degrees = 0.5\n", - " rotate_270_degrees = 0.5\n", - " flip_left_right = 0.5\n", - " flip_top_bottom = 0.5\n", - "\n", - " if not Multiply_dataset_by >5:\n", - " random_zoom = 0\n", - " random_zoom_magnification = 0.9\n", - " random_distortion = 0\n", - " image_shear = 0\n", - " max_image_shear = 10\n", - " skew_image = 0\n", - " skew_image_magnitude = 0\n", - "\n", - " if Multiply_dataset_by >5:\n", - " random_zoom = 0.1\n", - " random_zoom_magnification = 0.9\n", - " random_distortion = 0.5\n", - " image_shear = 0.2\n", - " max_image_shear = 5\n", - "\n", - "\n", - " if Multiply_dataset_by >25:\n", - " random_zoom = 0.5\n", - " random_zoom_magnification = 0.8\n", - " random_distortion = 0.5\n", - " image_shear = 0.5\n", - " max_image_shear = 20\n", - "\n", - "\n", - "\n", - "list_files = os.listdir(Training_source)\n", - "Nb_files = len(list_files)\n", - "\n", - "Nb_augmented_files = (Nb_files * Multiply_dataset_by)\n", - "\n", - "\n", - "if Use_Data_augmentation:\n", - " print(\"Data augmentation enabled\")\n", - "# Here we set the path for the various folder were the augmented images will be loaded\n", - "\n", - "# All images are first saved into the augmented folder\n", - " #Augmented_folder = \"/content/Augmented_Folder\"\n", - " \n", - " if not Save_augmented_images:\n", - " Saving_path= \"/content\"\n", - "\n", - " Augmented_folder = Saving_path+\"/Augmented_Folder\"\n", - " if os.path.exists(Augmented_folder):\n", - " shutil.rmtree(Augmented_folder)\n", - " os.makedirs(Augmented_folder)\n", - "\n", - " #Training_source_augmented = \"/content/Training_source_augmented\"\n", - " Training_source_augmented = Saving_path+\"/Training_source_augmented\"\n", - "\n", - " if os.path.exists(Training_source_augmented):\n", - " shutil.rmtree(Training_source_augmented)\n", - " os.makedirs(Training_source_augmented)\n", - "\n", - " #Training_target_augmented = \"/content/Training_target_augmented\"\n", - " Training_target_augmented = Saving_path+\"/Training_target_augmented\"\n", - "\n", - " if os.path.exists(Training_target_augmented):\n", - " shutil.rmtree(Training_target_augmented)\n", - " os.makedirs(Training_target_augmented)\n", - "\n", - "\n", - "# Here we generate the augmented images\n", - "#Load the images\n", - " p = Augmentor.Pipeline(Training_source, Augmented_folder)\n", - "\n", - "#Define the matching images\n", - " p.ground_truth(Training_target)\n", - "#Define the augmentation possibilities\n", - " if not rotate_90_degrees == 0:\n", - " p.rotate90(probability=rotate_90_degrees)\n", - " \n", - " if not rotate_270_degrees == 0:\n", - " p.rotate270(probability=rotate_270_degrees)\n", - "\n", - " if not flip_left_right == 0:\n", - " p.flip_left_right(probability=flip_left_right)\n", - "\n", - " if not flip_top_bottom == 0:\n", - " p.flip_top_bottom(probability=flip_top_bottom)\n", - "\n", - " if not random_zoom == 0:\n", - " p.zoom_random(probability=random_zoom, percentage_area=random_zoom_magnification)\n", - " \n", - " if not random_distortion == 0:\n", - " p.random_distortion(probability=random_distortion, grid_width=4, grid_height=4, magnitude=8)\n", - "\n", - " if not image_shear == 0:\n", - " p.shear(probability=image_shear,max_shear_left=20,max_shear_right=20)\n", - " \n", - "\n", - " p.sample(int(Nb_augmented_files))\n", - "\n", - " print(int(Nb_augmented_files),\"matching images generated\")\n", - "\n", - "# Here we sort through the images and move them back to augmented trainning source and targets folders\n", - "\n", - " augmented_files = os.listdir(Augmented_folder)\n", - "\n", - " for f in augmented_files:\n", - "\n", - " if (f.startswith(\"_groundtruth_(1)_\")):\n", - " shortname_noprefix = f[17:]\n", - " shutil.copyfile(Augmented_folder+\"/\"+f, Training_target_augmented+\"/\"+shortname_noprefix) \n", - " if not (f.startswith(\"_groundtruth_(1)_\")):\n", - " shutil.copyfile(Augmented_folder+\"/\"+f, Training_source_augmented+\"/\"+f)\n", - " \n", - "\n", - " for filename in os.listdir(Training_source_augmented):\n", - " os.chdir(Training_source_augmented)\n", - " os.rename(filename, filename.replace('_original', ''))\n", - " \n", - " #Here we clean up the extra files\n", - " shutil.rmtree(Augmented_folder)\n", - "\n", - "if not Use_Data_augmentation:\n", - " print(bcolors.WARNING+\"Data augmentation disabled\") \n", - "\n", - "\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "bQDuybvyadKU" - }, - "source": [ - "\n", - "## **3.3. Using weights from a pre-trained model as initial weights**\n", - "---\n", - " Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a CARE 2D model**. \n", - "\n", - " This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n", - "\n", - " In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. " - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "8vPkzEBNamE4", - "cellView": "form" - }, - "source": [ - "# @markdown ##Loading weights from a pre-trained network\n", - "\n", - "Use_pretrained_model = False #@param {type:\"boolean\"}\n", - "\n", - "pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n", - "\n", - "Weights_choice = \"best\" #@param [\"last\", \"best\"]\n", - "\n", - "\n", - "#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n", - "pretrained_model_path = \"\" #@param {type:\"string\"}\n", - "\n", - "# --------------------- Check if we load a previously trained model ------------------------\n", - "if Use_pretrained_model:\n", - "\n", - "# --------------------- Load the model from the choosen path ------------------------\n", - " if pretrained_model_choice == \"Model_from_file\":\n", - " h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n", - "\n", - "# --------------------- Download the a model provided in the XXX ------------------------\n", - "\n", - " if pretrained_model_choice == \"Model_name\":\n", - " pretrained_model_name = \"Model_name\"\n", - " pretrained_model_path = \"/content/\"+pretrained_model_name\n", - " print(\"Downloading the 2D_Demo_Model_from_Stardist_2D_paper\")\n", - " if os.path.exists(pretrained_model_path):\n", - " shutil.rmtree(pretrained_model_path)\n", - " os.makedirs(pretrained_model_path)\n", - " wget.download(\"\", pretrained_model_path)\n", - " wget.download(\"\", pretrained_model_path)\n", - " wget.download(\"\", pretrained_model_path) \n", - " wget.download(\"\", pretrained_model_path)\n", - " h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n", - "\n", - "# --------------------- Add additional pre-trained models here ------------------------\n", - "\n", - "\n", - "\n", - "# --------------------- Check the model exist ------------------------\n", - "# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n", - " if not os.path.exists(h5_file_path):\n", - " print(bcolors.WARNING+'WARNING: weights_'+Weights_choice+'.h5 pretrained model does not exist')\n", - " Use_pretrained_model = False\n", - "\n", - " \n", - "# If the model path contains a pretrain model, we load the training rate, \n", - " if os.path.exists(h5_file_path):\n", - "#Here we check if the learning rate can be loaded from the quality control folder\n", - " if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n", - " with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n", - " csvRead = pd.read_csv(csvfile, sep=',')\n", - " #print(csvRead)\n", - " if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n", - " print(\"pretrained network learning rate found\")\n", - " #find the last learning rate\n", - " lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n", - " #Find the learning rate corresponding to the lowest validation loss\n", - " min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n", - " #print(min_val_loss)\n", - " bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n", - " if Weights_choice == \"last\":\n", - " print('Last learning rate: '+str(lastLearningRate))\n", - " if Weights_choice == \"best\":\n", - " print('Learning rate of best validation loss: '+str(bestLearningRate))\n", - " if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n", - " bestLearningRate = initial_learning_rate\n", - " lastLearningRate = initial_learning_rate\n", - " print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead')\n", - "\n", - "#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n", - " if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n", - " print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead')\n", - " bestLearningRate = initial_learning_rate\n", - " lastLearningRate = initial_learning_rate\n", - "\n", - "\n", - "# Display info about the pretrained model to be loaded (or not)\n", - "if Use_pretrained_model:\n", - " print('Weights found in:')\n", - " print(h5_file_path)\n", - " print('will be loaded prior to training.')\n", - "\n", - "else:\n", - " print(bcolors.WARNING+'No pretrained network will be used.')\n", - "\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "rQndJj70FzfL" - }, - "source": [ - "# **4. Train the network**\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tGW2iaU6X5zi" - }, - "source": [ - "## **4.1. Prepare the training data and model for training**\n", - "---\n", - "Here, we use the information from 3. to build the model and convert the training data into a suitable format for training." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "WMJnGJpCMa4y", - "cellView": "form" - }, - "source": [ - "#@markdown ##Create the model and dataset objects\n", - "\n", - "# --------------------- Here we delete the model folder if it already exist ------------------------\n", - "\n", - "if os.path.exists(model_path+'/'+model_name):\n", - " print(bcolors.WARNING +\"!! WARNING: Model folder already exists and has been removed !!\"+W)\n", - " shutil.rmtree(model_path+'/'+model_name)\n", - "\n", - "\n", - "\n", - "# --------------------- Here we load the augmented data or the raw data ------------------------\n", - "\n", - "if Use_Data_augmentation:\n", - " Training_source_dir = Training_source_augmented\n", - " Training_target_dir = Training_target_augmented\n", - "\n", - "if not Use_Data_augmentation:\n", - " Training_source_dir = Training_source\n", - " Training_target_dir = Training_target\n", - "# --------------------- ------------------------------------------------\n", - "\n", - "# This object holds the image pairs (GT and low), ensuring that CARE compares corresponding images.\n", - "# This file is saved in .npz format and later called when loading the trainig data.\n", - "\n", - "\n", - "raw_data = data.RawData.from_folder(\n", - " basepath=base,\n", - " source_dirs=[Training_source_dir], \n", - " target_dir=Training_target_dir, \n", - " axes='CYX', \n", - " pattern='*.tif*')\n", - "\n", - "X, Y, XY_axes = data.create_patches(\n", - " raw_data, \n", - " patch_filter=None, \n", - " patch_size=(patch_size,patch_size), \n", - " n_patches_per_image=number_of_patches)\n", - "\n", - "print ('Creating 2D training dataset')\n", - "training_path = model_path+\"/rawdata\"\n", - "rawdata1 = training_path+\".npz\"\n", - "np.savez(training_path,X=X, Y=Y, axes=XY_axes)\n", - "\n", - "# Load Training Data\n", - "(X,Y), (X_val,Y_val), axes = load_training_data(rawdata1, validation_split=percentage, verbose=True)\n", - "c = axes_dict(axes)['C']\n", - "n_channel_in, n_channel_out = X.shape[c], Y.shape[c]\n", - "\n", - "%memit \n", - "\n", - "#plot of training patches.\n", - "plt.figure(figsize=(12,5))\n", - "plot_some(X[:5],Y[:5])\n", - "plt.suptitle('5 example training patches (top row: source, bottom row: target)');\n", - "\n", - "#plot of validation patches\n", - "plt.figure(figsize=(12,5))\n", - "plot_some(X_val[:5],Y_val[:5])\n", - "plt.suptitle('5 example validation patches (top row: source, bottom row: target)');\n", - "\n", - "\n", - "#Here we automatically define number_of_step in function of training data and batch size\n", - "#if (Use_Default_Advanced_Parameters):\n", - "if (Use_Default_Advanced_Parameters) or (number_of_steps == 0):\n", - " number_of_steps = int(X.shape[0]/batch_size)+1\n", - "\n", - "# --------------------- Using pretrained model ------------------------\n", - "#Here we ensure that the learning rate set correctly when using pre-trained models\n", - "if Use_pretrained_model:\n", - " if Weights_choice == \"last\":\n", - " initial_learning_rate = lastLearningRate\n", - "\n", - " if Weights_choice == \"best\": \n", - " initial_learning_rate = bestLearningRate\n", - "# --------------------- ---------------------- ------------------------\n", - "\n", - "\n", - "#Here we create the configuration file\n", - "\n", - "config = Config(axes, n_channel_in, n_channel_out, probabilistic=True, train_steps_per_epoch=number_of_steps, train_epochs=number_of_epochs, unet_kern_size=5, unet_n_depth=3, train_batch_size=batch_size, train_learning_rate=initial_learning_rate)\n", - "\n", - "print(config)\n", - "vars(config)\n", - "\n", - "# Compile the CARE model for network training\n", - "model_training= CARE(config, model_name, basedir=model_path)\n", - "\n", - "\n", - "# --------------------- Using pretrained model ------------------------\n", - "# Load the pretrained weights \n", - "if Use_pretrained_model:\n", - " model_training.load_weights(h5_file_path)\n", - "# --------------------- ---------------------- ------------------------\n", - "\n", - "pdf_export(augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "wQPz0F6JlvJR" - }, - "source": [ - "## **4.2. Start Training**\n", - "---\n", - "When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n", - "\n", - "* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches.\n", - "\n", - "Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder from Google Drive as all data can be erased at the next training if using the same folder.\n", - "\n", - "**Of Note:** At the end of the training, your model will be automatically exported so it can be used in the CSBDeep Fiji plugin (Run your Network). You can find it in your model folder (TF_SavedModel.zip). In Fiji, Make sure to choose the right version of tensorflow. You can check at: Edit-- Options-- Tensorflow. Choose the version 1.4 (CPU or GPU depending on your system)." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "j_Qm5JBmlvJg", - "cellView": "form" - }, - "source": [ - "#@markdown ##Start training\n", - "\n", - "start = time.time()\n", - "\n", - "# Start Training\n", - "history = model_training.train(X,Y, validation_data=(X_val,Y_val))\n", - "\n", - "print(\"Training, done.\")\n", - "\n", - "# copy the .npz to the model's folder\n", - "shutil.copyfile(model_path+'/rawdata.npz',model_path+'/'+model_name+'/rawdata.npz')\n", - "\n", - "# convert the history.history dict to a pandas DataFrame: \n", - "lossData = pd.DataFrame(history.history) \n", - "\n", - "if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n", - " shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n", - "\n", - "os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n", - "\n", - "# The training evaluation.csv is saved (overwrites the Files if needed). \n", - "lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n", - "with open(lossDataCSVpath, 'w') as f:\n", - " writer = csv.writer(f)\n", - " writer.writerow(['loss','val_loss', 'learning rate'])\n", - " for i in range(len(history.history['loss'])):\n", - " writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n", - "\n", - "\n", - "# Displaying the time elapsed for training\n", - "dt = time.time() - start\n", - "mins, sec = divmod(dt, 60) \n", - "hour, mins = divmod(mins, 60) \n", - "print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n", - "\n", - "model_training.export_TF()\n", - "\n", - "print(\"Your model has been sucessfully exported and can now also be used in the CSBdeep Fiji plugin\")\n", - "\n", - "pdf_export(trained = True, augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "QYuIOWQ3imuU" - }, - "source": [ - "# **5. Evaluate your model**\n", - "---\n", - "\n", - "This section allows you to perform important quality checks on the validity and generalisability of the trained model. \n", - "\n", - "**We highly recommend to perform quality control on all newly trained models.**\n", - "\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "zazOZ3wDx0zQ", - "cellView": "form" - }, - "source": [ - "# model name and path\n", - "#@markdown ###Do you want to assess the model you just trained ?\n", - "Use_the_current_trained_model = True #@param {type:\"boolean\"}\n", - "\n", - "#@markdown ###If not, please provide the path to the model folder:\n", - "\n", - "QC_model_folder = \"\" #@param {type:\"string\"}\n", - "\n", - "#Here we define the loaded model name and path\n", - "QC_model_name = os.path.basename(QC_model_folder)\n", - "QC_model_path = os.path.dirname(QC_model_folder)\n", - "\n", - "if (Use_the_current_trained_model): \n", - " QC_model_name = model_name\n", - " QC_model_path = model_path\n", - "\n", - "full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n", - "if os.path.exists(full_QC_model_path):\n", - " print(\"The \"+QC_model_name+\" network will be evaluated\")\n", - "else:\n", - " W = '\\033[0m' # white (normal)\n", - " R = '\\033[31m' # red\n", - " print(R+'!! WARNING: The chosen model does not exist !!'+W)\n", - " print('Please make sure you provide a valid model path and model name before proceeding further.')\n", - "\n", - "loss_displayed = False" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "yDY9dtzdUTLh" - }, - "source": [ - "## **5.1. Inspection of the loss function**\n", - "---\n", - "\n", - "First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n", - "\n", - "**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n", - "\n", - "**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n", - "\n", - "During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n", - "\n", - "Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased.\n", - "\n", - "**Note: Plots of the losses will be shown in a linear and in a log scale. This can help visualise changes in the losses at different magnitudes. However, note that if the losses are negative the plot on the log scale will be empty. This is not an error.**" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "vMzSP50kMv5p", - "cellView": "form" - }, - "source": [ - "#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n", - "loss_displayed = True\n", - "lossDataFromCSV = []\n", - "vallossDataFromCSV = []\n", - "\n", - "with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n", - " csvRead = csv.reader(csvfile, delimiter=',')\n", - " next(csvRead)\n", - " for row in csvRead:\n", - " lossDataFromCSV.append(float(row[0]))\n", - " vallossDataFromCSV.append(float(row[1]))\n", - "\n", - "epochNumber = range(len(lossDataFromCSV))\n", - "plt.figure(figsize=(15,10))\n", - "\n", - "plt.subplot(2,1,1)\n", - "plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n", - "plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n", - "plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n", - "plt.ylabel('Loss')\n", - "plt.xlabel('Epoch number')\n", - "plt.legend()\n", - "\n", - "plt.subplot(2,1,2)\n", - "plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n", - "plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n", - "plt.title('Training loss and validation loss vs. epoch number (log scale)')\n", - "plt.ylabel('Loss')\n", - "plt.xlabel('Epoch number')\n", - "plt.legend()\n", - "plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png',bbox_inches='tight',pad_inches=0)\n", - "plt.show()\n", - "\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "biT9FI9Ri77_" - }, - "source": [ - "## **5.2. Error mapping and quality metrics estimation**\n", - "---\n", - "\n", - "This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" !\n", - "\n", - "**1. The SSIM (structural similarity) map** \n", - "\n", - "The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n", - "\n", - "**mSSIM** is the SSIM value calculated across the entire window of both images.\n", - "\n", - "**The output below shows the SSIM maps with the mSSIM**\n", - "\n", - "**2. The RSE (Root Squared Error) map** \n", - "\n", - "This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n", - "\n", - "\n", - "**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n", - "\n", - "**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n", - "\n", - "**The output below shows the RSE maps with the NRMSE and PSNR values.**\n", - "\n", - "\n", - "\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "nAs4Wni7VYbq", - "cellView": "form" - }, - "source": [ - "#@markdown ##Choose the folders that contain your Quality Control dataset\n", - "\n", - "Source_QC_folder = \"\" #@param{type:\"string\"}\n", - "Target_QC_folder = \"\" #@param{type:\"string\"}\n", - "\n", - "# Create a quality control/Prediction Folder\n", - "if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"):\n", - " shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n", - "\n", - "os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n", - "\n", - "# Activate the pretrained model. \n", - "model_training = CARE(config=None, name=QC_model_name, basedir=QC_model_path)\n", - "\n", - "# List Tif images in Source_QC_folder\n", - "Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\n", - "Z = sorted(glob(Source_QC_folder_tif))\n", - "Z = list(map(imread,Z))\n", - "print('Number of test dataset found in the folder: '+str(len(Z)))\n", - "\n", - "\n", - "# Perform prediction on all datasets in the Source_QC folder\n", - "for filename in os.listdir(Source_QC_folder):\n", - " img = imread(os.path.join(Source_QC_folder, filename))\n", - " predicted = model_training.predict(img, axes='YX')\n", - " os.chdir(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n", - " imsave(filename, predicted)\n", - "\n", - "\n", - "def ssim(img1, img2):\n", - " return structural_similarity(img1,img2,data_range=1.,full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n", - "\n", - "\n", - "def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n", - " \"\"\"This function is adapted from Martin Weigert\"\"\"\n", - " \"\"\"Percentile-based image normalization.\"\"\"\n", - "\n", - " mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n", - " ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n", - " return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n", - "\n", - "\n", - "def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n", - " \"\"\"This function is adapted from Martin Weigert\"\"\"\n", - " if dtype is not None:\n", - " x = x.astype(dtype,copy=False)\n", - " mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n", - " ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n", - " eps = dtype(eps)\n", - "\n", - " try:\n", - " import numexpr\n", - " x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n", - " except ImportError:\n", - " x = (x - mi) / ( ma - mi + eps )\n", - "\n", - " if clip:\n", - " x = np.clip(x,0,1)\n", - "\n", - " return x\n", - "\n", - "def norm_minmse(gt, x, normalize_gt=True):\n", - " \"\"\"This function is adapted from Martin Weigert\"\"\"\n", - "\n", - " \"\"\"\n", - " normalizes and affinely scales an image pair such that the MSE is minimized \n", - " \n", - " Parameters\n", - " ----------\n", - " gt: ndarray\n", - " the ground truth image \n", - " x: ndarray\n", - " the image that will be affinely scaled \n", - " normalize_gt: bool\n", - " set to True of gt image should be normalized (default)\n", - " Returns\n", - " -------\n", - " gt_scaled, x_scaled \n", - " \"\"\"\n", - " if normalize_gt:\n", - " gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n", - " x = x.astype(np.float32, copy=False) - np.mean(x)\n", - " #x = x - np.mean(x)\n", - " gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n", - " #gt = gt - np.mean(gt)\n", - " scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n", - " return gt, scale * x\n", - "\n", - "# Open and create the csv file that will contain all the QC metrics\n", - "with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/QC_metrics_\"+QC_model_name+\".csv\", \"w\", newline='') as file:\n", - " writer = csv.writer(file)\n", - "\n", - " # Write the header in the csv file\n", - " writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n", - "\n", - " # Let's loop through the provided dataset in the QC folders\n", - "\n", - "\n", - " for i in os.listdir(Source_QC_folder):\n", - " if not os.path.isdir(os.path.join(Source_QC_folder,i)):\n", - " print('Running QC on: '+i)\n", - " # -------------------------------- Target test data (Ground truth) --------------------------------\n", - " test_GT = io.imread(os.path.join(Target_QC_folder, i))\n", - "\n", - " # -------------------------------- Source test data --------------------------------\n", - " test_source = io.imread(os.path.join(Source_QC_folder,i))\n", - "\n", - " # Normalize the images wrt each other by minimizing the MSE between GT and Source image\n", - " test_GT_norm,test_source_norm = norm_minmse(test_GT, test_source, normalize_gt=True)\n", - "\n", - " # -------------------------------- Prediction --------------------------------\n", - " test_prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\",i))\n", - "\n", - " # Normalize the images wrt each other by minimizing the MSE between GT and prediction\n", - " test_GT_norm,test_prediction_norm = norm_minmse(test_GT, test_prediction, normalize_gt=True) \n", - "\n", - "\n", - " # -------------------------------- Calculate the metric maps and save them --------------------------------\n", - "\n", - " # Calculate the SSIM maps\n", - " index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT_norm, test_prediction_norm)\n", - " index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT_norm, test_source_norm)\n", - "\n", - " #Save ssim_maps\n", - " img_SSIM_GTvsPrediction_32bit = np.float32(img_SSIM_GTvsPrediction)\n", - " io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/SSIM_GTvsPrediction_'+i,img_SSIM_GTvsPrediction_32bit)\n", - " img_SSIM_GTvsSource_32bit = np.float32(img_SSIM_GTvsSource)\n", - " io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/SSIM_GTvsSource_'+i,img_SSIM_GTvsSource_32bit)\n", - " \n", - " # Calculate the Root Squared Error (RSE) maps\n", - " img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n", - " img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n", - "\n", - " # Save SE maps\n", - " img_RSE_GTvsPrediction_32bit = np.float32(img_RSE_GTvsPrediction)\n", - " img_RSE_GTvsSource_32bit = np.float32(img_RSE_GTvsSource)\n", - " io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/RSE_GTvsPrediction_'+i,img_RSE_GTvsPrediction_32bit)\n", - " io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/RSE_GTvsSource_'+i,img_RSE_GTvsSource_32bit)\n", - "\n", - "\n", - " # -------------------------------- Calculate the RSE metrics and save them --------------------------------\n", - "\n", - " # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n", - " NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n", - " NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n", - " \n", - " # We can also measure the peak signal to noise ratio between the images\n", - " PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n", - " PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n", - "\n", - " writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource),str(PSNR_GTvsPrediction),str(PSNR_GTvsSource)])\n", - "\n", - "\n", - "# All data is now processed saved\n", - "Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same\n", - "\n", - "plt.figure(figsize=(20,20))\n", - "# Currently only displays the last computed set, from memory\n", - "# Target (Ground-truth)\n", - "plt.subplot(3,3,1)\n", - "plt.axis('off')\n", - "img_GT = io.imread(os.path.join(Target_QC_folder, Test_FileList[-1]))\n", - "plt.imshow(img_GT, norm=simple_norm(img_GT, percent = 99))\n", - "plt.title('Target',fontsize=15)\n", - "\n", - "# Source\n", - "plt.subplot(3,3,2)\n", - "plt.axis('off')\n", - "img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n", - "plt.imshow(img_Source, norm=simple_norm(img_Source, percent = 99))\n", - "plt.title('Source',fontsize=15)\n", - "\n", - "#Prediction\n", - "plt.subplot(3,3,3)\n", - "plt.axis('off')\n", - "img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction/\", Test_FileList[-1]))\n", - "plt.imshow(img_Prediction, norm=simple_norm(img_Prediction, percent = 99))\n", - "plt.title('Prediction',fontsize=15)\n", - "\n", - "#Setting up colours\n", - "cmap = plt.cm.CMRmap\n", - "\n", - "#SSIM between GT and Source\n", - "plt.subplot(3,3,5)\n", - "#plt.axis('off')\n", - "plt.tick_params(\n", - " axis='both', # changes apply to the x-axis and y-axis\n", - " which='both', # both major and minor ticks are affected\n", - " bottom=False, # ticks along the bottom edge are off\n", - " top=False, # ticks along the top edge are off\n", - " left=False, # ticks along the left edge are off\n", - " right=False, # ticks along the right edge are off\n", - " labelbottom=False,\n", - " labelleft=False) \n", - "imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n", - "plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n", - "plt.title('Target vs. Source',fontsize=15)\n", - "plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\n", - "plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n", - "\n", - "#SSIM between GT and Prediction\n", - "plt.subplot(3,3,6)\n", - "#plt.axis('off')\n", - "plt.tick_params(\n", - " axis='both', # changes apply to the x-axis and y-axis\n", - " which='both', # both major and minor ticks are affected\n", - " bottom=False, # ticks along the bottom edge are off\n", - " top=False, # ticks along the top edge are off\n", - " left=False, # ticks along the left edge are off\n", - " right=False, # ticks along the right edge are off\n", - " labelbottom=False,\n", - " labelleft=False) \n", - "imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n", - "plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n", - "plt.title('Target vs. Prediction',fontsize=15)\n", - "plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\n", - "\n", - "#Root Squared Error between GT and Source\n", - "plt.subplot(3,3,8)\n", - "#plt.axis('off')\n", - "plt.tick_params(\n", - " axis='both', # changes apply to the x-axis and y-axis\n", - " which='both', # both major and minor ticks are affected\n", - " bottom=False, # ticks along the bottom edge are off\n", - " top=False, # ticks along the top edge are off\n", - " left=False, # ticks along the left edge are off\n", - " right=False, # ticks along the right edge are off\n", - " labelbottom=False,\n", - " labelleft=False) \n", - "imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource, cmap = cmap, vmin=0, vmax = 1)\n", - "plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n", - "plt.title('Target vs. Source',fontsize=15)\n", - "plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsSource,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n", - "#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\n", - "plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n", - "\n", - "#Root Squared Error between GT and Prediction\n", - "plt.subplot(3,3,9)\n", - "#plt.axis('off')\n", - "plt.tick_params(\n", - " axis='both', # changes apply to the x-axis and y-axis\n", - " which='both', # both major and minor ticks are affected\n", - " bottom=False, # ticks along the bottom edge are off\n", - " top=False, # ticks along the top edge are off\n", - " left=False, # ticks along the left edge are off\n", - " right=False, # ticks along the right edge are off\n", - " labelbottom=False,\n", - " labelleft=False) \n", - "imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n", - "plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n", - "plt.title('Target vs. Prediction',fontsize=15)\n", - "plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsPrediction,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n", - "plt.savefig(full_QC_model_path+'Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n", - "\n", - "qc_pdf_export()" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "69aJVFfsqXbY" - }, - "source": [ - "# **6. Using the trained model**\n", - "\n", - "---\n", - "\n", - "In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tcPNRq1TrMPB" - }, - "source": [ - "## **6.1. Generate prediction(s) from unseen dataset**\n", - "---\n", - "\n", - "The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n", - "\n", - "**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n", - "\n", - "**`Result_folder`:** This folder will contain the predicted output images." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "Am2JSmpC0frj", - "cellView": "form" - }, - "source": [ - "#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then play the cell to predict outputs from your unseen images.\n", - "\n", - "Data_folder = \"\" #@param {type:\"string\"}\n", - "Result_folder = \"\" #@param {type:\"string\"}\n", - "\n", - "# model name and path\n", - "#@markdown ###Do you want to use the current trained model?\n", - "Use_the_current_trained_model = True #@param {type:\"boolean\"}\n", - "\n", - "#@markdown ###If not, please provide the path to the model folder:\n", - "\n", - "Prediction_model_folder = \"\" #@param {type:\"string\"}\n", - "\n", - "#Here we find the loaded model name and parent path\n", - "Prediction_model_name = os.path.basename(Prediction_model_folder)\n", - "Prediction_model_path = os.path.dirname(Prediction_model_folder)\n", - "\n", - "if (Use_the_current_trained_model): \n", - " print(\"Using current trained network\")\n", - " Prediction_model_name = model_name\n", - " Prediction_model_path = model_path\n", - "\n", - "full_Prediction_model_path = os.path.join(Prediction_model_path, Prediction_model_name)\n", - "\n", - "\n", - "if os.path.exists(full_Prediction_model_path):\n", - " print(\"The \"+Prediction_model_name+\" network will be used.\")\n", - "else:\n", - " W = '\\033[0m' # white (normal)\n", - " R = '\\033[31m' # red\n", - " print(R+'!! WARNING: The chosen model does not exist !!'+W)\n", - " print('Please make sure you provide a valid model path and model name before proceeding further.')\n", - "\n", - "\n", - "\n", - "#Activate the pretrained model. \n", - "model_training = CARE(config=None, name=Prediction_model_name, basedir=Prediction_model_path)\n", - "\n", - "\n", - "# creates a loop, creating filenames and saving them\n", - "for filename in os.listdir(Data_folder):\n", - " img = imread(os.path.join(Data_folder,filename))\n", - " restored = model_training.predict(img, axes='YX')\n", - " os.chdir(Result_folder)\n", - " imsave(filename,restored)\n", - "\n", - "print(\"Images saved into folder:\", Result_folder)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "bShxBHY4vFFd" - }, - "source": [ - "## **6.2. Inspect the predicted output**\n", - "---\n", - "\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "cellView": "form", - "id": "6b2t6SLQvIBO" - }, - "source": [ - "# @markdown ##Run this cell to display a randomly chosen input and its corresponding predicted output.\n", - "\n", - "# This will display a randomly chosen dataset input and predicted output\n", - "random_choice = random.choice(os.listdir(Data_folder))\n", - "x = imread(Data_folder+\"/\"+random_choice)\n", - "\n", - "os.chdir(Result_folder)\n", - "y = imread(Result_folder+\"/\"+random_choice)\n", - "\n", - "plt.figure(figsize=(16,8))\n", - "\n", - "plt.subplot(1,2,1)\n", - "plt.axis('off')\n", - "plt.imshow(x, norm=simple_norm(x, percent = 99), interpolation='nearest')\n", - "plt.title('Input')\n", - "\n", - "plt.subplot(1,2,2)\n", - "plt.axis('off')\n", - "plt.imshow(y, norm=simple_norm(y, percent = 99), interpolation='nearest')\n", - "plt.title('Predicted output');\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "hvkd66PldsXB" - }, - "source": [ - "## **6.3. Download your predictions**\n", - "---\n", - "\n", - "**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "u4pcBe8Z3T2J" - }, - "source": [ - "#**Thank you for using CARE 2D!**" - ] - } - ] -} \ No newline at end of file +{"nbformat":4,"nbformat_minor":0,"metadata":{"accelerator":"GPU","colab":{"name":"CARE_2D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1P6awmDRpRza-c5eOGDLBI8rcgBTTdieV","timestamp":1624966407105},{"file_id":"1BhRVVSz7iXSbfXvLoCsPSglyho3CK_9p","timestamp":1619455558162},{"file_id":"1hMjEc-Ex7j-jeYGclaPw2x3OgbkeC6Bl","timestamp":1610626439596},{"file_id":"1_W4q9V1ExGFldTUBvGK91E0LG5QMc7K6","timestamp":1602523405636},{"file_id":"1t9a-44km730bI7F4I08-6Xh7wEZuL98p","timestamp":1591013189418},{"file_id":"11TigzvLl4FSSwFHUNwLzZKI2IAix4Nmu","timestamp":1586415689249},{"file_id":"1_dSnxUg_qtNWjrPc7D6RWDWlCanEL4Ve","timestamp":1585153449937},{"file_id":"1bKo8jYVZPPgXPa_-Gdu1KhDnNN4vYfLx","timestamp":1583200150464}],"collapsed_sections":[],"machine_shape":"hm"},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.4"}},"cells":[{"cell_type":"markdown","metadata":{"id":"V9zNGvape2-I"},"source":["# **CARE: Content-aware image restoration (2D)**\n","\n","---\n","\n","CARE is a neural network capable of image restoration from corrupted bio-images, first published in 2018 by [Weigert *et al.* in Nature Methods](https://www.nature.com/articles/s41592-018-0216-7). The CARE network uses a U-Net network architecture and allows image restoration and resolution improvement in 2D and 3D images, in a supervised manner, using noisy images as input and low-noise images as targets for training. The function of the network is essentially determined by the set of images provided in the training dataset. For instance, if noisy images are provided as input and high signal-to-noise ratio images are provided as targets, the network will perform denoising.\n","\n"," **This particular notebook enables restoration of 2D datasets. If you are interested in restoring a 3D dataset, you should use the CARE 3D notebook instead.**\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the *Zero-Cost Deep-Learning to Enhance Microscopy* project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is based on the following paper: \n","\n","**Content-aware image restoration: pushing the limits of fluorescence microscopy**, by Weigert *et al.* published in Nature Methods in 2018 (https://www.nature.com/articles/s41592-018-0216-7)\n","\n","And source code found in: /~https://github.com/csbdeep/csbdeep\n","\n","For a more in-depth description of the features of the network, please refer to [this guide](http://csbdeep.bioimagecomputing.com/doc/) provided by the original authors of the work.\n","\n","We provide a dataset for the training of this notebook as a way to test its functionalities but the training and test data of the restoration experiments is also available from the authors of the original paper [here](https://publications.mpi-cbg.de/publications-sites/7207/).\n","\n","\n","**Please also cite this original paper when using or developing this notebook.**"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"vNMDQHm0Ah-Z"},"source":["#**0. Before getting started**\n","---\n"," For CARE to train, **it needs to have access to a paired training dataset**. This means that the same image needs to be acquired in the two conditions (for instance, low signal-to-noise ratio and high signal-to-noise ratio) and provided with indication of correspondence.\n","\n"," Therefore, the data structure is important. It is necessary that all the input data are in the same folder and that all the output data is in a separate folder. The provided training dataset is already split in two folders called \"Training - Low SNR images\" (Training_source) and \"Training - high SNR images\" (Training_target). Information on how to generate a training dataset is available in our Wiki page: /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n"," **Additionally, the corresponding input and output files need to have the same name**.\n","\n"," Please note that you currently can **only use .tif files!**\n","\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Low SNR images (Training_source)\n"," - img_1.tif, img_2.tif, ...\n"," - High SNR images (Training_target)\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - Low SNR images\n"," - img_1.tif, img_2.tif\n"," - High SNR images\n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **1. Install CARE and dependencies**\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"5d6BsNWn_bHL"},"source":["## **1.1. Install key dependencies**\n","---\n"," "]},{"cell_type":"code","metadata":{"id":"3u2mXn3XsWzd","cellView":"form"},"source":["#@markdown ##Install CARE and dependencies\n","\n","\n","#Here, we install libraries which are not already included in Colab.\n","\n","!pip install tifffile # contains tools to operate tiff-files\n","!pip install csbdeep # contains tools for restoration of fluorescence microcopy images (Content-aware Image Restoration, CARE). It uses Keras and Tensorflow.\n","!pip install wget\n","!pip install memory_profiler\n","!pip install fpdf\n","\n","#Force session restart\n","exit(0)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"bK6zwRkh-usk"},"source":["## **1.2. Restart your runtime**\n","---\n","\n","\n","\n","** Ignore the following message error message. Your Runtime has automatically restarted. This is normal.**\n","\n","\"\"
\n","\n"]},{"cell_type":"markdown","metadata":{"id":"eDrWDRP2_fRm"},"source":["## **1.3. Load key dependencies**\n","---\n"," "]},{"cell_type":"code","metadata":{"id":"aGxvAcGT-rTq","cellView":"form"},"source":["#@markdown ##Load key dependencies\n","\n","Notebook_version = '1.13'\n","Network = 'CARE (2D)'\n","\n","from builtins import any as b_any\n","\n","def get_requirements_path():\n"," # Store requirements file in 'contents' directory \n"," current_dir = os.getcwd()\n"," dir_count = current_dir.count('/') - 1\n"," path = '../' * (dir_count) + 'requirements.txt'\n"," return path\n","\n","def filter_files(file_list, filter_list):\n"," filtered_list = []\n"," for fname in file_list:\n"," if b_any(fname.split('==')[0] in s for s in filter_list):\n"," filtered_list.append(fname)\n"," return filtered_list\n","\n","def build_requirements_file(before, after):\n"," path = get_requirements_path()\n","\n"," # Exporting requirements.txt for local run\n"," !pip freeze > $path\n","\n"," # Get minimum requirements file\n"," df = pd.read_csv(path, delimiter = \"\\n\")\n"," mod_list = [m.split('.')[0] for m in after if not m in before]\n"," req_list_temp = df.values.tolist()\n"," req_list = [x[0] for x in req_list_temp]\n","\n"," # Replace with package name and handle cases where import name is different to module name\n"," mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]\n"," mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list] \n"," filtered_list = filter_files(req_list, mod_replace_list)\n","\n"," file=open(path,'w')\n"," for item in filtered_list:\n"," file.writelines(item + '\\n')\n","\n"," file.close()\n","\n","import sys\n","before = [str(m) for m in sys.modules]\n","\n","\n","%load_ext memory_profiler\n","\n","#Here, we import and enable Tensorflow 1 instead of Tensorflow 2.\n","%tensorflow_version 1.x\n","\n","import tensorflow \n","import tensorflow as tf\n","\n","print(tensorflow.__version__)\n","print(\"Tensorflow enabled.\")\n","\n","# ------- Variable specific to CARE -------\n","from csbdeep.utils import download_and_extract_zip_file, plot_some, axes_dict, plot_history, Path, download_and_extract_zip_file\n","from csbdeep.data import RawData, create_patches \n","from csbdeep.io import load_training_data, save_tiff_imagej_compatible\n","from csbdeep.models import Config, CARE\n","from csbdeep import data\n","from __future__ import print_function, unicode_literals, absolute_import, division\n","%matplotlib inline\n","%config InlineBackend.figure_format = 'retina'\n","\n","\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","from skimage.util import img_as_ubyte\n","from tqdm import tqdm \n","from fpdf import FPDF, HTMLMixin\n","from datetime import datetime\n","import subprocess\n","from pip._internal.operations.freeze import freeze\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print(\"Libraries installed\")\n","\n","\n","# Check if this is the latest version of the notebook\n","All_notebook_versions = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_Notebook_versions.csv\", dtype=str)\n","print('Notebook version: '+Notebook_version)\n","Latest_Notebook_version = All_notebook_versions[All_notebook_versions[\"Notebook\"] == Network]['Version'].iloc[0]\n","print('Latest notebook version: '+Latest_Notebook_version)\n","if Notebook_version == Latest_Notebook_version:\n"," print(\"This notebook is up-to-date.\")\n","else:\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","\n","!pip freeze > requirements.txt\n","\n","#Create a pdf document with training summary\n","\n","def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n"," # save FPDF() class into a \n"," # variable pdf \n"," #from datetime import datetime\n","\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'CARE 2D'\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," # add another cell \n"," if trained:\n"," training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n"," pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n"," pdf.ln(1)\n","\n"," Header_2 = 'Information for your materials and methods:'\n"," pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n"," #print(all_packages)\n","\n"," #Main Packages\n"," main_packages = ''\n"," version_numbers = []\n"," for name in ['tensorflow','numpy','Keras','csbdeep']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n"," cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n"," cuda_version = cuda_version.stdout.decode('utf-8')\n"," cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n"," gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n"," gpu_name = gpu_name.stdout.decode('utf-8')\n"," gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n"," #print(cuda_version[cuda_version.find(', V')+3:-1])\n"," #print(gpu_name)\n","\n"," shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n"," dataset_size = len(os.listdir(Training_source))\n","\n"," text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(dataset_size*number_of_patches)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+config.train_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," if pretrained_model:\n"," text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(dataset_size*number_of_patches)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+config.train_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was re-trained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.multi_cell(190, 5, txt = text, align='L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n"," pdf.set_font('')\n"," if augmentation:\n"," aug_text = 'The dataset was augmented by a factor of '+str(Multiply_dataset_by)+' by'\n"," if rotate_270_degrees != 0 or rotate_90_degrees != 0:\n"," aug_text = aug_text+'\\n- rotation'\n"," if flip_left_right != 0 or flip_top_bottom != 0:\n"," aug_text = aug_text+'\\n- flipping'\n"," if random_zoom_magnification != 0:\n"," aug_text = aug_text+'\\n- random zoom magnification'\n"," if random_distortion != 0:\n"," aug_text = aug_text+'\\n- random distortion'\n"," if image_shear != 0:\n"," aug_text = aug_text+'\\n- image shearing'\n"," else:\n"," aug_text = 'No augmentation was used for training.'\n"," pdf.multi_cell(190, 5, txt=aug_text, align='L')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," if Use_Default_Advanced_Parameters:\n"," pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n"," pdf.cell(200, 5, txt='The following parameters were used for training:')\n"," pdf.ln(1)\n"," html = \"\"\" \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
number_of_epochs{0}
patch_size{1}
number_of_patches{2}
batch_size{3}
number_of_steps{4}
percentage_validation{5}
initial_learning_rate{6}
\n"," \"\"\".format(number_of_epochs,str(patch_size)+'x'+str(patch_size),number_of_patches,batch_size,number_of_steps,percentage_validation,initial_learning_rate)\n"," pdf.write_html(html)\n","\n"," #pdf.multi_cell(190, 5, txt = text_2, align='L')\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," pdf.ln(1)\n"," pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(29, 5, txt= 'Training_source:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(27, 5, txt= 'Training_target:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n"," #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(22, 5, txt= 'Model Path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n"," pdf.ln(1)\n"," pdf.cell(60, 5, txt = 'Example Training pair', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread('/content/TrainingDataExample_CARE2D.png').shape\n"," pdf.image('/content/TrainingDataExample_CARE2D.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- CARE: Weigert, Martin, et al. \"Content-aware image restoration: pushing the limits of fluorescence microscopy.\" Nature methods 15.12 (2018): 1090-1097.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," if augmentation:\n"," ref_3 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\n"," pdf.multi_cell(190, 5, txt = ref_3, align='L')\n"," pdf.ln(3)\n"," reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(model_path+'/'+model_name+'/'+model_name+\"_training_report.pdf\")\n","\n","\n","#Make a pdf summary of the QC results\n","\n","def qc_pdf_export():\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'CARE 2D'\n"," #model_name = os.path.basename(full_QC_model_path)\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(2)\n"," pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'Quality Control/QC_example_data.png').shape\n"," if os.path.exists(full_QC_model_path+'Quality Control/lossCurvePlots.png'):\n"," pdf.image(full_QC_model_path+'Quality Control/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/13))\n"," else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.', align='L')\n"," pdf.ln(2)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(3)\n"," pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'Quality Control/QC_example_data.png').shape\n"," pdf.image(full_QC_model_path+'Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/10))\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n","\n"," pdf.ln(1)\n"," html = \"\"\"\n"," \n"," \n"," \"\"\"\n"," with open(full_QC_model_path+'Quality Control/QC_metrics_'+QC_model_name+'.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," mSSIM_PvsGT = header[1]\n"," mSSIM_SvsGT = header[2]\n"," NRMSE_PvsGT = header[3]\n"," NRMSE_SvsGT = header[4]\n"," PSNR_PvsGT = header[5]\n"," PSNR_SvsGT = header[6]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,mSSIM_PvsGT,mSSIM_SvsGT,NRMSE_PvsGT,NRMSE_SvsGT,PSNR_PvsGT,PSNR_SvsGT)\n"," html = html+header\n"," for row in metrics:\n"," image = row[0]\n"," mSSIM_PvsGT = row[1]\n"," mSSIM_SvsGT = row[2]\n"," NRMSE_PvsGT = row[3]\n"," NRMSE_SvsGT = row[4]\n"," PSNR_PvsGT = row[5]\n"," PSNR_SvsGT = row[6]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,str(round(float(mSSIM_PvsGT),3)),str(round(float(mSSIM_SvsGT),3)),str(round(float(NRMSE_PvsGT),3)),str(round(float(NRMSE_SvsGT),3)),str(round(float(PSNR_PvsGT),3)),str(round(float(PSNR_SvsGT),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}{3}{4}{5}{6}
{0}{1}{2}{3}{4}{5}{6}
\"\"\"\n","\n"," pdf.write_html(html)\n","\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- CARE: Weigert, Martin, et al. \"Content-aware image restoration: pushing the limits of fluorescence microscopy.\" Nature methods 15.12 (2018): 1090-1097.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n"," pdf.ln(3)\n"," reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(full_QC_model_path+'Quality Control/'+QC_model_name+'_QC_report.pdf')\n","\n","# Build requirements file for local run\n","after = [str(m) for m in sys.modules]\n","build_requirements_file(before, after)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"b4-r1gE7Iamv"},"source":["# **2. Initialise the Colab session**\n","---"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["\n","## **2.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelerator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"BDhmUgqCStlm","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","\n","%tensorflow_version 1.x\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-oqBTeLaImnU"},"source":["## **2.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","cellView":"form"},"source":["\n","#@markdown ##Run this cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","#mounts user's Google Drive to Google Colab.\n","\n","from google.colab import drive\n","drive.mount('/content/gdrive')\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"oIwqNMJ5flZX"},"source":["** If you cannot see your files, reactivate your session by connecting to your hosted runtime.** \n","\n","\n","\"Example
Connect to a hosted runtime.
"]},{"cell_type":"markdown","metadata":{"id":"Fw0kkTU6CsU4"},"source":["# **3. Select your parameters and paths**\n","\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"WzYAA-MuaYrT"},"source":["## **3.1. Setting main training parameters**\n","---\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"CB6acvUFtWqd"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source (Low SNR images) and Training_target (High SNR images or ground truth) training data respecively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**Training Parameters**\n","\n","**`number_of_epochs`:**Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for 100-300 epochs. Evaluate the performance after training (see 5). **Default value: 50**\n","\n","**`patch_size`:** CARE divides the image into patches for training. Input the size of the patches (length of a side). The value should be smaller than the dimensions of the image and divisible by 8. **Default value: 128**\n","\n","**When choosing the patch_size, the value should be i) large enough that it will enclose many instances, ii) small enough that the resulting patches fit into the RAM.** \n","\n","**`number_of_patches`:** Input the number of the patches per image. Increasing the number of patches allows for larger training datasets. **Default value: 50** \n","\n","**Decreasing the patch size or increasing the number of patches may improve the training but may also increase the training time.**\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 16**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default or if set to zero this parameter is calculated so that each patch is seen at least once per epoch. **Default value: Number of patches / batch_size**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during training. **Default value: 10** \n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0004**"]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","cellView":"form"},"source":["#@markdown ###Path to training images:\n","\n","Training_source = \"\" #@param {type:\"string\"}\n","InputFile = Training_source+\"/*.tif\"\n","\n","Training_target = \"\" #@param {type:\"string\"}\n","OutputFile = Training_target+\"/*.tif\"\n","\n","#Define where the patch file will be saved\n","base = \"/content\"\n","\n","\n","# model name and path\n","#@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","# other parameters for training.\n","#@markdown ###Training Parameters\n","#@markdown Number of epochs:\n","number_of_epochs = 50#@param {type:\"number\"}\n","\n","#@markdown Patch size (pixels) and number\n","patch_size = 128#@param {type:\"number\"} # in pixels\n","number_of_patches = 50#@param {type:\"number\"}\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","#@markdown ###If not, please input:\n","\n","batch_size = 16#@param {type:\"number\"}\n","number_of_steps = 0#@param {type:\"number\"}\n","percentage_validation = 10 #@param {type:\"number\"}\n","initial_learning_rate = 0.0004 #@param {type:\"number\"}\n","\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 16\n"," percentage_validation = 10\n"," initial_learning_rate = 0.0004\n","\n","#Here we define the percentage to use for validation\n","percentage = percentage_validation/100\n","\n","\n","#here we check that no model with the same name already exist, if so print a warning\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n"," print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\"+W)\n"," \n","\n","# Here we disable pre-trained model by default (in case the cell is not ran)\n","Use_pretrained_model = False\n","\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = False\n","\n","print(\"Parameters initiated.\")\n","\n","# This will display a randomly chosen dataset input and output\n","random_choice = random.choice(os.listdir(Training_source))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","\n","# Here we check that the input images contains the expected dimensions\n","if len(x.shape) == 2:\n"," print(\"Image dimensions (y,x)\",x.shape)\n","\n","if not len(x.shape) == 2:\n"," print(bcolors.WARNING +\"Your images appear to have the wrong dimensions. Image dimension\",x.shape)\n","\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","#Hyperparameters failsafes\n","\n","# Here we check that patch_size is smaller than the smallest xy dimension of the image \n","\n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is divisible by 8\n","if not patch_size % 8 == 0:\n"," patch_size = ((int(patch_size / 8)-1) * 8)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is now:\",patch_size)\n","\n","\n","os.chdir(Training_target)\n","y = imread(Training_target+\"/\"+random_choice)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, norm=simple_norm(x, percent = 99), interpolation='nearest')\n","plt.title('Training source')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(y, norm=simple_norm(y, percent = 99), interpolation='nearest')\n","plt.title('Training target')\n","plt.axis('off');\n","plt.savefig('/content/TrainingDataExample_CARE2D.png',bbox_inches='tight',pad_inches=0)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xGcl7WGP4WHt"},"source":["## **3.2. Data augmentation**\n","---"]},{"cell_type":"markdown","metadata":{"id":"5Lio8hpZ4PJ1"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n"," **However, data augmentation is not a magic solution and may also introduce issues. Therefore, we recommend that you train your network with and without augmentation, and use the QC section to validate that it improves overall performances.** \n","\n","Data augmentation is performed here by [Augmentor.](/~https://github.com/mdbloice/Augmentor)\n","\n","[Augmentor](/~https://github.com/mdbloice/Augmentor) was described in the following article:\n","\n","Marcus D Bloice, Peter M Roth, Andreas Holzinger, Biomedical image augmentation using Augmentor, Bioinformatics, https://doi.org/10.1093/bioinformatics/btz259\n","\n","**Please also cite this original paper when publishing results obtained using this notebook with augmentation enabled.** "]},{"cell_type":"code","metadata":{"id":"htqjkJWt5J_8","cellView":"form"},"source":["#Data augmentation\n","\n","Use_Data_augmentation = False #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation:\n"," !pip install Augmentor\n"," import Augmentor\n","\n","\n","#@markdown ####Choose a factor by which you want to multiply your original dataset\n","\n","Multiply_dataset_by = 5 #@param {type:\"slider\", min:1, max:30, step:1}\n","\n","Save_augmented_images = False #@param {type:\"boolean\"}\n","\n","Saving_path = \"\" #@param {type:\"string\"}\n","\n","\n","Use_Default_Augmentation_Parameters = True #@param {type:\"boolean\"}\n","#@markdown ###If not, please choose the probability of the following image manipulations to be used to augment your dataset (1 = always used; 0 = disabled ):\n","\n","#@markdown ####Mirror and rotate images\n","rotate_90_degrees = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","rotate_270_degrees = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","flip_left_right = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","flip_top_bottom = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","#@markdown ####Random image Zoom\n","\n","random_zoom = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","random_zoom_magnification = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","#@markdown ####Random image distortion\n","\n","random_distortion = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","\n","#@markdown ####Image shearing \n","\n","image_shear = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","max_image_shear = 1 #@param {type:\"slider\", min:1, max:25, step:1}\n","\n","\n","\n","\n","if Use_Default_Augmentation_Parameters:\n"," rotate_90_degrees = 0.5\n"," rotate_270_degrees = 0.5\n"," flip_left_right = 0.5\n"," flip_top_bottom = 0.5\n","\n"," if not Multiply_dataset_by >5:\n"," random_zoom = 0\n"," random_zoom_magnification = 0.9\n"," random_distortion = 0\n"," image_shear = 0\n"," max_image_shear = 10\n","\n","\n"," if Multiply_dataset_by >5:\n"," random_zoom = 0.1\n"," random_zoom_magnification = 0.9\n"," random_distortion = 0.5\n"," image_shear = 0.2\n"," max_image_shear = 5\n","\n","\n"," if Multiply_dataset_by >25:\n"," random_zoom = 0.5\n"," random_zoom_magnification = 0.8\n"," random_distortion = 0.5\n"," image_shear = 0.5\n"," max_image_shear = 20\n","\n","\n","\n","list_files = os.listdir(Training_source)\n","Nb_files = len(list_files)\n","\n","Nb_augmented_files = (Nb_files * Multiply_dataset_by)\n","\n","\n","if Use_Data_augmentation:\n"," print(\"Data augmentation enabled\")\n","# Here we set the path for the various folder were the augmented images will be loaded\n","\n","# All images are first saved into the augmented folder\n"," #Augmented_folder = \"/content/Augmented_Folder\"\n"," \n"," if not Save_augmented_images:\n"," Saving_path= \"/content\"\n","\n"," Augmented_folder = Saving_path+\"/Augmented_Folder\"\n"," if os.path.exists(Augmented_folder):\n"," shutil.rmtree(Augmented_folder)\n"," os.makedirs(Augmented_folder)\n","\n"," #Training_source_augmented = \"/content/Training_source_augmented\"\n"," Training_source_augmented = Saving_path+\"/Training_source_augmented\"\n","\n"," if os.path.exists(Training_source_augmented):\n"," shutil.rmtree(Training_source_augmented)\n"," os.makedirs(Training_source_augmented)\n","\n"," #Training_target_augmented = \"/content/Training_target_augmented\"\n"," Training_target_augmented = Saving_path+\"/Training_target_augmented\"\n","\n"," if os.path.exists(Training_target_augmented):\n"," shutil.rmtree(Training_target_augmented)\n"," os.makedirs(Training_target_augmented)\n","\n","\n","# Here we generate the augmented images\n","#Load the images\n"," p = Augmentor.Pipeline(Training_source, Augmented_folder)\n","\n","#Define the matching images\n"," p.ground_truth(Training_target)\n","#Define the augmentation possibilities\n"," if not rotate_90_degrees == 0:\n"," p.rotate90(probability=rotate_90_degrees)\n"," \n"," if not rotate_270_degrees == 0:\n"," p.rotate270(probability=rotate_270_degrees)\n","\n"," if not flip_left_right == 0:\n"," p.flip_left_right(probability=flip_left_right)\n","\n"," if not flip_top_bottom == 0:\n"," p.flip_top_bottom(probability=flip_top_bottom)\n","\n"," if not random_zoom == 0:\n"," p.zoom_random(probability=random_zoom, percentage_area=random_zoom_magnification)\n"," \n"," if not random_distortion == 0:\n"," p.random_distortion(probability=random_distortion, grid_width=4, grid_height=4, magnitude=8)\n","\n"," if not image_shear == 0:\n"," p.shear(probability=image_shear,max_shear_left=20,max_shear_right=20)\n"," \n","\n"," p.sample(int(Nb_augmented_files))\n","\n"," print(int(Nb_augmented_files),\"matching images generated\")\n","\n","# Here we sort through the images and move them back to augmented trainning source and targets folders\n","\n"," augmented_files = os.listdir(Augmented_folder)\n","\n"," for f in augmented_files:\n","\n"," if (f.startswith(\"_groundtruth_(1)_\")):\n"," shortname_noprefix = f[17:]\n"," shutil.copyfile(Augmented_folder+\"/\"+f, Training_target_augmented+\"/\"+shortname_noprefix) \n"," if not (f.startswith(\"_groundtruth_(1)_\")):\n"," shutil.copyfile(Augmented_folder+\"/\"+f, Training_source_augmented+\"/\"+f)\n"," \n","\n"," for filename in os.listdir(Training_source_augmented):\n"," os.chdir(Training_source_augmented)\n"," os.rename(filename, filename.replace('_original', ''))\n"," \n"," #Here we clean up the extra files\n"," shutil.rmtree(Augmented_folder)\n","\n","if not Use_Data_augmentation:\n"," print(bcolors.WARNING+\"Data augmentation disabled\") \n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"bQDuybvyadKU"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a CARE 2D model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"8vPkzEBNamE4","cellView":"form"},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","\n","Weights_choice = \"best\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Download the a model provided in the XXX ------------------------\n","\n"," if pretrained_model_choice == \"Model_name\":\n"," pretrained_model_name = \"Model_name\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the 2D_Demo_Model_from_Stardist_2D_paper\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path) \n"," wget.download(\"\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_'+Weights_choice+'.h5 pretrained model does not exist')\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead')\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead')\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print(bcolors.WARNING+'No pretrained network will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"rQndJj70FzfL"},"source":["# **4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"tGW2iaU6X5zi"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"id":"WMJnGJpCMa4y","cellView":"form"},"source":["#@markdown ##Create the model and dataset objects\n","\n","# --------------------- Here we delete the model folder if it already exist ------------------------\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: Model folder already exists and has been removed !!\"+W)\n"," shutil.rmtree(model_path+'/'+model_name)\n","\n","\n","\n","# --------------------- Here we load the augmented data or the raw data ------------------------\n","\n","if Use_Data_augmentation:\n"," Training_source_dir = Training_source_augmented\n"," Training_target_dir = Training_target_augmented\n","\n","if not Use_Data_augmentation:\n"," Training_source_dir = Training_source\n"," Training_target_dir = Training_target\n","# --------------------- ------------------------------------------------\n","\n","# This object holds the image pairs (GT and low), ensuring that CARE compares corresponding images.\n","# This file is saved in .npz format and later called when loading the trainig data.\n","\n","\n","raw_data = data.RawData.from_folder(\n"," basepath=base,\n"," source_dirs=[Training_source_dir], \n"," target_dir=Training_target_dir, \n"," axes='CYX', \n"," pattern='*.tif*')\n","\n","X, Y, XY_axes = data.create_patches(\n"," raw_data, \n"," patch_filter=None, \n"," patch_size=(patch_size,patch_size), \n"," n_patches_per_image=number_of_patches)\n","\n","print ('Creating 2D training dataset')\n","training_path = model_path+\"/rawdata\"\n","rawdata1 = training_path+\".npz\"\n","np.savez(training_path,X=X, Y=Y, axes=XY_axes)\n","\n","# Load Training Data\n","(X,Y), (X_val,Y_val), axes = load_training_data(rawdata1, validation_split=percentage, verbose=True)\n","c = axes_dict(axes)['C']\n","n_channel_in, n_channel_out = X.shape[c], Y.shape[c]\n","\n","%memit \n","\n","#plot of training patches.\n","plt.figure(figsize=(12,5))\n","plot_some(X[:5],Y[:5])\n","plt.suptitle('5 example training patches (top row: source, bottom row: target)');\n","\n","#plot of validation patches\n","plt.figure(figsize=(12,5))\n","plot_some(X_val[:5],Y_val[:5])\n","plt.suptitle('5 example validation patches (top row: source, bottom row: target)');\n","\n","\n","#Here we automatically define number_of_step in function of training data and batch size\n","#if (Use_Default_Advanced_Parameters):\n","if (Use_Default_Advanced_Parameters) or (number_of_steps == 0):\n"," number_of_steps = int(X.shape[0]/batch_size)+1\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","\n","#Here we create the configuration file\n","\n","config = Config(axes, n_channel_in, n_channel_out, probabilistic=True, train_steps_per_epoch=number_of_steps, train_epochs=number_of_epochs, unet_kern_size=5, unet_n_depth=3, train_batch_size=batch_size, train_learning_rate=initial_learning_rate)\n","\n","print(config)\n","vars(config)\n","\n","# Compile the CARE model for network training\n","model_training= CARE(config, model_name, basedir=model_path)\n","\n","\n","# --------------------- Using pretrained model ------------------------\n","# Load the pretrained weights \n","if Use_pretrained_model:\n"," model_training.load_weights(h5_file_path)\n","# --------------------- ---------------------- ------------------------\n","\n","pdf_export(augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"wQPz0F6JlvJR"},"source":["## **4.2. Start Training**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches.\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder from Google Drive as all data can be erased at the next training if using the same folder.\n","\n","**Of Note:** At the end of the training, your model will be automatically exported so it can be used in the CSBDeep Fiji plugin (Run your Network). You can find it in your model folder (TF_SavedModel.zip). In Fiji, Make sure to choose the right version of tensorflow. You can check at: Edit-- Options-- Tensorflow. Choose the version 1.4 (CPU or GPU depending on your system)."]},{"cell_type":"code","metadata":{"id":"j_Qm5JBmlvJg","cellView":"form"},"source":["#@markdown ##Start training\n","\n","start = time.time()\n","\n","# Start Training\n","history = model_training.train(X,Y, validation_data=(X_val,Y_val))\n","\n","print(\"Training, done.\")\n","\n","# copy the .npz to the model's folder\n","shutil.copyfile(model_path+'/rawdata.npz',model_path+'/'+model_name+'/rawdata.npz')\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n"," shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","model_training.export_TF()\n","\n","print(\"Your model has been sucessfully exported and can now also be used in the CSBdeep Fiji plugin\")\n","\n","pdf_export(trained = True, augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"QYuIOWQ3imuU"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows you to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"id":"zazOZ3wDx0zQ","cellView":"form"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = False #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","loss_displayed = False"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"yDY9dtzdUTLh"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased.\n","\n","**Note: Plots of the losses will be shown in a linear and in a log scale. This can help visualise changes in the losses at different magnitudes. However, note that if the losses are negative the plot on the log scale will be empty. This is not an error.**"]},{"cell_type":"code","metadata":{"id":"vMzSP50kMv5p","cellView":"form"},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","loss_displayed = True\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png',bbox_inches='tight',pad_inches=0)\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"biT9FI9Ri77_"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"nAs4Wni7VYbq","cellView":"form"},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","# Create a quality control/Prediction Folder\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","# Activate the pretrained model. \n","model_training = CARE(config=None, name=QC_model_name, basedir=QC_model_path)\n","\n","# List Tif images in Source_QC_folder\n","Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\n","Z = sorted(glob(Source_QC_folder_tif))\n","Z = list(map(imread,Z))\n","print('Number of test dataset found in the folder: '+str(len(Z)))\n","\n","\n","# Perform prediction on all datasets in the Source_QC folder\n","for filename in os.listdir(Source_QC_folder):\n"," img = imread(os.path.join(Source_QC_folder, filename))\n"," predicted = model_training.predict(img, axes='YX')\n"," os.chdir(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n"," imsave(filename, predicted)\n","\n","\n","def ssim(img1, img2):\n"," return structural_similarity(img1,img2,data_range=1.,full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n","\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/QC_metrics_\"+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n","\n"," # Let's loop through the provided dataset in the QC folders\n","\n","\n"," for i in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder,i)):\n"," print('Running QC on: '+i)\n"," # -------------------------------- Target test data (Ground truth) --------------------------------\n"," test_GT = io.imread(os.path.join(Target_QC_folder, i))\n","\n"," # -------------------------------- Source test data --------------------------------\n"," test_source = io.imread(os.path.join(Source_QC_folder,i))\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and Source image\n"," test_GT_norm,test_source_norm = norm_minmse(test_GT, test_source, normalize_gt=True)\n","\n"," # -------------------------------- Prediction --------------------------------\n"," test_prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\",i))\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and prediction\n"," test_GT_norm,test_prediction_norm = norm_minmse(test_GT, test_prediction, normalize_gt=True) \n","\n","\n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\n","\n"," # Calculate the SSIM maps\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT_norm, test_prediction_norm)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT_norm, test_source_norm)\n","\n"," #Save ssim_maps\n"," img_SSIM_GTvsPrediction_32bit = np.float32(img_SSIM_GTvsPrediction)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/SSIM_GTvsPrediction_'+i,img_SSIM_GTvsPrediction_32bit)\n"," img_SSIM_GTvsSource_32bit = np.float32(img_SSIM_GTvsSource)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/SSIM_GTvsSource_'+i,img_SSIM_GTvsSource_32bit)\n"," \n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n","\n"," # Save SE maps\n"," img_RSE_GTvsPrediction_32bit = np.float32(img_RSE_GTvsPrediction)\n"," img_RSE_GTvsSource_32bit = np.float32(img_RSE_GTvsSource)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/RSE_GTvsPrediction_'+i,img_RSE_GTvsPrediction_32bit)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/RSE_GTvsSource_'+i,img_RSE_GTvsSource_32bit)\n","\n","\n"," # -------------------------------- Calculate the RSE metrics and save them --------------------------------\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n"," \n"," # We can also measure the peak signal to noise ratio between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n","\n"," writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource),str(PSNR_GTvsPrediction),str(PSNR_GTvsSource)])\n","\n","\n","# All data is now processed saved\n","Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same\n","\n","plt.figure(figsize=(20,20))\n","# Currently only displays the last computed set, from memory\n","# Target (Ground-truth)\n","plt.subplot(3,3,1)\n","plt.axis('off')\n","img_GT = io.imread(os.path.join(Target_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_GT, norm=simple_norm(img_GT, percent = 99))\n","plt.title('Target',fontsize=15)\n","\n","# Source\n","plt.subplot(3,3,2)\n","plt.axis('off')\n","img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_Source, norm=simple_norm(img_Source, percent = 99))\n","plt.title('Source',fontsize=15)\n","\n","#Prediction\n","plt.subplot(3,3,3)\n","plt.axis('off')\n","img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction/\", Test_FileList[-1]))\n","plt.imshow(img_Prediction, norm=simple_norm(img_Prediction, percent = 99))\n","plt.title('Prediction',fontsize=15)\n","\n","#Setting up colours\n","cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Source\n","plt.subplot(3,3,5)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\n","plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n","plt.subplot(3,3,6)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n","plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\n","\n","#Root Squared Error between GT and Source\n","plt.subplot(3,3,8)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource, cmap = cmap, vmin=0, vmax = 1)\n","plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsSource,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n","#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\n","plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#Root Squared Error between GT and Prediction\n","plt.subplot(3,3,9)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsPrediction,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n","plt.savefig(full_QC_model_path+'Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","\n","qc_pdf_export()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"69aJVFfsqXbY"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"tcPNRq1TrMPB"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images."]},{"cell_type":"code","metadata":{"id":"Am2JSmpC0frj","cellView":"form"},"source":["#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then play the cell to predict outputs from your unseen images.\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = False #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = os.path.join(Prediction_model_path, Prediction_model_name)\n","\n","\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","\n","#Activate the pretrained model. \n","model_training = CARE(config=None, name=Prediction_model_name, basedir=Prediction_model_path)\n","\n","\n","# creates a loop, creating filenames and saving them\n","for filename in os.listdir(Data_folder):\n"," img = imread(os.path.join(Data_folder,filename))\n"," restored = model_training.predict(img, axes='YX')\n"," os.chdir(Result_folder)\n"," imsave(filename,restored)\n","\n","print(\"Images saved into folder:\", Result_folder)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"bShxBHY4vFFd"},"source":["## **6.2. Inspect the predicted output**\n","---\n","\n"]},{"cell_type":"code","metadata":{"cellView":"form","id":"6b2t6SLQvIBO"},"source":["# @markdown ##Run this cell to display a randomly chosen input and its corresponding predicted output.\n","\n","# This will display a randomly chosen dataset input and predicted output\n","random_choice = random.choice(os.listdir(Data_folder))\n","x = imread(Data_folder+\"/\"+random_choice)\n","\n","os.chdir(Result_folder)\n","y = imread(Result_folder+\"/\"+random_choice)\n","\n","plt.figure(figsize=(16,8))\n","\n","plt.subplot(1,2,1)\n","plt.axis('off')\n","plt.imshow(x, norm=simple_norm(x, percent = 99), interpolation='nearest')\n","plt.title('Input')\n","\n","plt.subplot(1,2,2)\n","plt.axis('off')\n","plt.imshow(y, norm=simple_norm(y, percent = 99), interpolation='nearest')\n","plt.title('Predicted output');\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"YLiyBg8Vvk6E"},"source":["# **7. Version log**\n","---\n","**v1.13**: \n","\n","\n","* This version now includes an automatic restart allowing to set the h5py library to v2.10.\n","* The section 1 and 2 are now swapped for better export of *requirements.txt*.\n","This version also now includes built-in version check and the version log that \n","\n","* This version also now includes built-in version check and the version log that you're reading now."]},{"cell_type":"markdown","metadata":{"id":"u4pcBe8Z3T2J"},"source":["#**Thank you for using CARE 2D!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/CARE_3D_ZeroCostDL4Mic.ipynb b/Colab_notebooks/CARE_3D_ZeroCostDL4Mic.ipynb index 7be40fe9..ac5beaaa 100644 --- a/Colab_notebooks/CARE_3D_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/CARE_3D_ZeroCostDL4Mic.ipynb @@ -1,2082 +1 @@ -{ - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "accelerator": "GPU", - "colab": { - "name": "CARE_3D_ZeroCostDL4Mic.ipynb", - "provenance": [], - "collapsed_sections": [], - "toc_visible": true, - "machine_shape": "hm" - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.4" - } - }, - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "V9zNGvape2-I" - }, - "source": [ - "# **CARE: Content-aware image restoration (3D)**\n", - "\n", - "---\n", - "\n", - "CARE is a neural network capable of image restoration from corrupted bio-images, first published in 2018 by [Weigert *et al.* in Nature Methods](https://www.nature.com/articles/s41592-018-0216-7). The CARE network uses a U-Net network architecture and allows image restoration and resolution improvement in 2D and 3D images, in a supervised manner, using noisy images as input and low-noise images as targets for training. The function of the network is essentially determined by the set of images provided in the training dataset. For instance, if noisy images are provided as input and high signal-to-noise ratio images are provided as targets, the network will perform denoising.\n", - "\n", - " **This particular notebook enables restoration of 3D dataset. If you are interested in restoring 2D dataset, you should use the CARE 2D notebook instead.**\n", - "\n", - "---\n", - "\n", - "*Disclaimer*:\n", - "\n", - "This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n", - "\n", - "This notebook is largely based on the following paper: \n", - "\n", - "**Content-aware image restoration: pushing the limits of fluorescence microscopy**, by Weigert *et al.* published in Nature Methods in 2018 (https://www.nature.com/articles/s41592-018-0216-7)\n", - "\n", - "And source code found in: /~https://github.com/csbdeep/csbdeep\n", - "\n", - "For a more in-depth description of the features of the network,please refer to [this guide](http://csbdeep.bioimagecomputing.com/doc/) provided by the original authors of the work.\n", - "\n", - "We provide a dataset for the training of this notebook as a way to test its functionalities but the training and test data of the restoration experiments is also available from the authors of the original paper [here](https://publications.mpi-cbg.de/publications-sites/7207/).\n", - "\n", - "**Please also cite this original paper when using or developing this notebook.**" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jWAz2i7RdxUV" - }, - "source": [ - "# **How to use this notebook?**\n", - "\n", - "---\n", - "\n", - "Video describing how to use our notebooks are available on youtube:\n", - " - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n", - " - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n", - "\n", - "\n", - "---\n", - "###**Structure of a notebook**\n", - "\n", - "The notebook contains two types of cell: \n", - "\n", - "**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n", - "\n", - "**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n", - "\n", - "---\n", - "###**Table of contents, Code snippets** and **Files**\n", - "\n", - "On the top left side of the notebook you find three tabs which contain from top to bottom:\n", - "\n", - "*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n", - "\n", - "*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n", - "\n", - "*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n", - "\n", - "**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n", - "\n", - "**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n", - "\n", - "---\n", - "###**Making changes to the notebook**\n", - "\n", - "**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n", - "\n", - "To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n", - "You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "vNMDQHm0Ah-Z" - }, - "source": [ - "#**0. Before getting started**\n", - "---\n", - " For CARE to train, **it needs to have access to a paired training dataset**. This means that the same image needs to be acquired in the two conditions (for instance, low signal-to-noise ratio and high signal-to-noise ratio) and provided with indication of correspondence.\n", - "\n", - " Therefore, the data structure is important. It is necessary that all the input data are in the same folder and that all the output data is in a separate folder. The provided training dataset is already split in two folders called \"Training - Low SNR images\" (Training_source) and \"Training - high SNR images\" (Training_target). Information on how to generate a training dataset is available in our Wiki page: /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n", - "\n", - "**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n", - "\n", - " **Additionally, the corresponding input and output files need to have the same name**.\n", - "\n", - " Please note that you currently can **only use .tif files!**\n", - "\n", - " You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed. \n", - "\n", - "\n", - "Here's a common data structure that can work:\n", - "* Experiment A\n", - " - **Training dataset**\n", - " - Low SNR images (Training_source)\n", - " - img_1.tif, img_2.tif, ...\n", - " - High SNR images (Training_target)\n", - " - img_1.tif, img_2.tif, ...\n", - " - **Quality control dataset**\n", - " - Low SNR images\n", - " - img_1.tif, img_2.tif\n", - " - High SNR images\n", - " - img_1.tif, img_2.tif\n", - " - **Data to be predicted**\n", - " - **Results**\n", - "\n", - "---\n", - "**Important note**\n", - "\n", - "- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n", - "\n", - "- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n", - "\n", - "- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "b4-r1gE7Iamv" - }, - "source": [ - "# **1. Initialise the Colab session**\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "DMNHVZfHmbKb" - }, - "source": [ - "\n", - "## **1.1. Check for GPU access**\n", - "---\n", - "\n", - "By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n", - "\n", - "Go to **Runtime -> Change the Runtime type**\n", - "\n", - "**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n", - "\n", - "**Accelerator: GPU** *(Graphics processing unit)*\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "BDhmUgqCStlm", - "cellView": "form" - }, - "source": [ - "#@markdown ##Run this cell to check if you have GPU access\n", - "%tensorflow_version 1.x\n", - "import tensorflow as tf\n", - "if tf.test.gpu_device_name()=='':\n", - " print('You do not have GPU access.') \n", - " print('Did you change your runtime ?') \n", - " print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n", - " print('Expect slow performance. To access GPU try reconnecting later')\n", - "\n", - "else:\n", - " print('You have GPU access')\n", - " !nvidia-smi" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-oqBTeLaImnU" - }, - "source": [ - "## **1.2. Mount your Google Drive**\n", - "---\n", - " To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n", - "\n", - " Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n", - "\n", - " Once this is done, your data are available in the **Files** tab on the top left of notebook." - ] - }, - { - "cell_type": "code", - "metadata": { - "cellView": "form", - "id": "01Djr8v-5pPk" - }, - "source": [ - "#@markdown ##Play the cell to connect your Google Drive to Colab\n", - "\n", - "#@markdown * Click on the URL. \n", - "\n", - "#@markdown * Sign in your Google Account. \n", - "\n", - "#@markdown * Copy the authorization code. \n", - "\n", - "#@markdown * Enter the authorization code. \n", - "\n", - "#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n", - "\n", - "# mount user's Google Drive to Google Colab.\n", - "from google.colab import drive\n", - "drive.mount('/content/gdrive')" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "n4yWFoJNnoin" - }, - "source": [ - "# **2. Install CARE and dependencies**\n", - "---\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ohimRYp7UOv-" - }, - "source": [ - "## **2.1. Install key dependencies**\n", - "---\n", - " " - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "3u2mXn3XsWzd", - "cellView": "form" - }, - "source": [ - "#@markdown ##Install CARE and dependencies\n", - "\n", - "\n", - "#Here, we install libraries which are not already included in Colab.\n", - "!pip install tifffile # contains tools to operate tiff-files\n", - "!pip install csbdeep # contains tools for restoration of fluorescence microcopy images (Content-aware Image Restoration, CARE). It uses Keras and Tensorflow.\n", - "!pip install wget\n", - "!pip install fpdf\n", - "!pip install memory_profiler\n", - "\n", - "#Force session restart\n", - "exit(0)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "0ntZU7cKUSpp" - }, - "source": [ - "## **2.2. Restart your runtime**\n", - "---\n", - "\n", - "\n", - "\n", - "** Your Runtime has automatically restarted. This is normal.**\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "GV9Saw1JUVVP" - }, - "source": [ - "## **2.3. Load key dependencies**\n", - "---\n", - " " - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "TYEDoeMMUY96", - "cellView": "form" - }, - "source": [ - "#@markdown ##Load key dependencies\n", - "\n", - "Notebook_version = ['1.12']\n", - "\n", - "from builtins import any as b_any\n", - "\n", - "def get_requirements_path():\n", - " # Store requirements file in 'contents' directory \n", - " current_dir = os.getcwd()\n", - " dir_count = current_dir.count('/') - 1\n", - " path = '../' * (dir_count) + 'requirements.txt'\n", - " return path\n", - "\n", - "def filter_files(file_list, filter_list):\n", - " filtered_list = []\n", - " for fname in file_list:\n", - " if b_any(fname.split('==')[0] in s for s in filter_list):\n", - " filtered_list.append(fname)\n", - " return filtered_list\n", - "\n", - "def build_requirements_file(before, after):\n", - " path = get_requirements_path()\n", - "\n", - " # Exporting requirements.txt for local run\n", - " !pip freeze > $path\n", - "\n", - " # Get minimum requirements file\n", - " df = pd.read_csv(path, delimiter = \"\\n\")\n", - " mod_list = [m.split('.')[0] for m in after if not m in before]\n", - " req_list_temp = df.values.tolist()\n", - " req_list = [x[0] for x in req_list_temp]\n", - "\n", - " # Replace with package name and handle cases where import name is different to module name\n", - " mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]\n", - " mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list] \n", - " filtered_list = filter_files(req_list, mod_replace_list)\n", - "\n", - " file=open(path,'w')\n", - " for item in filtered_list:\n", - " file.writelines(item + '\\n')\n", - "\n", - " file.close()\n", - "\n", - "import sys\n", - "before = [str(m) for m in sys.modules]\n", - "\n", - "#Here, we import and enable Tensorflow 1 instead of Tensorflow 2.\n", - "%tensorflow_version 1.x\n", - "import tensorflow\n", - "import tensorflow as tf\n", - "\n", - "print(tensorflow.__version__)\n", - "print(\"Tensorflow enabled.\")\n", - "\n", - "# ------- Variable specific to CARE -------\n", - "from csbdeep.utils import download_and_extract_zip_file, normalize, plot_some, axes_dict, plot_history, Path, download_and_extract_zip_file\n", - "from csbdeep.data import RawData, create_patches \n", - "from csbdeep.io import load_training_data, save_tiff_imagej_compatible\n", - "from csbdeep.models import Config, CARE\n", - "from csbdeep import data\n", - "from __future__ import print_function, unicode_literals, absolute_import, division\n", - "%matplotlib inline\n", - "%config InlineBackend.figure_format = 'retina'\n", - "\n", - "# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n", - "import numpy as np\n", - "from matplotlib import pyplot as plt\n", - "import urllib\n", - "import os, random\n", - "import shutil \n", - "import zipfile\n", - "from tifffile import imread, imsave\n", - "import time\n", - "import sys\n", - "import wget\n", - "from pathlib import Path\n", - "import pandas as pd\n", - "import csv\n", - "from glob import glob\n", - "from scipy import signal\n", - "from scipy import ndimage\n", - "from skimage import io\n", - "from sklearn.linear_model import LinearRegression\n", - "from skimage.util import img_as_uint\n", - "import matplotlib as mpl\n", - "from skimage.metrics import structural_similarity\n", - "from skimage.metrics import peak_signal_noise_ratio as psnr\n", - "from astropy.visualization import simple_norm\n", - "from skimage import img_as_float32\n", - "from skimage.util import img_as_ubyte\n", - "from tqdm import tqdm \n", - "from fpdf import FPDF, HTMLMixin\n", - "from datetime import datetime\n", - "from pip._internal.operations.freeze import freeze\n", - "import subprocess\n", - "\n", - "\n", - "# For sliders and dropdown menu and progress bar\n", - "from ipywidgets import interact\n", - "import ipywidgets as widgets\n", - "\n", - "\n", - "# Colors for the warning messages\n", - "class bcolors:\n", - " WARNING = '\\033[31m'\n", - "\n", - "W = '\\033[0m' # white (normal)\n", - "R = '\\033[31m' # red\n", - "\n", - "#Disable some of the tensorflow warnings\n", - "import warnings\n", - "warnings.filterwarnings(\"ignore\")\n", - "\n", - "print(\"Libraries installed\")\n", - "\n", - "# Check if this is the latest version of the notebook\n", - "Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n", - "\n", - "if Notebook_version == list(Latest_notebook_version.columns):\n", - " print(\"This notebook is up-to-date.\")\n", - "\n", - "if not Notebook_version == list(Latest_notebook_version.columns):\n", - " print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n", - "\n", - "\n", - "def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n", - " class MyFPDF(FPDF, HTMLMixin):\n", - " pass\n", - "\n", - " pdf = MyFPDF()\n", - " pdf.add_page()\n", - " pdf.set_right_margin(-1)\n", - " pdf.set_font(\"Arial\", size = 11, style='B') \n", - "\n", - " Network = 'CARE 3D'\n", - " #model_name = 'little_CARE_test'\n", - " day = datetime.now()\n", - " datetime_str = str(day)[0:10]\n", - "\n", - " Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n", - " pdf.multi_cell(180, 5, txt = Header, align = 'L') \n", - " \n", - " # add another cell \n", - " if trained:\n", - " training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n", - " pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n", - " pdf.ln(1)\n", - "\n", - " Header_2 = 'Information for your materials and method:'\n", - " pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n", - "\n", - " all_packages = ''\n", - " for requirement in freeze(local_only=True):\n", - " all_packages = all_packages+requirement+', '\n", - " #print(all_packages)\n", - "\n", - " #Main Packages\n", - " main_packages = ''\n", - " version_numbers = []\n", - " for name in ['tensorflow','numpy','Keras','csbdeep']:\n", - " find_name=all_packages.find(name)\n", - " main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n", - " #Version numbers only here:\n", - " version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n", - "\n", - " cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n", - " cuda_version = cuda_version.stdout.decode('utf-8')\n", - " cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n", - " gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n", - " gpu_name = gpu_name.stdout.decode('utf-8')\n", - " gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n", - " #print(cuda_version[cuda_version.find(', V')+3:-1])\n", - " #print(gpu_name)\n", - "\n", - " shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n", - " dataset_size = len(os.listdir(Training_source))\n", - "\n", - " text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(dataset_size*number_of_patches)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_height)+','+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+config.train_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n", - "\n", - " #text = 'The '+Network+' model ('+model_name+') was trained using '+str(dataset_size)+' paired images (image dimensions: '+str(shape)+') using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The GPU used was a '+gpu_name+'.'\n", - "\n", - " if pretrained_model:\n", - " text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(dataset_size*number_of_patches)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_height)+','+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+config.train_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was retrained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n", - "\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - " pdf.multi_cell(190, 5, txt = text, align='L')\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.ln(1)\n", - " pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n", - " pdf.set_font('')\n", - " if augmentation:\n", - " aug_text = 'The dataset was augmented by'\n", - " if Rotation:\n", - " aug_text = aug_text+'\\n- rotation'\n", - " if Flip:\n", - " aug_text = aug_text+'\\n- flipping'\n", - " else:\n", - " aug_text = 'No augmentation was used for training.'\n", - " pdf.multi_cell(190, 5, txt=aug_text, align='L')\n", - " pdf.set_font('Arial', size = 11, style = 'B')\n", - " pdf.ln(1)\n", - " pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - " if Use_Default_Advanced_Parameters:\n", - " pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n", - " pdf.cell(200, 5, txt='The following parameters were used for training:')\n", - " pdf.ln(1)\n", - " html = \"\"\" \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
ParameterValue
number_of_epochs{0}
patch_size{1}
patch_height{2}
number_of_patches{3}
batch_size{4}
number_of_steps{5}
percentage_validation{6}
initial_learning_rate{7}
\n", - " \"\"\".format(number_of_epochs,patch_size,patch_height,number_of_patches,batch_size,number_of_steps,percentage_validation,initial_learning_rate)\n", - " pdf.write_html(html)\n", - "\n", - " #pdf.multi_cell(190, 5, txt = text_2, align='L')\n", - " pdf.set_font(\"Arial\", size = 11, style='B')\n", - " pdf.ln(1)\n", - " pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.cell(32, 5, txt= 'Training_source:', align = 'L', ln=0)\n", - " pdf.set_font('')\n", - " pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.cell(30, 5, txt= 'Training_target:', align = 'L', ln=0)\n", - " pdf.set_font('')\n", - " pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n", - " #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n", - " pdf.ln(1)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.cell(22, 5, txt= 'Model Path:', align = 'L', ln=0)\n", - " pdf.set_font('')\n", - " pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n", - " pdf.ln(1)\n", - " pdf.cell(60, 5, txt = 'Example Training pair', ln=1)\n", - " pdf.ln(1)\n", - " exp_size = io.imread('/content/TrainingDataExample_CARE3D.png').shape\n", - " pdf.image('/content/TrainingDataExample_CARE3D.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n", - " pdf.ln(1)\n", - " ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n", - " pdf.multi_cell(190, 5, txt = ref_1, align='L')\n", - " ref_2 = '- CARE: Weigert, Martin, et al. \"Content-aware image restoration: pushing the limits of fluorescence microscopy.\" Nature methods 15.12 (2018): 1090-1097.'\n", - " pdf.multi_cell(190, 5, txt = ref_2, align='L')\n", - " # if Use_Data_augmentation:\n", - " # ref_3 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\n", - " # pdf.multi_cell(190, 5, txt = ref_3, align='L')\n", - " pdf.ln(3)\n", - " reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n", - " pdf.set_font('Arial', size = 11, style='B')\n", - " pdf.multi_cell(190, 5, txt=reminder, align='C')\n", - "\n", - " pdf.output(model_path+'/'+model_name+'/'+model_name+\"_training_report.pdf\")\n", - "\n", - "def qc_pdf_export():\n", - " class MyFPDF(FPDF, HTMLMixin):\n", - " pass\n", - "\n", - " pdf = MyFPDF()\n", - " pdf.add_page()\n", - " pdf.set_right_margin(-1)\n", - " pdf.set_font(\"Arial\", size = 11, style='B') \n", - "\n", - " Network = 'CARE 3D'\n", - " #model_name = os.path.basename(QC_model_folder)\n", - " day = datetime.now()\n", - " datetime_str = str(day)[0:10]\n", - "\n", - " Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n", - " pdf.multi_cell(180, 5, txt = Header, align = 'L') \n", - "\n", - " all_packages = ''\n", - " for requirement in freeze(local_only=True):\n", - " all_packages = all_packages+requirement+', '\n", - "\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 11, style = 'B')\n", - " pdf.ln(2)\n", - " pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n", - " exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n", - " if os.path.exists(full_QC_model_path+'/Quality Control/lossCurvePlots.png'):\n", - " pdf.image(full_QC_model_path+'/Quality Control/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/13))\n", - " else:\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size=10)\n", - " # pdf.ln(3)\n", - " pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.')\n", - " pdf.ln(3)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.ln(3)\n", - " pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n", - " pdf.ln(1)\n", - " exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n", - " pdf.image(full_QC_model_path+'/Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/10))\n", - " pdf.ln(1)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 11, style = 'B')\n", - " pdf.ln(1)\n", - " pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - "\n", - " pdf.ln(1)\n", - " html = \"\"\"\n", - " \n", - " \n", - " \"\"\"\n", - " with open(full_QC_model_path+'/Quality Control/QC_metrics_'+QC_model_name+'.csv', 'r') as csvfile:\n", - " metrics = csv.reader(csvfile)\n", - " header = next(metrics)\n", - " image = header[0]\n", - " slice_n = header[1]\n", - " mSSIM_PvsGT = header[2]\n", - " mSSIM_SvsGT = header[3]\n", - " NRMSE_PvsGT = header[4]\n", - " NRMSE_SvsGT = header[5]\n", - " PSNR_PvsGT = header[6]\n", - " PSNR_SvsGT = header[7]\n", - " header = \"\"\"\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \"\"\".format(image,slice_n,mSSIM_PvsGT,mSSIM_SvsGT,NRMSE_PvsGT,NRMSE_SvsGT,PSNR_PvsGT,PSNR_SvsGT)\n", - " html = html+header\n", - " for row in metrics:\n", - " image = row[0]\n", - " slice_n = row[1]\n", - " mSSIM_PvsGT = row[2]\n", - " mSSIM_SvsGT = row[3]\n", - " NRMSE_PvsGT = row[4]\n", - " NRMSE_SvsGT = row[5]\n", - " PSNR_PvsGT = row[6]\n", - " PSNR_SvsGT = row[7]\n", - " cells = \"\"\"\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \"\"\".format(image,slice_n,str(round(float(mSSIM_PvsGT),3)),str(round(float(mSSIM_SvsGT),3)),str(round(float(NRMSE_PvsGT),3)),str(round(float(NRMSE_SvsGT),3)),str(round(float(PSNR_PvsGT),3)),str(round(float(PSNR_SvsGT),3)))\n", - " html = html+cells\n", - " html = html+\"\"\"
{0}{1}{2}{3}{4}{5}{6}{7}
{0}{1}{2}{3}{4}{5}{6}{7}
\"\"\"\n", - " \n", - " pdf.write_html(html)\n", - "\n", - " pdf.ln(1)\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - " ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n", - " pdf.multi_cell(190, 5, txt = ref_1, align='L')\n", - " ref_2 = '- CARE: Weigert, Martin, et al. \"Content-aware image restoration: pushing the limits of fluorescence microscopy.\" Nature methods 15.12 (2018): 1090-1097.'\n", - " pdf.multi_cell(190, 5, txt = ref_2, align='L')\n", - "\n", - " pdf.ln(3)\n", - " reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n", - "\n", - " pdf.set_font('Arial', size = 11, style='B')\n", - " pdf.multi_cell(190, 5, txt=reminder, align='C')\n", - "\n", - " pdf.output(full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n", - "\n", - "# Build requirements file for local run\n", - "after = [str(m) for m in sys.modules]\n", - "build_requirements_file(before, after)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Fw0kkTU6CsU4" - }, - "source": [ - "# **3. Select your parameters and paths**\n", - "\n", - "---\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "WzYAA-MuaYrT" - }, - "source": [ - "## **3.1. Setting main training parameters**\n", - "---\n", - "\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "CB6acvUFtWqd" - }, - "source": [ - " **Paths for training, predictions and results**\n", - "\n", - "**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source (Low SNR images) and Training_target (High SNR images or ground truth) training data respecively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n", - "\n", - "**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n", - "\n", - "**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n", - "\n", - "**Training Parameters**\n", - "\n", - "**`number of epochs`:**Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for 100-300 epochs. Evaluate the performance after training (see 5.). **Default value: 40**\n", - "\n", - "**`patch_size`:** CARE divides the image into patches for training. Input the size of the patches (length of a side). The value should be smaller than the dimensions of the image and divisible by 8. **Default value: 80**\n", - "\n", - "**`patch_height`:** The value should be smaller than the Z dimensions of the image and divisible by 4. When analysing isotropic stacks patch_size and patch_height should have similar values.\n", - "\n", - "**When choosing the patch_size and patch_height, the values should be i) large enough that they will enclose many instances, ii) small enough that the resulting patches fit into the RAM.** \n", - "\n", - "**If you get an Out of memory (OOM) error during the training, manually decrease the patch_size and patch_height values until the OOM error disappear.**\n", - "\n", - "**`number_of_patches`:** Input the number of the patches per image. Increasing the number of patches allows for larger training datasets. **Default value: 200** \n", - "\n", - "**Decreasing the patch size or increasing the number of patches may improve the training but may also increase the training time.**\n", - "\n", - "**Advanced Parameters - experienced users only**\n", - "\n", - "**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 16**\n", - "\n", - "**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n", - "\n", - "**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during the training. **Default value: 10** \n", - "\n", - "**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0004**" - ] - }, - { - "cell_type": "code", - "metadata": { - "cellView": "form", - "id": "ewpNJ_I0Mv47" - }, - "source": [ - "\n", - "#@markdown ###Path to training images:\n", - "\n", - "# base folder of GT and low images\n", - "base = \"/content\"\n", - "\n", - "# low SNR images\n", - "Training_source = \"\" #@param {type:\"string\"}\n", - "lowfile = Training_source+\"/*.tif\"\n", - "# Ground truth images\n", - "Training_target = \"\" #@param {type:\"string\"}\n", - "GTfile = Training_target+\"/*.tif\"\n", - "\n", - "\n", - "# model name and path\n", - "#@markdown ###Name of the model and path to model folder:\n", - "model_name = \"\" #@param {type:\"string\"}\n", - "model_path = \"\" #@param {type:\"string\"}\n", - "\n", - "\n", - "# create the training data file into model_path folder.\n", - "training_data = model_path+\"/my_training_data.npz\"\n", - "\n", - "# other parameters for training.\n", - "#@markdown ###Training Parameters\n", - "#@markdown Number of epochs:\n", - "\n", - "number_of_epochs = 40#@param {type:\"number\"}\n", - "\n", - "#@markdown Patch size (pixels) and number\n", - "patch_size = 80#@param {type:\"number\"} # pixels in\n", - "patch_height = 8#@param {type:\"number\"}\n", - "number_of_patches = 200#@param {type:\"number\"}\n", - "\n", - "\n", - "#@markdown ###Advanced Parameters\n", - "\n", - "Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n", - "#@markdown ###If not, please input:\n", - "\n", - "batch_size = 16#@param {type:\"number\"}\n", - "number_of_steps = 300#@param {type:\"number\"}\n", - "percentage_validation = 10 #@param {type:\"number\"}\n", - "initial_learning_rate = 0.0004 #@param {type:\"number\"}\n", - "\n", - "\n", - "if (Use_Default_Advanced_Parameters): \n", - " print(\"Default advanced parameters enabled\")\n", - " batch_size = 16\n", - " percentage_validation = 10\n", - " initial_learning_rate = 0.0004\n", - "\n", - "percentage = percentage_validation/100\n", - "\n", - "\n", - "#here we check that no model with the same name already exist, if so print a warning\n", - "\n", - "if os.path.exists(model_path+'/'+model_name):\n", - " print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n", - " print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\"+W)\n", - " \n", - " \n", - "# Here we disable pre-trained model by default (in case the next cell is not ran)\n", - "Use_pretrained_model = False\n", - "\n", - "# Here we disable data augmentation by default (in case the cell is not ran)\n", - "\n", - "Use_Data_augmentation = False\n", - "\n", - "\n", - "#Load one randomly chosen training source file\n", - "\n", - "random_choice=random.choice(os.listdir(Training_source))\n", - "x = imread(Training_source+\"/\"+random_choice)\n", - "\n", - "\n", - "# Here we check that the input images are stacks\n", - "if len(x.shape) == 3:\n", - " print(\"Image dimensions (z,y,x)\",x.shape)\n", - "\n", - "if not len(x.shape) == 3:\n", - " print(bcolors.WARNING +\"Your images appear to have the wrong dimensions. Image dimension\",x.shape)\n", - "\n", - "\n", - "#Find image Z dimension and select the mid-plane\n", - "Image_Z = x.shape[0]\n", - "mid_plane = int(Image_Z / 2)+1\n", - "\n", - "\n", - "#Find image XY dimension\n", - "Image_Y = x.shape[1]\n", - "Image_X = x.shape[2]\n", - "\n", - "#Hyperparameters failsafes\n", - "\n", - "# Here we check that patch_size is smaller than the smallest xy dimension of the image \n", - "\n", - "if patch_size > min(Image_Y, Image_X):\n", - " patch_size = min(Image_Y, Image_X)\n", - " print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n", - "\n", - "# Here we check that patch_size is divisible by 8\n", - "if not patch_size % 8 == 0:\n", - " patch_size = ((int(patch_size / 8)-1) * 8)\n", - " print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is now:\",patch_size)\n", - "\n", - "# Here we check that patch_height is smaller than the z dimension of the image \n", - "\n", - "if patch_height > Image_Z :\n", - " patch_height = Image_Z\n", - " print (bcolors.WARNING + \" Your chosen patch_height is bigger than the z dimension of your image; therefore the patch_size chosen is now:\",patch_height)\n", - "\n", - "# Here we check that patch_height is divisible by 4\n", - "if not patch_height % 4 == 0:\n", - " patch_height = ((int(patch_height / 4)-1) * 4)\n", - " if patch_height == 0:\n", - " patch_height = 4\n", - " print (bcolors.WARNING + \" Your chosen patch_height is not divisible by 4; therefore the patch_size chosen is now:\",patch_height)\n", - "\n", - "\n", - "#Load one randomly chosen training target file\n", - "\n", - "os.chdir(Training_target)\n", - "y = imread(Training_target+\"/\"+random_choice)\n", - "\n", - "f=plt.figure(figsize=(16,8))\n", - "plt.subplot(1,2,1)\n", - "plt.imshow(x[mid_plane], norm=simple_norm(x[mid_plane], percent = 99), interpolation='nearest')\n", - "plt.axis('off')\n", - "plt.title('Low SNR image (single Z plane)');\n", - "plt.subplot(1,2,2)\n", - "plt.imshow(y[mid_plane], norm=simple_norm(y[mid_plane], percent = 99), interpolation='nearest')\n", - "plt.axis('off')\n", - "plt.title('High SNR image (single Z plane)');\n", - "plt.savefig('/content/TrainingDataExample_CARE3D.png',bbox_inches='tight',pad_inches=0)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "xGcl7WGP4WHt" - }, - "source": [ - "## **3.2. Data augmentation**\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "5Lio8hpZ4PJ1" - }, - "source": [ - "Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n", - "\n", - " **However, data augmentation is not a magic solution and may also introduce issues. Therefore, we recommend that you train your network with and without augmentation, and use the QC section to validate that it improves overall performances.** \n", - "\n", - "Data augmentation is performed here by rotating the training images in the XY-Plane and flipping them along X-Axis.\n", - "\n", - "**The flip option alone will double the size of your dataset, rotation will quadruple and both together will increase the dataset by a factor of 8.**" - ] - }, - { - "cell_type": "code", - "metadata": { - "cellView": "form", - "id": "htqjkJWt5J_8" - }, - "source": [ - "Use_Data_augmentation = False #@param{type:\"boolean\"}\n", - "\n", - "#@markdown Select this option if you want to use augmentation to increase the size of your dataset\n", - "\n", - "#@markdown **Rotate each image 3 times by 90 degrees.**\n", - "Rotation = True #@param{type:\"boolean\"}\n", - "\n", - "#@markdown **Flip each image once around the x axis of the stack.**\n", - "Flip = True #@param{type:\"boolean\"}\n", - "\n", - "\n", - "#@markdown **Would you like to save your augmented images?**\n", - "\n", - "Save_augmented_images = False #@param {type:\"boolean\"}\n", - "\n", - "Saving_path = \"\" #@param {type:\"string\"}\n", - "\n", - "\n", - "if not Save_augmented_images:\n", - " Saving_path= \"/content\"\n", - "\n", - "\n", - "def rotation_aug(Source_path, Target_path, flip=False):\n", - " Source_images = os.listdir(Source_path)\n", - " Target_images = os.listdir(Target_path)\n", - " \n", - " for image in Source_images:\n", - " source_img = io.imread(os.path.join(Source_path,image))\n", - " target_img = io.imread(os.path.join(Target_path,image))\n", - " \n", - " # Source Rotation\n", - " source_img_90 = np.rot90(source_img,axes=(1,2))\n", - " source_img_180 = np.rot90(source_img_90,axes=(1,2))\n", - " source_img_270 = np.rot90(source_img_180,axes=(1,2))\n", - "\n", - " # Target Rotation\n", - " target_img_90 = np.rot90(target_img,axes=(1,2))\n", - " target_img_180 = np.rot90(target_img_90,axes=(1,2))\n", - " target_img_270 = np.rot90(target_img_180,axes=(1,2))\n", - "\n", - " # Add a flip to the rotation\n", - " \n", - " if flip == True:\n", - " source_img_lr = np.fliplr(source_img)\n", - " source_img_90_lr = np.fliplr(source_img_90)\n", - " source_img_180_lr = np.fliplr(source_img_180)\n", - " source_img_270_lr = np.fliplr(source_img_270)\n", - "\n", - " target_img_lr = np.fliplr(target_img)\n", - " target_img_90_lr = np.fliplr(target_img_90)\n", - " target_img_180_lr = np.fliplr(target_img_180)\n", - " target_img_270_lr = np.fliplr(target_img_270)\n", - "\n", - " #source_img_90_ud = np.flipud(source_img_90)\n", - " \n", - " # Save the augmented files\n", - " # Source images\n", - " io.imsave(Saving_path+'/augmented_source/'+image,source_img)\n", - " io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_90.tif',source_img_90)\n", - " io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_180.tif',source_img_180)\n", - " io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_270.tif',source_img_270)\n", - " # Target images\n", - " io.imsave(Saving_path+'/augmented_target/'+image,target_img)\n", - " io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_90.tif',target_img_90)\n", - " io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_180.tif',target_img_180)\n", - " io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_270.tif',target_img_270)\n", - "\n", - " if flip == True:\n", - " io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n", - " io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_90_lr.tif',source_img_90_lr)\n", - " io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_180_lr.tif',source_img_180_lr)\n", - " io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_270_lr.tif',source_img_270_lr)\n", - "\n", - " io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n", - " io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_90_lr.tif',target_img_90_lr)\n", - " io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_180_lr.tif',target_img_180_lr)\n", - " io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_270_lr.tif',target_img_270_lr)\n", - "\n", - "def flip(Source_path, Target_path):\n", - " Source_images = os.listdir(Source_path)\n", - " Target_images = os.listdir(Target_path) \n", - "\n", - " for image in Source_images:\n", - " source_img = io.imread(os.path.join(Source_path,image))\n", - " target_img = io.imread(os.path.join(Target_path,image))\n", - " \n", - " source_img_lr = np.fliplr(source_img)\n", - " target_img_lr = np.fliplr(target_img)\n", - "\n", - " io.imsave(Saving_path+'/augmented_source/'+image,source_img)\n", - " io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n", - "\n", - " io.imsave(Saving_path+'/augmented_target/'+image,target_img)\n", - " io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n", - "\n", - "\n", - "if Use_Data_augmentation:\n", - "\n", - " if os.path.exists(Saving_path+'/augmented_source'):\n", - " shutil.rmtree(Saving_path+'/augmented_source')\n", - " os.mkdir(Saving_path+'/augmented_source')\n", - "\n", - " if os.path.exists(Saving_path+'/augmented_target'):\n", - " shutil.rmtree(Saving_path+'/augmented_target') \n", - " os.mkdir(Saving_path+'/augmented_target')\n", - "\n", - " print(\"Data augmentation enabled\")\n", - " print(\"Data augmentation in progress....\")\n", - "\n", - " if Rotation == True:\n", - " rotation_aug(Training_source,Training_target,flip=Flip)\n", - " \n", - " elif Rotation == False and Flip == True:\n", - " flip(Training_source,Training_target)\n", - " print(\"Done\")\n", - "\n", - "\n", - "if not Use_Data_augmentation:\n", - " print(bcolors.WARNING+\"Data augmentation disabled\")\n", - "\n", - "\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "bQDuybvyadKU" - }, - "source": [ - "\n", - "## **3.3. Using weights from a pre-trained model as initial weights**\n", - "---\n", - " Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a CARE 3D model**. \n", - "\n", - " This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n", - "\n", - " In order to continue training from the point where the pret-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. " - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "8vPkzEBNamE4", - "cellView": "form" - }, - "source": [ - "# @markdown ##Loading weights from a pre-trained network\n", - "\n", - "Use_pretrained_model = False #@param {type:\"boolean\"}\n", - "\n", - "pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n", - "\n", - "Weights_choice = \"best\" #@param [\"last\", \"best\"]\n", - "\n", - "\n", - "#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n", - "pretrained_model_path = \"\" #@param {type:\"string\"}\n", - "\n", - "# --------------------- Check if we load a previously trained model ------------------------\n", - "if Use_pretrained_model:\n", - "\n", - "# --------------------- Load the model from the choosen path ------------------------\n", - " if pretrained_model_choice == \"Model_from_file\":\n", - " h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n", - "\n", - "\n", - "# --------------------- Download the a model provided in the XXX ------------------------\n", - "\n", - " if pretrained_model_choice == \"Model_name\":\n", - " pretrained_model_name = \"Model_name\"\n", - " pretrained_model_path = \"/content/\"+pretrained_model_name\n", - " print(\"Downloading the 2D_Demo_Model_from_Stardist_2D_paper\")\n", - " if os.path.exists(pretrained_model_path):\n", - " shutil.rmtree(pretrained_model_path)\n", - " os.makedirs(pretrained_model_path)\n", - " wget.download(\"\", pretrained_model_path)\n", - " wget.download(\"\", pretrained_model_path)\n", - " wget.download(\"\", pretrained_model_path) \n", - " wget.download(\"\", pretrained_model_path)\n", - " h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n", - "\n", - "# --------------------- Add additional pre-trained models here ------------------------\n", - "\n", - "\n", - "\n", - "# --------------------- Check the model exist ------------------------\n", - "# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n", - " if not os.path.exists(h5_file_path):\n", - " print(bcolors.WARNING+'WARNING: weights_last.h5 pretrained model does not exist')\n", - " Use_pretrained_model = False\n", - "\n", - " \n", - "# If the model path contains a pretrain model, we load the training rate, \n", - " if os.path.exists(h5_file_path):\n", - "#Here we check if the learning rate can be loaded from the quality control folder\n", - " if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n", - "\n", - " with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n", - " csvRead = pd.read_csv(csvfile, sep=',')\n", - " #print(csvRead)\n", - " \n", - " if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n", - " print(\"pretrained network learning rate found\")\n", - " #find the last learning rate\n", - " lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n", - " #Find the learning rate corresponding to the lowest validation loss\n", - " min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n", - " #print(min_val_loss)\n", - " bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n", - "\n", - " if Weights_choice == \"last\":\n", - " print('Last learning rate: '+str(lastLearningRate))\n", - "\n", - " if Weights_choice == \"best\":\n", - " print('Learning rate of best validation loss: '+str(bestLearningRate))\n", - "\n", - " if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n", - " bestLearningRate = initial_learning_rate\n", - " lastLearningRate = initial_learning_rate\n", - " print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead')\n", - "\n", - "#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n", - " if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n", - " print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead')\n", - " bestLearningRate = initial_learning_rate\n", - " lastLearningRate = initial_learning_rate\n", - "\n", - "\n", - "# Display info about the pretrained model to be loaded (or not)\n", - "if Use_pretrained_model:\n", - " print('Weights found in:')\n", - " print(h5_file_path)\n", - " print('will be loaded prior to training.')\n", - "\n", - "else:\n", - " print(bcolors.WARNING+'No pretrained nerwork will be used.')\n", - "\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "rQndJj70FzfL" - }, - "source": [ - "# **4. Train the network**\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tGW2iaU6X5zi" - }, - "source": [ - "## **4.1. Prepare the training data and model for training**\n", - "---\n", - "Here, we use the information from 3. to build the model and convert the training data into a suitable format for training." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "WMJnGJpCMa4y", - "cellView": "form" - }, - "source": [ - "#@markdown ##Create the model and dataset objects\n", - "\n", - "# --------------------- Here we delete the model folder if it already exist ------------------------\n", - "\n", - "if os.path.exists(model_path+'/'+model_name):\n", - " print(bcolors.WARNING +\"!! WARNING: Model folder already exists and has been removed !!\" + W)\n", - " shutil.rmtree(model_path+'/'+model_name)\n", - "\n", - "\n", - "# This object holds the image pairs (GT and low), ensuring that CARE compares corresponding images.\n", - "# This file is saved in .npz format and later called when loading the trainig data.\n", - "\n", - "if Use_Data_augmentation == True:\n", - " Training_source = Saving_path+'/augmented_source'\n", - " Training_target = Saving_path+'/augmented_target'\n", - "\n", - "raw_data = RawData.from_folder (\n", - " basepath = base,\n", - " source_dirs = [Training_source],\n", - " target_dir = Training_target,\n", - " axes = 'ZYX',\n", - " pattern='*.tif*'\n", - ")\n", - "X, Y, XY_axes = create_patches (\n", - " raw_data = raw_data,\n", - " patch_size = (patch_height,patch_size,patch_size),\n", - " n_patches_per_image = number_of_patches, \n", - " save_file = training_data,\n", - ")\n", - "\n", - "assert X.shape == Y.shape\n", - "print(\"shape of X,Y =\", X.shape)\n", - "print(\"axes of X,Y =\", XY_axes)\n", - "\n", - "print ('Creating 3D training dataset')\n", - "\n", - "# Load Training Data\n", - "(X,Y), (X_val,Y_val), axes = load_training_data(training_data, validation_split=percentage, verbose=True)\n", - "c = axes_dict(axes)['C']\n", - "n_channel_in, n_channel_out = X.shape[c], Y.shape[c]\n", - "\n", - "#Plot example patches\n", - "\n", - "#plot of training patches.\n", - "plt.figure(figsize=(12,5))\n", - "plot_some(X[:5],Y[:5])\n", - "plt.suptitle('5 example training patches (top row: source, bottom row: target)');\n", - "\n", - "#plot of validation patches\n", - "plt.figure(figsize=(12,5))\n", - "plot_some(X_val[:5],Y_val[:5])\n", - "plt.suptitle('5 example validation patches (top row: source, bottom row: target)');\n", - "\n", - "\n", - "#Here we automatically define number_of_step in function of training data and batch size\n", - "if (Use_Default_Advanced_Parameters): \n", - " number_of_steps= int(X.shape[0]/batch_size)+1\n", - "\n", - "\n", - "# --------------------- Using pretrained model ------------------------\n", - "#Here we ensure that the learning rate set correctly when using pre-trained models\n", - "if Use_pretrained_model:\n", - " if Weights_choice == \"last\":\n", - " initial_learning_rate = lastLearningRate\n", - "\n", - " if Weights_choice == \"best\": \n", - " initial_learning_rate = bestLearningRate\n", - "# --------------------- ---------------------- ------------------------\n", - "\n", - "\n", - "#Here, we create the default Config object which sets the hyperparameters of the network training.\n", - "\n", - "config = Config(axes, n_channel_in, n_channel_out, train_steps_per_epoch=number_of_steps, train_epochs=number_of_epochs, train_batch_size=batch_size, train_learning_rate=initial_learning_rate)\n", - "print(config)\n", - "vars(config)\n", - "\n", - "# Compile the CARE model for network training\n", - "\n", - "model_training= CARE(config, model_name, basedir=model_path)\n", - "\n", - "# --------------------- Using pretrained model ------------------------\n", - "# Load the pretrained weights \n", - "if Use_pretrained_model:\n", - " model_training.load_weights(h5_file_path)\n", - "# --------------------- ---------------------- ------------------------\n", - "\n", - "pdf_export(augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "wQPz0F6JlvJR" - }, - "source": [ - "## **4.2. Start Training**\n", - "---\n", - "When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n", - "\n", - "* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point.\n", - "\n", - "Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder from Google Drive as all data can be erased at the next training if using the same folder.\n", - "\n", - "**Of Note:** At the end of the training, your model will be automatically exported so it can be used in the CSB Fiji plugin (Run your Network). You can find it in your model folder (TF_SavedModel.zip). In Fiji, Make sure to choose the right version of tensorflow. You can check at: Edit-- Options-- Tensorflow. Choose the version 1.4 (CPU or GPU depending on your system)." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "j_Qm5JBmlvJg", - "cellView": "form" - }, - "source": [ - "#@markdown ##Start Training\n", - "\n", - "start = time.time()\n", - "\n", - "# Start Training\n", - "history = model_training.train(X,Y, validation_data=(X_val,Y_val))\n", - "\n", - "print(\"Training, done.\")\n", - "\n", - "# convert the history.history dict to a pandas DataFrame: \n", - "lossData = pd.DataFrame(history.history) \n", - "\n", - "if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n", - " shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n", - "\n", - "os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n", - "\n", - "# The training evaluation.csv is saved (overwrites the Files if needed). \n", - "lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n", - "with open(lossDataCSVpath, 'w') as f:\n", - " writer = csv.writer(f)\n", - " writer.writerow(['loss','val_loss', 'learning rate'])\n", - " for i in range(len(history.history['loss'])):\n", - " writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n", - "\n", - "\n", - "# Displaying the time elapsed for training\n", - "dt = time.time() - start\n", - "mins, sec = divmod(dt, 60) \n", - "hour, mins = divmod(mins, 60) \n", - "print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n", - "\n", - "model_training.export_TF()\n", - "\n", - "print(\"Your model has been sucessfully exported and can now also be used in the CSBdeep Fiji plugin\")\n", - "\n", - "#Create a pdf document with training summary\n", - "pdf_export(trained=True, augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "QYuIOWQ3imuU" - }, - "source": [ - "# **5. Evaluate your model**\n", - "---\n", - "\n", - "This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n", - "\n", - "**We highly recommend to perform quality control on all newly trained models.**\n", - "\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "cellView": "form", - "id": "zazOZ3wDx0zQ" - }, - "source": [ - "# model name and path\n", - "#@markdown ###Do you want to assess the model you just trained ?\n", - "Use_the_current_trained_model = True #@param {type:\"boolean\"}\n", - "\n", - "#@markdown ###If not, please provide the path to the model folder:\n", - "\n", - "QC_model_folder = \"\" #@param {type:\"string\"}\n", - "\n", - "#Here we define the loaded model name and path\n", - "QC_model_name = os.path.basename(QC_model_folder)\n", - "QC_model_path = os.path.dirname(QC_model_folder)\n", - "\n", - "if (Use_the_current_trained_model): \n", - " QC_model_name = model_name\n", - " QC_model_path = model_path\n", - "\n", - "full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n", - "if os.path.exists(full_QC_model_path):\n", - " print(\"The \"+QC_model_name+\" network will be evaluated\")\n", - "else:\n", - " W = '\\033[0m' # white (normal)\n", - " R = '\\033[31m' # red\n", - " print(R+'!! WARNING: The chosen model does not exist !!'+W)\n", - " print('Please make sure you provide a valid model path and model name before proceeding further.')\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "yDY9dtzdUTLh" - }, - "source": [ - "## **5.1. Inspection of the loss function**\n", - "---\n", - "\n", - "First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n", - "\n", - "**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n", - "\n", - "**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n", - "\n", - "During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n", - "\n", - "Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased." - ] - }, - { - "cell_type": "code", - "metadata": { - "cellView": "form", - "id": "vMzSP50kMv5p" - }, - "source": [ - "#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n", - "\n", - "lossDataFromCSV = []\n", - "vallossDataFromCSV = []\n", - "\n", - "with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n", - " csvRead = csv.reader(csvfile, delimiter=',')\n", - " next(csvRead)\n", - " for row in csvRead:\n", - " lossDataFromCSV.append(float(row[0]))\n", - " vallossDataFromCSV.append(float(row[1]))\n", - "\n", - "epochNumber = range(len(lossDataFromCSV))\n", - "plt.figure(figsize=(15,10))\n", - "\n", - "plt.subplot(2,1,1)\n", - "plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n", - "plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n", - "plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n", - "plt.ylabel('Loss')\n", - "plt.xlabel('Epoch number')\n", - "plt.legend()\n", - "\n", - "plt.subplot(2,1,2)\n", - "plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n", - "plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n", - "plt.title('Training loss and validation loss vs. epoch number (log scale)')\n", - "plt.ylabel('Loss')\n", - "plt.xlabel('Epoch number')\n", - "plt.legend()\n", - "plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png')\n", - "plt.show()\n", - "\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "biT9FI9Ri77_" - }, - "source": [ - "## **5.2. Error mapping and quality metrics estimation**\n", - "---\n", - "\n", - "This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" !\n", - "\n", - "**1. The SSIM (structural similarity) map** \n", - "\n", - "The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n", - "\n", - "**mSSIM** is the SSIM value calculated across the entire window of both images.\n", - "\n", - "**The output below shows the SSIM maps with the mSSIM**\n", - "\n", - "**2. The RSE (Root Squared Error) map** \n", - "\n", - "This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n", - "\n", - "\n", - "**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n", - "\n", - "**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n", - "\n", - "**The output below shows the RSE maps with the NRMSE and PSNR values.**\n", - "\n", - "\n", - "\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "nAs4Wni7VYbq", - "cellView": "form" - }, - "source": [ - "#@markdown ##Choose the folders that contain your Quality Control dataset\n", - "\n", - "\n", - "Source_QC_folder = \"\" #@param{type:\"string\"}\n", - "Target_QC_folder = \"\" #@param{type:\"string\"}\n", - "\n", - "path_metrics_save = QC_model_path+'/'+QC_model_name+'/Quality Control/'\n", - "\n", - "# Create a quality control/Prediction Folder\n", - "if os.path.exists(path_metrics_save+'Prediction'):\n", - " shutil.rmtree(path_metrics_save+'Prediction')\n", - "os.makedirs(path_metrics_save+'Prediction')\n", - "\n", - "#Here we allow the user to choose the number of tile to be used when predicting the images\n", - "#@markdown #####To analyse large image, your images need to be divided into tiles. Each tile will then be processed independently and re-assembled to generate the final image. \"Automatic_number_of_tiles\" will search for and use the smallest number of tiles that can be used, at the expanse of your runtime. Alternatively, manually input the number of tiles in each dimension to be used to process your images. \n", - "\n", - "Automatic_number_of_tiles = False #@param {type:\"boolean\"}\n", - "#@markdown #####If you get an Out of memory (OOM) error when using the \"Automatic_number_of_tiles\" option, disable it and manually input the values to be used to process your images. Progressively increases these numbers until the OOM error disappear.\n", - "n_tiles_Z = 1#@param {type:\"number\"}\n", - "n_tiles_Y = 2#@param {type:\"number\"}\n", - "n_tiles_X = 2#@param {type:\"number\"}\n", - "\n", - "if (Automatic_number_of_tiles): \n", - " n_tilesZYX = None\n", - "\n", - "if not (Automatic_number_of_tiles):\n", - " n_tilesZYX = (n_tiles_Z, n_tiles_Y, n_tiles_X)\n", - "\n", - "# Activate the pretrained model. \n", - "model_training = CARE(config=None, name=QC_model_name, basedir=QC_model_path)\n", - "\n", - "# List Tif images in Source_QC_folder\n", - "Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\n", - "Z = sorted(glob(Source_QC_folder_tif))\n", - "Z = list(map(imread,Z))\n", - "print('Number of test dataset found in the folder: '+str(len(Z)))\n", - "\n", - "\n", - "# Perform prediction on all datasets in the Source_QC folder\n", - "for filename in os.listdir(Source_QC_folder):\n", - " img = imread(os.path.join(Source_QC_folder, filename))\n", - " n_slices = img.shape[0]\n", - " predicted = model_training.predict(img, axes='ZYX', n_tiles=n_tilesZYX)\n", - " os.chdir(path_metrics_save+'Prediction/')\n", - " imsave('Predicted_'+filename, predicted)\n", - "\n", - "\n", - "def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n", - " \"\"\"This function is adapted from Martin Weigert\"\"\"\n", - " \"\"\"Percentile-based image normalization.\"\"\"\n", - "\n", - " mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n", - " ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n", - " return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n", - "\n", - "\n", - "def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n", - " \"\"\"This function is adapted from Martin Weigert\"\"\"\n", - " if dtype is not None:\n", - " x = x.astype(dtype,copy=False)\n", - " mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n", - " ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n", - " eps = dtype(eps)\n", - "\n", - " try:\n", - " import numexpr\n", - " x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n", - " except ImportError:\n", - " x = (x - mi) / ( ma - mi + eps )\n", - "\n", - " if clip:\n", - " x = np.clip(x,0,1)\n", - "\n", - " return x\n", - "\n", - "def norm_minmse(gt, x, normalize_gt=True):\n", - " \"\"\"This function is adapted from Martin Weigert\"\"\"\n", - "\n", - " \"\"\"\n", - " normalizes and affinely scales an image pair such that the MSE is minimized \n", - " \n", - " Parameters\n", - " ----------\n", - " gt: ndarray\n", - " the ground truth image \n", - " x: ndarray\n", - " the image that will be affinely scaled \n", - " normalize_gt: bool\n", - " set to True of gt image should be normalized (default)\n", - " Returns\n", - " -------\n", - " gt_scaled, x_scaled \n", - " \"\"\"\n", - " if normalize_gt:\n", - " gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n", - " x = x.astype(np.float32, copy=False) - np.mean(x)\n", - " #x = x - np.mean(x)\n", - " gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n", - " #gt = gt - np.mean(gt)\n", - " scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n", - " return gt, scale * x\n", - "\n", - "\n", - "\n", - "# Open and create the csv file that will contain all the QC metrics\n", - "with open(path_metrics_save+'QC_metrics_'+QC_model_name+\".csv\", \"w\", newline='') as file:\n", - " writer = csv.writer(file)\n", - "\n", - " # Write the header in the csv file\n", - " writer.writerow([\"File name\",\"Slice #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n", - " \n", - " # These lists will be used to collect all the metrics values per slice\n", - " file_name_list = []\n", - " slice_number_list = []\n", - " mSSIM_GvP_list = []\n", - " mSSIM_GvS_list = []\n", - " NRMSE_GvP_list = []\n", - " NRMSE_GvS_list = []\n", - " PSNR_GvP_list = []\n", - " PSNR_GvS_list = []\n", - "\n", - " # These lists will be used to display the mean metrics for the stacks\n", - " mSSIM_GvP_list_mean = []\n", - " mSSIM_GvS_list_mean = []\n", - " NRMSE_GvP_list_mean = []\n", - " NRMSE_GvS_list_mean = []\n", - " PSNR_GvP_list_mean = []\n", - " PSNR_GvS_list_mean = []\n", - "\n", - " # Let's loop through the provided dataset in the QC folders\n", - " for thisFile in os.listdir(Source_QC_folder):\n", - " if not os.path.isdir(os.path.join(Source_QC_folder, thisFile)):\n", - " print('Running QC on: '+thisFile)\n", - "\n", - " test_GT_stack = io.imread(os.path.join(Target_QC_folder, thisFile))\n", - " test_source_stack = io.imread(os.path.join(Source_QC_folder,thisFile))\n", - " test_prediction_stack = io.imread(os.path.join(path_metrics_save+\"Prediction/\",'Predicted_'+thisFile))\n", - " n_slices = test_GT_stack.shape[0]\n", - "\n", - " # Calculating the position of the mid-plane slice\n", - " z_mid_plane = int(n_slices / 2)+1\n", - "\n", - " img_SSIM_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n", - " img_SSIM_GTvsSource_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n", - " img_RSE_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n", - " img_RSE_GTvsSource_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n", - "\n", - " for z in range(n_slices): \n", - " # -------------------------------- Normalising the dataset --------------------------------\n", - "\n", - " test_GT_norm, test_source_norm = norm_minmse(test_GT_stack[z], test_source_stack[z], normalize_gt=True)\n", - " test_GT_norm, test_prediction_norm = norm_minmse(test_GT_stack[z], test_prediction_stack[z], normalize_gt=True)\n", - "\n", - " # -------------------------------- Calculate the SSIM metric and maps --------------------------------\n", - "\n", - " # Calculate the SSIM maps and index\n", - " index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = structural_similarity(test_GT_norm, test_prediction_norm, data_range=1.0, full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n", - " index_SSIM_GTvsSource, img_SSIM_GTvsSource = structural_similarity(test_GT_norm, test_source_norm, data_range=1.0, full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n", - "\n", - " #Calculate ssim_maps\n", - " img_SSIM_GTvsPrediction_stack[z] = img_as_float32(img_SSIM_GTvsPrediction, force_copy=False)\n", - " img_SSIM_GTvsSource_stack[z] = img_as_float32(img_SSIM_GTvsSource, force_copy=False)\n", - " \n", - "\n", - " # -------------------------------- Calculate the NRMSE metrics --------------------------------\n", - "\n", - " # Calculate the Root Squared Error (RSE) maps\n", - " img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n", - " img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n", - "\n", - " # Calculate SE maps\n", - " img_RSE_GTvsPrediction_stack[z] = img_as_float32(img_RSE_GTvsPrediction, force_copy=False)\n", - " img_RSE_GTvsSource_stack[z] = img_as_float32(img_RSE_GTvsSource, force_copy=False)\n", - "\n", - " # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n", - " NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n", - " NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n", - "\n", - " # Calculate the PSNR between the images\n", - " PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n", - " PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n", - "\n", - " writer.writerow([thisFile, str(z),str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource), str(PSNR_GTvsPrediction), str(PSNR_GTvsSource)])\n", - " \n", - " # Collect values to display in dataframe output\n", - " slice_number_list.append(z)\n", - " mSSIM_GvP_list.append(index_SSIM_GTvsPrediction)\n", - " mSSIM_GvS_list.append(index_SSIM_GTvsSource)\n", - " NRMSE_GvP_list.append(NRMSE_GTvsPrediction)\n", - " NRMSE_GvS_list.append(NRMSE_GTvsSource)\n", - " PSNR_GvP_list.append(PSNR_GTvsPrediction)\n", - " PSNR_GvS_list.append(PSNR_GTvsSource)\n", - "\n", - " if (z == z_mid_plane): # catch these for display\n", - " SSIM_GTvsP_forDisplay = index_SSIM_GTvsPrediction\n", - " SSIM_GTvsS_forDisplay = index_SSIM_GTvsSource\n", - " NRMSE_GTvsP_forDisplay = NRMSE_GTvsPrediction\n", - " NRMSE_GTvsS_forDisplay = NRMSE_GTvsSource\n", - " \n", - " # If calculating average metrics for dataframe output\n", - " file_name_list.append(thisFile)\n", - " mSSIM_GvP_list_mean.append(sum(mSSIM_GvP_list)/len(mSSIM_GvP_list))\n", - " mSSIM_GvS_list_mean.append(sum(mSSIM_GvS_list)/len(mSSIM_GvS_list))\n", - " NRMSE_GvP_list_mean.append(sum(NRMSE_GvP_list)/len(NRMSE_GvP_list))\n", - " NRMSE_GvS_list_mean.append(sum(NRMSE_GvS_list)/len(NRMSE_GvS_list))\n", - " PSNR_GvP_list_mean.append(sum(PSNR_GvP_list)/len(PSNR_GvP_list))\n", - " PSNR_GvS_list_mean.append(sum(PSNR_GvS_list)/len(PSNR_GvS_list))\n", - "\n", - " # ----------- Change the stacks to 32 bit images -----------\n", - "\n", - " img_SSIM_GTvsSource_stack_32 = img_as_float32(img_SSIM_GTvsSource_stack, force_copy=False)\n", - " img_SSIM_GTvsPrediction_stack_32 = img_as_float32(img_SSIM_GTvsPrediction_stack, force_copy=False)\n", - " img_RSE_GTvsSource_stack_32 = img_as_float32(img_RSE_GTvsSource_stack, force_copy=False)\n", - " img_RSE_GTvsPrediction_stack_32 = img_as_float32(img_RSE_GTvsPrediction_stack, force_copy=False)\n", - "\n", - " # ----------- Saving the error map stacks -----------\n", - " io.imsave(path_metrics_save+'SSIM_GTvsSource_'+thisFile,img_SSIM_GTvsSource_stack_32)\n", - " io.imsave(path_metrics_save+'SSIM_GTvsPrediction_'+thisFile,img_SSIM_GTvsPrediction_stack_32)\n", - " io.imsave(path_metrics_save+'RSE_GTvsSource_'+thisFile,img_RSE_GTvsSource_stack_32)\n", - " io.imsave(path_metrics_save+'RSE_GTvsPrediction_'+thisFile,img_RSE_GTvsPrediction_stack_32)\n", - "\n", - "#Averages of the metrics per stack as dataframe output\n", - "pdResults = pd.DataFrame(file_name_list, columns = [\"File name\"])\n", - "pdResults[\"Prediction v. GT mSSIM\"] = mSSIM_GvP_list_mean\n", - "pdResults[\"Input v. GT mSSIM\"] = mSSIM_GvS_list_mean\n", - "pdResults[\"Prediction v. GT NRMSE\"] = NRMSE_GvP_list_mean\n", - "pdResults[\"Input v. GT NRMSE\"] = NRMSE_GvS_list_mean\n", - "pdResults[\"Prediction v. GT PSNR\"] = PSNR_GvP_list_mean\n", - "pdResults[\"Input v. GT PSNR\"] = PSNR_GvS_list_mean\n", - "\n", - "# All data is now processed saved\n", - "Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same way\n", - "\n", - "plt.figure(figsize=(20,20))\n", - "# Currently only displays the last computed set, from memory\n", - "# Target (Ground-truth)\n", - "plt.subplot(3,3,1)\n", - "plt.axis('off')\n", - "img_GT = io.imread(os.path.join(Target_QC_folder, Test_FileList[-1]))\n", - "\n", - "# Calculating the position of the mid-plane slice\n", - "z_mid_plane = int(img_GT.shape[0] / 2)+1\n", - "\n", - "plt.imshow(img_GT[z_mid_plane], norm=simple_norm(img_GT[z_mid_plane], percent = 99))\n", - "plt.title('Target (slice #'+str(z_mid_plane)+')')\n", - "\n", - "# Source\n", - "plt.subplot(3,3,2)\n", - "plt.axis('off')\n", - "img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n", - "plt.imshow(img_Source[z_mid_plane], norm=simple_norm(img_Source[z_mid_plane], percent = 99))\n", - "plt.title('Source (slice #'+str(z_mid_plane)+')')\n", - "\n", - "#Prediction\n", - "plt.subplot(3,3,3)\n", - "plt.axis('off')\n", - "img_Prediction = io.imread(os.path.join(path_metrics_save+'Prediction/', 'Predicted_'+Test_FileList[-1]))\n", - "plt.imshow(img_Prediction[z_mid_plane], norm=simple_norm(img_Prediction[z_mid_plane], percent = 99))\n", - "plt.title('Prediction (slice #'+str(z_mid_plane)+')')\n", - "\n", - "#Setting up colours\n", - "cmap = plt.cm.CMRmap\n", - "\n", - "#SSIM between GT and Source\n", - "plt.subplot(3,3,5)\n", - "#plt.axis('off')\n", - "plt.tick_params(\n", - " axis='both', # changes apply to the x-axis and y-axis\n", - " which='both', # both major and minor ticks are affected\n", - " bottom=False, # ticks along the bottom edge are off\n", - " top=False, # ticks along the top edge are off\n", - " left=False, # ticks along the left edge are off\n", - " right=False, # ticks along the right edge are off\n", - " labelbottom=False,\n", - " labelleft=False)\n", - "img_SSIM_GTvsSource = io.imread(os.path.join(path_metrics_save, 'SSIM_GTvsSource_'+Test_FileList[-1]))\n", - "imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource[z_mid_plane], cmap = cmap, vmin=0, vmax=1)\n", - "plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n", - "plt.title('Target vs. Source',fontsize=15)\n", - "plt.xlabel('mSSIM: '+str(round(SSIM_GTvsS_forDisplay,3)),fontsize=14)\n", - "plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n", - "\n", - "#SSIM between GT and Prediction\n", - "plt.subplot(3,3,6)\n", - "#plt.axis('off')\n", - "plt.tick_params(\n", - " axis='both', # changes apply to the x-axis and y-axis\n", - " which='both', # both major and minor ticks are affected\n", - " bottom=False, # ticks along the bottom edge are off\n", - " top=False, # ticks along the top edge are off\n", - " left=False, # ticks along the left edge are off\n", - " right=False, # ticks along the right edge are off\n", - " labelbottom=False,\n", - " labelleft=False) \n", - "img_SSIM_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'SSIM_GTvsPrediction_'+Test_FileList[-1]))\n", - "imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0,vmax=1)\n", - "plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n", - "plt.title('Target vs. Prediction',fontsize=15)\n", - "plt.xlabel('mSSIM: '+str(round(SSIM_GTvsP_forDisplay,3)),fontsize=14)\n", - "\n", - "#Root Squared Error between GT and Source\n", - "plt.subplot(3,3,8)\n", - "#plt.axis('off')\n", - "plt.tick_params(\n", - " axis='both', # changes apply to the x-axis and y-axis\n", - " which='both', # both major and minor ticks are affected\n", - " bottom=False, # ticks along the bottom edge are off\n", - " top=False, # ticks along the top edge are off\n", - " left=False, # ticks along the left edge are off\n", - " right=False, # ticks along the right edge are off\n", - " labelbottom=False,\n", - " labelleft=False)\n", - "img_RSE_GTvsSource = io.imread(os.path.join(path_metrics_save, 'RSE_GTvsSource_'+Test_FileList[-1]))\n", - "imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource[z_mid_plane], cmap = cmap, vmin=0, vmax = 1) \n", - "plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n", - "plt.title('Target vs. Source',fontsize=15)\n", - "plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsS_forDisplay,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n", - "#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\n", - "plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n", - "\n", - "#Root Squared Error between GT and Prediction\n", - "plt.subplot(3,3,9)\n", - "#plt.axis('off')\n", - "plt.tick_params(\n", - " axis='both', # changes apply to the x-axis and y-axis\n", - " which='both', # both major and minor ticks are affected\n", - " bottom=False, # ticks along the bottom edge are off\n", - " top=False, # ticks along the top edge are off\n", - " left=False, # ticks along the left edge are off\n", - " right=False, # ticks along the right edge are off\n", - " labelbottom=False,\n", - " labelleft=False) \n", - "img_RSE_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'RSE_GTvsPrediction_'+Test_FileList[-1]))\n", - "imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0, vmax=1)\n", - "plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n", - "plt.title('Target vs. Prediction',fontsize=15)\n", - "plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsP_forDisplay,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n", - "plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n", - "\n", - "print('-----------------------------------')\n", - "print('Here are the average scores for the stacks you tested in Quality control. To see values for all slices, open the .csv file saved in the Quality Control folder.')\n", - "pdResults.head()\n", - "\n", - "\n", - "#Make a pdf summary of the QC results\n", - "\n", - "qc_pdf_export()" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "69aJVFfsqXbY" - }, - "source": [ - "# **6. Using the trained model**\n", - "\n", - "---\n", - "\n", - "In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tcPNRq1TrMPB" - }, - "source": [ - "## **6.1. Generate prediction(s) from unseen dataset**\n", - "---\n", - "\n", - "The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n", - "\n", - "**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n", - "\n", - "**`Result_folder`:** This folder will contain the predicted output images." - ] - }, - { - "cell_type": "code", - "metadata": { - "cellView": "form", - "id": "Am2JSmpC0frj" - }, - "source": [ - "\n", - "#@markdown ##Provide the path to your dataset and to the folder where the prediction will be saved, then play the cell to predict output on your unseen images.\n", - "\n", - "Data_folder = \"\" #@param {type:\"string\"}\n", - "Result_folder = \"\" #@param {type:\"string\"}\n", - "\n", - " \n", - "# model name and path\n", - "#@markdown ###Do you want to use the current trained model?\n", - "Use_the_current_trained_model = True #@param {type:\"boolean\"}\n", - "\n", - "#@markdown ###If not, please provide the path to the model folder:\n", - "\n", - "Prediction_model_folder = \"\" #@param {type:\"string\"}\n", - "\n", - "#Here we find the loaded model name and parent path\n", - "Prediction_model_name = os.path.basename(Prediction_model_folder)\n", - "Prediction_model_path = os.path.dirname(Prediction_model_folder)\n", - "\n", - "\n", - "if (Use_the_current_trained_model): \n", - " print(\"Using current trained network\")\n", - " Prediction_model_name = model_name\n", - " Prediction_model_path = model_path\n", - "\n", - "full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n", - "if os.path.exists(full_Prediction_model_path):\n", - " print(\"The \"+Prediction_model_name+\" network will be used.\")\n", - "else:\n", - " W = '\\033[0m' # white (normal)\n", - " R = '\\033[31m' # red\n", - " print(R+'!! WARNING: The chosen model does not exist !!'+W)\n", - " print('Please make sure you provide a valid model path and model name before proceeding further.')\n", - "\n", - "\n", - "#Here we allow the user to choose the number of tile to be used when predicting the images\n", - "#@markdown #####To analyse large image, your images need to be divided into tiles. Each tile will then be processed independently and re-assembled to generate the final image. \"Automatic_number_of_tiles\" will search for and use the smallest number of tiles that can be used, at the expanse of your runtime. Alternatively, manually input the number of tiles in each dimension to be used to process your images. \n", - "\n", - "Automatic_number_of_tiles = False #@param {type:\"boolean\"}\n", - "#@markdown #####If you get an Out of memory (OOM) error when using the \"Automatic_number_of_tiles\" option, disable it and manually input the values to be used to process your images. Progressively increases these numbers until the OOM error disappear.\n", - "n_tiles_Z = 1#@param {type:\"number\"}\n", - "n_tiles_Y = 2#@param {type:\"number\"}\n", - "n_tiles_X = 2#@param {type:\"number\"}\n", - "\n", - "if (Automatic_number_of_tiles): \n", - " n_tilesZYX = None\n", - "\n", - "if not (Automatic_number_of_tiles):\n", - " n_tilesZYX = (n_tiles_Z, n_tiles_Y, n_tiles_X)\n", - "\n", - "#Activate the pretrained model. \n", - "model=CARE(config=None, name=Prediction_model_name, basedir=Prediction_model_path)\n", - "\n", - "print(\"Restoring images...\")\n", - "\n", - "thisdir = Path(Data_folder)\n", - "outputdir = Path(Result_folder)\n", - "suffix = '.tif'\n", - "\n", - "# r=root, d=directories, f = files\n", - "for r, d, f in os.walk(thisdir):\n", - " for file in f:\n", - " if \".tif\" in file:\n", - " print(os.path.join(r, file))\n", - "\n", - "for r, d, f in os.walk(thisdir):\n", - " for file in f:\n", - " base_filename = os.path.basename(file)\n", - " input_train = imread(os.path.join(r, file))\n", - " pred_train = model.predict(input_train, axes='ZYX', n_tiles=n_tilesZYX)\n", - " save_tiff_imagej_compatible(os.path.join(outputdir, base_filename), pred_train, axes='ZYX') \n", - "\n", - "print(\"Images saved into the result folder:\", Result_folder)\n", - "\n", - "#Display an example\n", - "\n", - "random_choice=random.choice(os.listdir(Data_folder))\n", - "x = imread(Data_folder+\"/\"+random_choice)\n", - "\n", - "z_mid_plane = int(x.shape[0] / 2)+1\n", - "\n", - "@interact\n", - "def show_results(file=os.listdir(Data_folder), z_plane=widgets.IntSlider(min=0, max=(x.shape[0]-1), step=1, value=z_mid_plane)):\n", - " x = imread(Data_folder+\"/\"+file)\n", - " y = imread(Result_folder+\"/\"+file)\n", - "\n", - " f=plt.figure(figsize=(16,8))\n", - " plt.subplot(1,2,1)\n", - " plt.imshow(x[z_plane], norm=simple_norm(x[z_plane], percent = 99), interpolation='nearest')\n", - " plt.axis('off')\n", - " plt.title('Noisy Input (single Z plane)');\n", - " plt.subplot(1,2,2)\n", - " plt.imshow(y[z_plane], norm=simple_norm(y[z_plane], percent = 99), interpolation='nearest')\n", - " plt.axis('off')\n", - " plt.title('Prediction (single Z plane)');\n", - "\n", - "\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "hvkd66PldsXB" - }, - "source": [ - "## **6.2. Download your predictions**\n", - "---\n", - "\n", - "**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "u4pcBe8Z3T2J" - }, - "source": [ - "#**Thank you for using CARE 3D!**" - ] - } - ] -} \ No newline at end of file +{"nbformat":4,"nbformat_minor":0,"metadata":{"accelerator":"GPU","colab":{"name":"CARE_3D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1hMjEc-Ex7j-jeYGclaPw2x3OgbkeC6Bl","timestamp":1610626439596},{"file_id":"1_W4q9V1ExGFldTUBvGK91E0LG5QMc7K6","timestamp":1602523405636},{"file_id":"1t9a-44km730bI7F4I08-6Xh7wEZuL98p","timestamp":1591013189418},{"file_id":"11TigzvLl4FSSwFHUNwLzZKI2IAix4Nmu","timestamp":1586415689249},{"file_id":"1_dSnxUg_qtNWjrPc7D6RWDWlCanEL4Ve","timestamp":1585153449937},{"file_id":"1bKo8jYVZPPgXPa_-Gdu1KhDnNN4vYfLx","timestamp":1583200150464}],"collapsed_sections":[],"machine_shape":"hm"},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.4"}},"cells":[{"cell_type":"markdown","metadata":{"id":"V9zNGvape2-I"},"source":["# **CARE: Content-aware image restoration (3D)**\n","\n","---\n","\n","CARE is a neural network capable of image restoration from corrupted bio-images, first published in 2018 by [Weigert *et al.* in Nature Methods](https://www.nature.com/articles/s41592-018-0216-7). The CARE network uses a U-Net network architecture and allows image restoration and resolution improvement in 2D and 3D images, in a supervised manner, using noisy images as input and low-noise images as targets for training. The function of the network is essentially determined by the set of images provided in the training dataset. For instance, if noisy images are provided as input and high signal-to-noise ratio images are provided as targets, the network will perform denoising.\n","\n"," **This particular notebook enables restoration of 3D dataset. If you are interested in restoring 2D dataset, you should use the CARE 2D notebook instead.**\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the following paper: \n","\n","**Content-aware image restoration: pushing the limits of fluorescence microscopy**, by Weigert *et al.* published in Nature Methods in 2018 (https://www.nature.com/articles/s41592-018-0216-7)\n","\n","And source code found in: /~https://github.com/csbdeep/csbdeep\n","\n","For a more in-depth description of the features of the network,please refer to [this guide](http://csbdeep.bioimagecomputing.com/doc/) provided by the original authors of the work.\n","\n","We provide a dataset for the training of this notebook as a way to test its functionalities but the training and test data of the restoration experiments is also available from the authors of the original paper [here](https://publications.mpi-cbg.de/publications-sites/7207/).\n","\n","**Please also cite this original paper when using or developing this notebook.**"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"vNMDQHm0Ah-Z"},"source":["#**0. Before getting started**\n","---\n"," For CARE to train, **it needs to have access to a paired training dataset**. This means that the same image needs to be acquired in the two conditions (for instance, low signal-to-noise ratio and high signal-to-noise ratio) and provided with indication of correspondence.\n","\n"," Therefore, the data structure is important. It is necessary that all the input data are in the same folder and that all the output data is in a separate folder. The provided training dataset is already split in two folders called \"Training - Low SNR images\" (Training_source) and \"Training - high SNR images\" (Training_target). Information on how to generate a training dataset is available in our Wiki page: /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n"," **Additionally, the corresponding input and output files need to have the same name**.\n","\n"," Please note that you currently can **only use .tif files!**\n","\n"," You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed. \n","\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Low SNR images (Training_source)\n"," - img_1.tif, img_2.tif, ...\n"," - High SNR images (Training_target)\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - Low SNR images\n"," - img_1.tif, img_2.tif\n"," - High SNR images\n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **1. Install CARE and dependencies**\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"ohimRYp7UOv-"},"source":["## **1.1. Install key dependencies**\n","---\n"," "]},{"cell_type":"code","metadata":{"id":"3u2mXn3XsWzd","cellView":"form"},"source":["#@markdown ##Install CARE and dependencies\n","\n","\n","#Here, we install libraries which are not already included in Colab.\n","!pip install tifffile # contains tools to operate tiff-files\n","!pip install csbdeep # contains tools for restoration of fluorescence microcopy images (Content-aware Image Restoration, CARE). It uses Keras and Tensorflow.\n","!pip install wget\n","!pip install fpdf\n","!pip install memory_profiler\n","\n","#Force session restart\n","exit(0)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"0ntZU7cKUSpp"},"source":["\n","## **1.2. Restart your runtime**\n","---\n","\n","\n","\n","** Ignore the following message error message. Your Runtime has automatically restarted. This is normal.**\n","\n","\"\"
\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"GV9Saw1JUVVP"},"source":["## **1.3. Load key dependencies**\n","---\n"," "]},{"cell_type":"code","metadata":{"id":"TYEDoeMMUY96","cellView":"form"},"source":["#@markdown ##Load key dependencies\n","\n","Notebook_version = '1.13'\n","Network = 'CARE (3D)'\n","\n","\n","from builtins import any as b_any\n","\n","def get_requirements_path():\n"," # Store requirements file in 'contents' directory \n"," current_dir = os.getcwd()\n"," dir_count = current_dir.count('/') - 1\n"," path = '../' * (dir_count) + 'requirements.txt'\n"," return path\n","\n","def filter_files(file_list, filter_list):\n"," filtered_list = []\n"," for fname in file_list:\n"," if b_any(fname.split('==')[0] in s for s in filter_list):\n"," filtered_list.append(fname)\n"," return filtered_list\n","\n","def build_requirements_file(before, after):\n"," path = get_requirements_path()\n","\n"," # Exporting requirements.txt for local run\n"," !pip freeze > $path\n","\n"," # Get minimum requirements file\n"," df = pd.read_csv(path, delimiter = \"\\n\")\n"," mod_list = [m.split('.')[0] for m in after if not m in before]\n"," req_list_temp = df.values.tolist()\n"," req_list = [x[0] for x in req_list_temp]\n","\n"," # Replace with package name and handle cases where import name is different to module name\n"," mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]\n"," mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list] \n"," filtered_list = filter_files(req_list, mod_replace_list)\n","\n"," file=open(path,'w')\n"," for item in filtered_list:\n"," file.writelines(item + '\\n')\n","\n"," file.close()\n","\n","import sys\n","before = [str(m) for m in sys.modules]\n","\n","#Here, we import and enable Tensorflow 1 instead of Tensorflow 2.\n","%tensorflow_version 1.x\n","import tensorflow\n","import tensorflow as tf\n","\n","print(tensorflow.__version__)\n","print(\"Tensorflow enabled.\")\n","\n","# ------- Variable specific to CARE -------\n","from csbdeep.utils import download_and_extract_zip_file, normalize, plot_some, axes_dict, plot_history, Path, download_and_extract_zip_file\n","from csbdeep.data import RawData, create_patches \n","from csbdeep.io import load_training_data, save_tiff_imagej_compatible\n","from csbdeep.models import Config, CARE\n","from csbdeep import data\n","from __future__ import print_function, unicode_literals, absolute_import, division\n","%matplotlib inline\n","%config InlineBackend.figure_format = 'retina'\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","from skimage.util import img_as_ubyte\n","from tqdm import tqdm \n","from fpdf import FPDF, HTMLMixin\n","from datetime import datetime\n","from pip._internal.operations.freeze import freeze\n","import subprocess\n","\n","\n","# For sliders and dropdown menu and progress bar\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print(\"Libraries installed\")\n","\n","# Check if this is the latest version of the notebook\n","All_notebook_versions = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_Notebook_versions.csv\", dtype=str)\n","print('Notebook version: '+Notebook_version)\n","Latest_Notebook_version = All_notebook_versions[All_notebook_versions[\"Notebook\"] == Network]['Version'].iloc[0]\n","print('Latest notebook version: '+Latest_Notebook_version)\n","if Notebook_version == Latest_Notebook_version:\n"," print(\"This notebook is up-to-date.\")\n","else:\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","\n","def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n","\n"," #model_name = 'little_CARE_test'\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," \n"," # add another cell \n"," if trained:\n"," training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n"," pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n"," pdf.ln(1)\n","\n"," Header_2 = 'Information for your materials and method:'\n"," pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n"," #print(all_packages)\n","\n"," #Main Packages\n"," main_packages = ''\n"," version_numbers = []\n"," for name in ['tensorflow','numpy','Keras','csbdeep']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n"," cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n"," cuda_version = cuda_version.stdout.decode('utf-8')\n"," cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n"," gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n"," gpu_name = gpu_name.stdout.decode('utf-8')\n"," gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n"," #print(cuda_version[cuda_version.find(', V')+3:-1])\n"," #print(gpu_name)\n","\n"," shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n"," dataset_size = len(os.listdir(Training_source))\n","\n"," text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(dataset_size*number_of_patches)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_height)+','+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+config.train_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," #text = 'The '+Network+' model ('+model_name+') was trained using '+str(dataset_size)+' paired images (image dimensions: '+str(shape)+') using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The GPU used was a '+gpu_name+'.'\n","\n"," if pretrained_model:\n"," text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(dataset_size*number_of_patches)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_height)+','+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+config.train_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was retrained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.multi_cell(190, 5, txt = text, align='L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n"," pdf.set_font('')\n"," if augmentation:\n"," aug_text = 'The dataset was augmented by'\n"," if Rotation:\n"," aug_text = aug_text+'\\n- rotation'\n"," if Flip:\n"," aug_text = aug_text+'\\n- flipping'\n"," else:\n"," aug_text = 'No augmentation was used for training.'\n"," pdf.multi_cell(190, 5, txt=aug_text, align='L')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," if Use_Default_Advanced_Parameters:\n"," pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n"," pdf.cell(200, 5, txt='The following parameters were used for training:')\n"," pdf.ln(1)\n"," html = \"\"\" \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
number_of_epochs{0}
patch_size{1}
patch_height{2}
number_of_patches{3}
batch_size{4}
number_of_steps{5}
percentage_validation{6}
initial_learning_rate{7}
\n"," \"\"\".format(number_of_epochs,patch_size,patch_height,number_of_patches,batch_size,number_of_steps,percentage_validation,initial_learning_rate)\n"," pdf.write_html(html)\n","\n"," #pdf.multi_cell(190, 5, txt = text_2, align='L')\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," pdf.ln(1)\n"," pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(32, 5, txt= 'Training_source:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(30, 5, txt= 'Training_target:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n"," #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(22, 5, txt= 'Model Path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n"," pdf.ln(1)\n"," pdf.cell(60, 5, txt = 'Example Training pair', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread('/content/TrainingDataExample_CARE3D.png').shape\n"," pdf.image('/content/TrainingDataExample_CARE3D.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- CARE: Weigert, Martin, et al. \"Content-aware image restoration: pushing the limits of fluorescence microscopy.\" Nature methods 15.12 (2018): 1090-1097.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," # if Use_Data_augmentation:\n"," # ref_3 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\n"," # pdf.multi_cell(190, 5, txt = ref_3, align='L')\n"," pdf.ln(3)\n"," reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(model_path+'/'+model_name+'/'+model_name+\"_training_report.pdf\")\n","\n","def qc_pdf_export():\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'CARE 3D'\n"," #model_name = os.path.basename(QC_model_folder)\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(2)\n"," pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n"," if os.path.exists(full_QC_model_path+'/Quality Control/lossCurvePlots.png'):\n"," pdf.image(full_QC_model_path+'/Quality Control/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/13))\n"," else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," # pdf.ln(3)\n"," pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.')\n"," pdf.ln(3)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(3)\n"," pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n"," pdf.image(full_QC_model_path+'/Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/10))\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n","\n"," pdf.ln(1)\n"," html = \"\"\"\n"," \n"," \n"," \"\"\"\n"," with open(full_QC_model_path+'/Quality Control/QC_metrics_'+QC_model_name+'.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," slice_n = header[1]\n"," mSSIM_PvsGT = header[2]\n"," mSSIM_SvsGT = header[3]\n"," NRMSE_PvsGT = header[4]\n"," NRMSE_SvsGT = header[5]\n"," PSNR_PvsGT = header[6]\n"," PSNR_SvsGT = header[7]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,slice_n,mSSIM_PvsGT,mSSIM_SvsGT,NRMSE_PvsGT,NRMSE_SvsGT,PSNR_PvsGT,PSNR_SvsGT)\n"," html = html+header\n"," for row in metrics:\n"," image = row[0]\n"," slice_n = row[1]\n"," mSSIM_PvsGT = row[2]\n"," mSSIM_SvsGT = row[3]\n"," NRMSE_PvsGT = row[4]\n"," NRMSE_SvsGT = row[5]\n"," PSNR_PvsGT = row[6]\n"," PSNR_SvsGT = row[7]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,slice_n,str(round(float(mSSIM_PvsGT),3)),str(round(float(mSSIM_SvsGT),3)),str(round(float(NRMSE_PvsGT),3)),str(round(float(NRMSE_SvsGT),3)),str(round(float(PSNR_PvsGT),3)),str(round(float(PSNR_SvsGT),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}{3}{4}{5}{6}{7}
{0}{1}{2}{3}{4}{5}{6}{7}
\"\"\"\n"," \n"," pdf.write_html(html)\n","\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- CARE: Weigert, Martin, et al. \"Content-aware image restoration: pushing the limits of fluorescence microscopy.\" Nature methods 15.12 (2018): 1090-1097.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n"," pdf.ln(3)\n"," reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n","\n","# Build requirements file for local run\n","after = [str(m) for m in sys.modules]\n","build_requirements_file(before, after)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"b4-r1gE7Iamv"},"source":["# **2. Initialise the Colab session**\n","---"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["\n","## **2.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelerator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"BDhmUgqCStlm","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-oqBTeLaImnU"},"source":["## **2.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"cellView":"form","id":"01Djr8v-5pPk"},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"DvuiNvCbbeKN"},"source":["** If you cannot see your files, reactivate your session by connecting to your hosted runtime.** \n","\n","\n","\"Example
Connect to a hosted runtime.
"]},{"cell_type":"markdown","metadata":{"id":"Fw0kkTU6CsU4"},"source":["# **3. Select your parameters and paths**\n","\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"WzYAA-MuaYrT"},"source":["## **3.1. Setting main training parameters**\n","---\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"CB6acvUFtWqd"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source (Low SNR images) and Training_target (High SNR images or ground truth) training data respecively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**Training Parameters**\n","\n","**`number of epochs`:**Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for 100-300 epochs. Evaluate the performance after training (see 5.). **Default value: 40**\n","\n","**`patch_size`:** CARE divides the image into patches for training. Input the size of the patches (length of a side). The value should be smaller than the dimensions of the image and divisible by 8. **Default value: 80**\n","\n","**`patch_height`:** The value should be smaller than the Z dimensions of the image and divisible by 4. When analysing isotropic stacks patch_size and patch_height should have similar values.\n","\n","**When choosing the patch_size and patch_height, the values should be i) large enough that they will enclose many instances, ii) small enough that the resulting patches fit into the RAM.** \n","\n","**If you get an Out of memory (OOM) error during the training, manually decrease the patch_size and patch_height values until the OOM error disappear.**\n","\n","**`number_of_patches`:** Input the number of the patches per image. Increasing the number of patches allows for larger training datasets. **Default value: 200** \n","\n","**Decreasing the patch size or increasing the number of patches may improve the training but may also increase the training time.**\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 16**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during the training. **Default value: 10** \n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0004**"]},{"cell_type":"code","metadata":{"cellView":"form","id":"ewpNJ_I0Mv47"},"source":["\n","#@markdown ###Path to training images:\n","\n","# base folder of GT and low images\n","base = \"/content\"\n","\n","# low SNR images\n","Training_source = \"\" #@param {type:\"string\"}\n","lowfile = Training_source+\"/*.tif\"\n","# Ground truth images\n","Training_target = \"\" #@param {type:\"string\"}\n","GTfile = Training_target+\"/*.tif\"\n","\n","\n","# model name and path\n","#@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","\n","# create the training data file into model_path folder.\n","training_data = model_path+\"/my_training_data.npz\"\n","\n","# other parameters for training.\n","#@markdown ###Training Parameters\n","#@markdown Number of epochs:\n","\n","number_of_epochs = 40#@param {type:\"number\"}\n","\n","#@markdown Patch size (pixels) and number\n","patch_size = 80#@param {type:\"number\"} # pixels in\n","patch_height = 8#@param {type:\"number\"}\n","number_of_patches = 200#@param {type:\"number\"}\n","\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","#@markdown ###If not, please input:\n","\n","batch_size = 16#@param {type:\"number\"}\n","number_of_steps = 300#@param {type:\"number\"}\n","percentage_validation = 10 #@param {type:\"number\"}\n","initial_learning_rate = 0.0004 #@param {type:\"number\"}\n","\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 16\n"," percentage_validation = 10\n"," initial_learning_rate = 0.0004\n","\n","percentage = percentage_validation/100\n","\n","\n","#here we check that no model with the same name already exist, if so print a warning\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n"," print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\"+W)\n"," \n"," \n","# Here we disable pre-trained model by default (in case the next cell is not ran)\n","Use_pretrained_model = False\n","\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = False\n","\n","\n","#Load one randomly chosen training source file\n","\n","random_choice=random.choice(os.listdir(Training_source))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","\n","# Here we check that the input images are stacks\n","if len(x.shape) == 3:\n"," print(\"Image dimensions (z,y,x)\",x.shape)\n","\n","if not len(x.shape) == 3:\n"," print(bcolors.WARNING +\"Your images appear to have the wrong dimensions. Image dimension\",x.shape)\n","\n","\n","#Find image Z dimension and select the mid-plane\n","Image_Z = x.shape[0]\n","mid_plane = int(Image_Z / 2)+1\n","\n","\n","#Find image XY dimension\n","Image_Y = x.shape[1]\n","Image_X = x.shape[2]\n","\n","#Hyperparameters failsafes\n","\n","# Here we check that patch_size is smaller than the smallest xy dimension of the image \n","\n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is divisible by 8\n","if not patch_size % 8 == 0:\n"," patch_size = ((int(patch_size / 8)-1) * 8)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_height is smaller than the z dimension of the image \n","\n","if patch_height > Image_Z :\n"," patch_height = Image_Z\n"," print (bcolors.WARNING + \" Your chosen patch_height is bigger than the z dimension of your image; therefore the patch_size chosen is now:\",patch_height)\n","\n","# Here we check that patch_height is divisible by 4\n","if not patch_height % 4 == 0:\n"," patch_height = ((int(patch_height / 4)-1) * 4)\n"," if patch_height == 0:\n"," patch_height = 4\n"," print (bcolors.WARNING + \" Your chosen patch_height is not divisible by 4; therefore the patch_size chosen is now:\",patch_height)\n","\n","\n","#Load one randomly chosen training target file\n","\n","os.chdir(Training_target)\n","y = imread(Training_target+\"/\"+random_choice)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x[mid_plane], norm=simple_norm(x[mid_plane], percent = 99), interpolation='nearest')\n","plt.axis('off')\n","plt.title('Low SNR image (single Z plane)');\n","plt.subplot(1,2,2)\n","plt.imshow(y[mid_plane], norm=simple_norm(y[mid_plane], percent = 99), interpolation='nearest')\n","plt.axis('off')\n","plt.title('High SNR image (single Z plane)');\n","plt.savefig('/content/TrainingDataExample_CARE3D.png',bbox_inches='tight',pad_inches=0)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xGcl7WGP4WHt"},"source":["## **3.2. Data augmentation**\n","---"]},{"cell_type":"markdown","metadata":{"id":"5Lio8hpZ4PJ1"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n"," **However, data augmentation is not a magic solution and may also introduce issues. Therefore, we recommend that you train your network with and without augmentation, and use the QC section to validate that it improves overall performances.** \n","\n","Data augmentation is performed here by rotating the training images in the XY-Plane and flipping them along X-Axis.\n","\n","**The flip option alone will double the size of your dataset, rotation will quadruple and both together will increase the dataset by a factor of 8.**"]},{"cell_type":"code","metadata":{"cellView":"form","id":"htqjkJWt5J_8"},"source":["Use_Data_augmentation = False #@param{type:\"boolean\"}\n","\n","#@markdown Select this option if you want to use augmentation to increase the size of your dataset\n","\n","#@markdown **Rotate each image 3 times by 90 degrees.**\n","Rotation = True #@param{type:\"boolean\"}\n","\n","#@markdown **Flip each image once around the x axis of the stack.**\n","Flip = True #@param{type:\"boolean\"}\n","\n","\n","#@markdown **Would you like to save your augmented images?**\n","\n","Save_augmented_images = False #@param {type:\"boolean\"}\n","\n","Saving_path = \"\" #@param {type:\"string\"}\n","\n","\n","if not Save_augmented_images:\n"," Saving_path= \"/content\"\n","\n","\n","def rotation_aug(Source_path, Target_path, flip=False):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path)\n"," \n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," # Source Rotation\n"," source_img_90 = np.rot90(source_img,axes=(1,2))\n"," source_img_180 = np.rot90(source_img_90,axes=(1,2))\n"," source_img_270 = np.rot90(source_img_180,axes=(1,2))\n","\n"," # Target Rotation\n"," target_img_90 = np.rot90(target_img,axes=(1,2))\n"," target_img_180 = np.rot90(target_img_90,axes=(1,2))\n"," target_img_270 = np.rot90(target_img_180,axes=(1,2))\n","\n"," # Add a flip to the rotation\n"," \n"," if flip == True:\n"," source_img_lr = np.fliplr(source_img)\n"," source_img_90_lr = np.fliplr(source_img_90)\n"," source_img_180_lr = np.fliplr(source_img_180)\n"," source_img_270_lr = np.fliplr(source_img_270)\n","\n"," target_img_lr = np.fliplr(target_img)\n"," target_img_90_lr = np.fliplr(target_img_90)\n"," target_img_180_lr = np.fliplr(target_img_180)\n"," target_img_270_lr = np.fliplr(target_img_270)\n","\n"," #source_img_90_ud = np.flipud(source_img_90)\n"," \n"," # Save the augmented files\n"," # Source images\n"," io.imsave(Saving_path+'/augmented_source/'+image,source_img)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_90.tif',source_img_90)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_180.tif',source_img_180)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_270.tif',source_img_270)\n"," # Target images\n"," io.imsave(Saving_path+'/augmented_target/'+image,target_img)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_90.tif',target_img_90)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_180.tif',target_img_180)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_270.tif',target_img_270)\n","\n"," if flip == True:\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_90_lr.tif',source_img_90_lr)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_180_lr.tif',source_img_180_lr)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_270_lr.tif',source_img_270_lr)\n","\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_90_lr.tif',target_img_90_lr)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_180_lr.tif',target_img_180_lr)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_270_lr.tif',target_img_270_lr)\n","\n","def flip(Source_path, Target_path):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path) \n","\n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," source_img_lr = np.fliplr(source_img)\n"," target_img_lr = np.fliplr(target_img)\n","\n"," io.imsave(Saving_path+'/augmented_source/'+image,source_img)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n","\n"," io.imsave(Saving_path+'/augmented_target/'+image,target_img)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n","\n","\n","if Use_Data_augmentation:\n","\n"," if os.path.exists(Saving_path+'/augmented_source'):\n"," shutil.rmtree(Saving_path+'/augmented_source')\n"," os.mkdir(Saving_path+'/augmented_source')\n","\n"," if os.path.exists(Saving_path+'/augmented_target'):\n"," shutil.rmtree(Saving_path+'/augmented_target') \n"," os.mkdir(Saving_path+'/augmented_target')\n","\n"," print(\"Data augmentation enabled\")\n"," print(\"Data augmentation in progress....\")\n","\n"," if Rotation == True:\n"," rotation_aug(Training_source,Training_target,flip=Flip)\n"," \n"," elif Rotation == False and Flip == True:\n"," flip(Training_source,Training_target)\n"," print(\"Done\")\n","\n","\n","if not Use_Data_augmentation:\n"," print(bcolors.WARNING+\"Data augmentation disabled\")\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"bQDuybvyadKU"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a CARE 3D model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pret-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"8vPkzEBNamE4","cellView":"form"},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","\n","Weights_choice = \"best\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","\n","# --------------------- Download the a model provided in the XXX ------------------------\n","\n"," if pretrained_model_choice == \"Model_name\":\n"," pretrained_model_name = \"Model_name\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the 2D_Demo_Model_from_Stardist_2D_paper\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path) \n"," wget.download(\"\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_last.h5 pretrained model does not exist')\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead')\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead')\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print(bcolors.WARNING+'No pretrained nerwork will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"rQndJj70FzfL"},"source":["# **4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"tGW2iaU6X5zi"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"id":"WMJnGJpCMa4y","cellView":"form"},"source":["#@markdown ##Create the model and dataset objects\n","\n","# --------------------- Here we delete the model folder if it already exist ------------------------\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: Model folder already exists and has been removed !!\" + W)\n"," shutil.rmtree(model_path+'/'+model_name)\n","\n","\n","# This object holds the image pairs (GT and low), ensuring that CARE compares corresponding images.\n","# This file is saved in .npz format and later called when loading the trainig data.\n","\n","if Use_Data_augmentation == True:\n"," Training_source = Saving_path+'/augmented_source'\n"," Training_target = Saving_path+'/augmented_target'\n","\n","raw_data = RawData.from_folder (\n"," basepath = base,\n"," source_dirs = [Training_source],\n"," target_dir = Training_target,\n"," axes = 'ZYX',\n"," pattern='*.tif*'\n",")\n","X, Y, XY_axes = create_patches (\n"," raw_data = raw_data,\n"," patch_size = (patch_height,patch_size,patch_size),\n"," n_patches_per_image = number_of_patches, \n"," save_file = training_data,\n",")\n","\n","assert X.shape == Y.shape\n","print(\"shape of X,Y =\", X.shape)\n","print(\"axes of X,Y =\", XY_axes)\n","\n","print ('Creating 3D training dataset')\n","\n","# Load Training Data\n","(X,Y), (X_val,Y_val), axes = load_training_data(training_data, validation_split=percentage, verbose=True)\n","c = axes_dict(axes)['C']\n","n_channel_in, n_channel_out = X.shape[c], Y.shape[c]\n","\n","#Plot example patches\n","\n","#plot of training patches.\n","plt.figure(figsize=(12,5))\n","plot_some(X[:5],Y[:5])\n","plt.suptitle('5 example training patches (top row: source, bottom row: target)');\n","\n","#plot of validation patches\n","plt.figure(figsize=(12,5))\n","plot_some(X_val[:5],Y_val[:5])\n","plt.suptitle('5 example validation patches (top row: source, bottom row: target)');\n","\n","\n","#Here we automatically define number_of_step in function of training data and batch size\n","if (Use_Default_Advanced_Parameters): \n"," number_of_steps= int(X.shape[0]/batch_size)+1\n","\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","\n","#Here, we create the default Config object which sets the hyperparameters of the network training.\n","\n","config = Config(axes, n_channel_in, n_channel_out, train_steps_per_epoch=number_of_steps, train_epochs=number_of_epochs, train_batch_size=batch_size, train_learning_rate=initial_learning_rate)\n","print(config)\n","vars(config)\n","\n","# Compile the CARE model for network training\n","\n","model_training= CARE(config, model_name, basedir=model_path)\n","\n","# --------------------- Using pretrained model ------------------------\n","# Load the pretrained weights \n","if Use_pretrained_model:\n"," model_training.load_weights(h5_file_path)\n","# --------------------- ---------------------- ------------------------\n","\n","pdf_export(augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"wQPz0F6JlvJR"},"source":["## **4.2. Start Training**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point.\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder from Google Drive as all data can be erased at the next training if using the same folder.\n","\n","**Of Note:** At the end of the training, your model will be automatically exported so it can be used in the CSB Fiji plugin (Run your Network). You can find it in your model folder (TF_SavedModel.zip). In Fiji, Make sure to choose the right version of tensorflow. You can check at: Edit-- Options-- Tensorflow. Choose the version 1.4 (CPU or GPU depending on your system)."]},{"cell_type":"code","metadata":{"id":"j_Qm5JBmlvJg","cellView":"form"},"source":["#@markdown ##Start Training\n","\n","start = time.time()\n","\n","# Start Training\n","history = model_training.train(X,Y, validation_data=(X_val,Y_val))\n","\n","print(\"Training, done.\")\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n"," shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","model_training.export_TF()\n","\n","print(\"Your model has been sucessfully exported and can now also be used in the CSBdeep Fiji plugin\")\n","\n","#Create a pdf document with training summary\n","pdf_export(trained=True, augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"QYuIOWQ3imuU"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"cellView":"form","id":"zazOZ3wDx0zQ"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"yDY9dtzdUTLh"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"cellView":"form","id":"vMzSP50kMv5p"},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png')\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"biT9FI9Ri77_"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"nAs4Wni7VYbq","cellView":"form"},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","path_metrics_save = QC_model_path+'/'+QC_model_name+'/Quality Control/'\n","\n","# Create a quality control/Prediction Folder\n","if os.path.exists(path_metrics_save+'Prediction'):\n"," shutil.rmtree(path_metrics_save+'Prediction')\n","os.makedirs(path_metrics_save+'Prediction')\n","\n","#Here we allow the user to choose the number of tile to be used when predicting the images\n","#@markdown #####To analyse large image, your images need to be divided into tiles. Each tile will then be processed independently and re-assembled to generate the final image. \"Automatic_number_of_tiles\" will search for and use the smallest number of tiles that can be used, at the expanse of your runtime. Alternatively, manually input the number of tiles in each dimension to be used to process your images. \n","\n","Automatic_number_of_tiles = False #@param {type:\"boolean\"}\n","#@markdown #####If you get an Out of memory (OOM) error when using the \"Automatic_number_of_tiles\" option, disable it and manually input the values to be used to process your images. Progressively increases these numbers until the OOM error disappear.\n","n_tiles_Z = 1#@param {type:\"number\"}\n","n_tiles_Y = 2#@param {type:\"number\"}\n","n_tiles_X = 2#@param {type:\"number\"}\n","\n","if (Automatic_number_of_tiles): \n"," n_tilesZYX = None\n","\n","if not (Automatic_number_of_tiles):\n"," n_tilesZYX = (n_tiles_Z, n_tiles_Y, n_tiles_X)\n","\n","# Activate the pretrained model. \n","model_training = CARE(config=None, name=QC_model_name, basedir=QC_model_path)\n","\n","# List Tif images in Source_QC_folder\n","Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\n","Z = sorted(glob(Source_QC_folder_tif))\n","Z = list(map(imread,Z))\n","print('Number of test dataset found in the folder: '+str(len(Z)))\n","\n","\n","# Perform prediction on all datasets in the Source_QC folder\n","for filename in os.listdir(Source_QC_folder):\n"," img = imread(os.path.join(Source_QC_folder, filename))\n"," n_slices = img.shape[0]\n"," predicted = model_training.predict(img, axes='ZYX', n_tiles=n_tilesZYX)\n"," os.chdir(path_metrics_save+'Prediction/')\n"," imsave('Predicted_'+filename, predicted)\n","\n","\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(path_metrics_save+'QC_metrics_'+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"File name\",\"Slice #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n"," \n"," # These lists will be used to collect all the metrics values per slice\n"," file_name_list = []\n"," slice_number_list = []\n"," mSSIM_GvP_list = []\n"," mSSIM_GvS_list = []\n"," NRMSE_GvP_list = []\n"," NRMSE_GvS_list = []\n"," PSNR_GvP_list = []\n"," PSNR_GvS_list = []\n","\n"," # These lists will be used to display the mean metrics for the stacks\n"," mSSIM_GvP_list_mean = []\n"," mSSIM_GvS_list_mean = []\n"," NRMSE_GvP_list_mean = []\n"," NRMSE_GvS_list_mean = []\n"," PSNR_GvP_list_mean = []\n"," PSNR_GvS_list_mean = []\n","\n"," # Let's loop through the provided dataset in the QC folders\n"," for thisFile in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder, thisFile)):\n"," print('Running QC on: '+thisFile)\n","\n"," test_GT_stack = io.imread(os.path.join(Target_QC_folder, thisFile))\n"," test_source_stack = io.imread(os.path.join(Source_QC_folder,thisFile))\n"," test_prediction_stack = io.imread(os.path.join(path_metrics_save+\"Prediction/\",'Predicted_'+thisFile))\n"," n_slices = test_GT_stack.shape[0]\n","\n"," # Calculating the position of the mid-plane slice\n"," z_mid_plane = int(n_slices / 2)+1\n","\n"," img_SSIM_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_SSIM_GTvsSource_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_RSE_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_RSE_GTvsSource_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n","\n"," for z in range(n_slices): \n"," # -------------------------------- Normalising the dataset --------------------------------\n","\n"," test_GT_norm, test_source_norm = norm_minmse(test_GT_stack[z], test_source_stack[z], normalize_gt=True)\n"," test_GT_norm, test_prediction_norm = norm_minmse(test_GT_stack[z], test_prediction_stack[z], normalize_gt=True)\n","\n"," # -------------------------------- Calculate the SSIM metric and maps --------------------------------\n","\n"," # Calculate the SSIM maps and index\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = structural_similarity(test_GT_norm, test_prediction_norm, data_range=1.0, full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = structural_similarity(test_GT_norm, test_source_norm, data_range=1.0, full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n"," #Calculate ssim_maps\n"," img_SSIM_GTvsPrediction_stack[z] = img_as_float32(img_SSIM_GTvsPrediction, force_copy=False)\n"," img_SSIM_GTvsSource_stack[z] = img_as_float32(img_SSIM_GTvsSource, force_copy=False)\n"," \n","\n"," # -------------------------------- Calculate the NRMSE metrics --------------------------------\n","\n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n","\n"," # Calculate SE maps\n"," img_RSE_GTvsPrediction_stack[z] = img_as_float32(img_RSE_GTvsPrediction, force_copy=False)\n"," img_RSE_GTvsSource_stack[z] = img_as_float32(img_RSE_GTvsSource, force_copy=False)\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n","\n"," # Calculate the PSNR between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n","\n"," writer.writerow([thisFile, str(z),str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource), str(PSNR_GTvsPrediction), str(PSNR_GTvsSource)])\n"," \n"," # Collect values to display in dataframe output\n"," slice_number_list.append(z)\n"," mSSIM_GvP_list.append(index_SSIM_GTvsPrediction)\n"," mSSIM_GvS_list.append(index_SSIM_GTvsSource)\n"," NRMSE_GvP_list.append(NRMSE_GTvsPrediction)\n"," NRMSE_GvS_list.append(NRMSE_GTvsSource)\n"," PSNR_GvP_list.append(PSNR_GTvsPrediction)\n"," PSNR_GvS_list.append(PSNR_GTvsSource)\n","\n"," if (z == z_mid_plane): # catch these for display\n"," SSIM_GTvsP_forDisplay = index_SSIM_GTvsPrediction\n"," SSIM_GTvsS_forDisplay = index_SSIM_GTvsSource\n"," NRMSE_GTvsP_forDisplay = NRMSE_GTvsPrediction\n"," NRMSE_GTvsS_forDisplay = NRMSE_GTvsSource\n"," \n"," # If calculating average metrics for dataframe output\n"," file_name_list.append(thisFile)\n"," mSSIM_GvP_list_mean.append(sum(mSSIM_GvP_list)/len(mSSIM_GvP_list))\n"," mSSIM_GvS_list_mean.append(sum(mSSIM_GvS_list)/len(mSSIM_GvS_list))\n"," NRMSE_GvP_list_mean.append(sum(NRMSE_GvP_list)/len(NRMSE_GvP_list))\n"," NRMSE_GvS_list_mean.append(sum(NRMSE_GvS_list)/len(NRMSE_GvS_list))\n"," PSNR_GvP_list_mean.append(sum(PSNR_GvP_list)/len(PSNR_GvP_list))\n"," PSNR_GvS_list_mean.append(sum(PSNR_GvS_list)/len(PSNR_GvS_list))\n","\n"," # ----------- Change the stacks to 32 bit images -----------\n","\n"," img_SSIM_GTvsSource_stack_32 = img_as_float32(img_SSIM_GTvsSource_stack, force_copy=False)\n"," img_SSIM_GTvsPrediction_stack_32 = img_as_float32(img_SSIM_GTvsPrediction_stack, force_copy=False)\n"," img_RSE_GTvsSource_stack_32 = img_as_float32(img_RSE_GTvsSource_stack, force_copy=False)\n"," img_RSE_GTvsPrediction_stack_32 = img_as_float32(img_RSE_GTvsPrediction_stack, force_copy=False)\n","\n"," # ----------- Saving the error map stacks -----------\n"," io.imsave(path_metrics_save+'SSIM_GTvsSource_'+thisFile,img_SSIM_GTvsSource_stack_32)\n"," io.imsave(path_metrics_save+'SSIM_GTvsPrediction_'+thisFile,img_SSIM_GTvsPrediction_stack_32)\n"," io.imsave(path_metrics_save+'RSE_GTvsSource_'+thisFile,img_RSE_GTvsSource_stack_32)\n"," io.imsave(path_metrics_save+'RSE_GTvsPrediction_'+thisFile,img_RSE_GTvsPrediction_stack_32)\n","\n","#Averages of the metrics per stack as dataframe output\n","pdResults = pd.DataFrame(file_name_list, columns = [\"File name\"])\n","pdResults[\"Prediction v. GT mSSIM\"] = mSSIM_GvP_list_mean\n","pdResults[\"Input v. GT mSSIM\"] = mSSIM_GvS_list_mean\n","pdResults[\"Prediction v. GT NRMSE\"] = NRMSE_GvP_list_mean\n","pdResults[\"Input v. GT NRMSE\"] = NRMSE_GvS_list_mean\n","pdResults[\"Prediction v. GT PSNR\"] = PSNR_GvP_list_mean\n","pdResults[\"Input v. GT PSNR\"] = PSNR_GvS_list_mean\n","\n","# All data is now processed saved\n","Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same way\n","\n","plt.figure(figsize=(20,20))\n","# Currently only displays the last computed set, from memory\n","# Target (Ground-truth)\n","plt.subplot(3,3,1)\n","plt.axis('off')\n","img_GT = io.imread(os.path.join(Target_QC_folder, Test_FileList[-1]))\n","\n","# Calculating the position of the mid-plane slice\n","z_mid_plane = int(img_GT.shape[0] / 2)+1\n","\n","plt.imshow(img_GT[z_mid_plane], norm=simple_norm(img_GT[z_mid_plane], percent = 99))\n","plt.title('Target (slice #'+str(z_mid_plane)+')')\n","\n","# Source\n","plt.subplot(3,3,2)\n","plt.axis('off')\n","img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_Source[z_mid_plane], norm=simple_norm(img_Source[z_mid_plane], percent = 99))\n","plt.title('Source (slice #'+str(z_mid_plane)+')')\n","\n","#Prediction\n","plt.subplot(3,3,3)\n","plt.axis('off')\n","img_Prediction = io.imread(os.path.join(path_metrics_save+'Prediction/', 'Predicted_'+Test_FileList[-1]))\n","plt.imshow(img_Prediction[z_mid_plane], norm=simple_norm(img_Prediction[z_mid_plane], percent = 99))\n","plt.title('Prediction (slice #'+str(z_mid_plane)+')')\n","\n","#Setting up colours\n","cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Source\n","plt.subplot(3,3,5)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","img_SSIM_GTvsSource = io.imread(os.path.join(path_metrics_save, 'SSIM_GTvsSource_'+Test_FileList[-1]))\n","imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource[z_mid_plane], cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(SSIM_GTvsS_forDisplay,3)),fontsize=14)\n","plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n","plt.subplot(3,3,6)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_SSIM_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'SSIM_GTvsPrediction_'+Test_FileList[-1]))\n","imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0,vmax=1)\n","plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(SSIM_GTvsP_forDisplay,3)),fontsize=14)\n","\n","#Root Squared Error between GT and Source\n","plt.subplot(3,3,8)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","img_RSE_GTvsSource = io.imread(os.path.join(path_metrics_save, 'RSE_GTvsSource_'+Test_FileList[-1]))\n","imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource[z_mid_plane], cmap = cmap, vmin=0, vmax = 1) \n","plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsS_forDisplay,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n","#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\n","plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#Root Squared Error between GT and Prediction\n","plt.subplot(3,3,9)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_RSE_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'RSE_GTvsPrediction_'+Test_FileList[-1]))\n","imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsP_forDisplay,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n","plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","\n","print('-----------------------------------')\n","print('Here are the average scores for the stacks you tested in Quality control. To see values for all slices, open the .csv file saved in the Quality Control folder.')\n","pdResults.head()\n","\n","\n","#Make a pdf summary of the QC results\n","\n","qc_pdf_export()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"69aJVFfsqXbY"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"tcPNRq1TrMPB"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images."]},{"cell_type":"code","metadata":{"cellView":"form","id":"Am2JSmpC0frj"},"source":["\n","#@markdown ##Provide the path to your dataset and to the folder where the prediction will be saved, then play the cell to predict output on your unseen images.\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n"," \n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","#Here we allow the user to choose the number of tile to be used when predicting the images\n","#@markdown #####To analyse large image, your images need to be divided into tiles. Each tile will then be processed independently and re-assembled to generate the final image. \"Automatic_number_of_tiles\" will search for and use the smallest number of tiles that can be used, at the expanse of your runtime. Alternatively, manually input the number of tiles in each dimension to be used to process your images. \n","\n","Automatic_number_of_tiles = False #@param {type:\"boolean\"}\n","#@markdown #####If you get an Out of memory (OOM) error when using the \"Automatic_number_of_tiles\" option, disable it and manually input the values to be used to process your images. Progressively increases these numbers until the OOM error disappear.\n","n_tiles_Z = 1#@param {type:\"number\"}\n","n_tiles_Y = 2#@param {type:\"number\"}\n","n_tiles_X = 2#@param {type:\"number\"}\n","\n","if (Automatic_number_of_tiles): \n"," n_tilesZYX = None\n","\n","if not (Automatic_number_of_tiles):\n"," n_tilesZYX = (n_tiles_Z, n_tiles_Y, n_tiles_X)\n","\n","#Activate the pretrained model. \n","model=CARE(config=None, name=Prediction_model_name, basedir=Prediction_model_path)\n","\n","print(\"Restoring images...\")\n","\n","thisdir = Path(Data_folder)\n","outputdir = Path(Result_folder)\n","suffix = '.tif'\n","\n","# r=root, d=directories, f = files\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," if \".tif\" in file:\n"," print(os.path.join(r, file))\n","\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," base_filename = os.path.basename(file)\n"," input_train = imread(os.path.join(r, file))\n"," pred_train = model.predict(input_train, axes='ZYX', n_tiles=n_tilesZYX)\n"," save_tiff_imagej_compatible(os.path.join(outputdir, base_filename), pred_train, axes='ZYX') \n","\n","print(\"Images saved into the result folder:\", Result_folder)\n","\n","#Display an example\n","\n","random_choice=random.choice(os.listdir(Data_folder))\n","x = imread(Data_folder+\"/\"+random_choice)\n","\n","z_mid_plane = int(x.shape[0] / 2)+1\n","\n","@interact\n","def show_results(file=os.listdir(Data_folder), z_plane=widgets.IntSlider(min=0, max=(x.shape[0]-1), step=1, value=z_mid_plane)):\n"," x = imread(Data_folder+\"/\"+file)\n"," y = imread(Result_folder+\"/\"+file)\n","\n"," f=plt.figure(figsize=(16,8))\n"," plt.subplot(1,2,1)\n"," plt.imshow(x[z_plane], norm=simple_norm(x[z_plane], percent = 99), interpolation='nearest')\n"," plt.axis('off')\n"," plt.title('Noisy Input (single Z plane)');\n"," plt.subplot(1,2,2)\n"," plt.imshow(y[z_plane], norm=simple_norm(y[z_plane], percent = 99), interpolation='nearest')\n"," plt.axis('off')\n"," plt.title('Prediction (single Z plane)');\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB"},"source":["## **6.2. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"Ka4SQ-h_cKIv"},"source":["# **7. Version log**\n","---\n","**v1.13**: \n","\n","\n","* This version now includes an automatic restart allowing to set the h5py library to v2.10.\n","* The section 1 and 2 are now swapped for better export of *requirements.txt*.\n","\n","* This version also now includes built-in version check and the version log that you're reading now.\n"]},{"cell_type":"markdown","metadata":{"id":"u4pcBe8Z3T2J"},"source":["#**Thank you for using CARE 3D!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/ChangeLog.txt b/Colab_notebooks/ChangeLog.txt index 1bbc0528..4d9b8955 100644 --- a/Colab_notebooks/ChangeLog.txt +++ b/Colab_notebooks/ChangeLog.txt @@ -7,6 +7,28 @@ https://www.biorxiv.org/content/10.1101/2020.03.20.000133v4 Latest releases available here: /~https://github.com/HenriquesLab/ZeroCostDL4Mic/releases +————————————————————————————————————————————————————————— +ZeroCostDL4Mic v1.13 + +Major changes: +- added a new export of requirements.txt file with minimal packages, better compatibility with local runtime sessions +- Beta notebooks: new notebooks available: Detectron2 (object detection), MaskRCNN (object detection and segmentation), DRMIME (image registration), Cellpose (image segmentation) (2D) and DecoNoising (denoising) (2D). +- Addition of the Interactive segmentation - Cellpose notebook, using Kaibu and ImJoy, big thanks to Wei Ouyang for helping us get this up and running. +- Tools: A notebook to perform Quality Control has been added +- Noise2Void 2D and 3D now use the latest code release based on TensorFlow 2 +- The version check is now done on a per-notebook basis, making release of individual notebooks easier. +- U-Net 3D imports Keras libraries via TensorFlow. +- Section 1 and 2 in the notebooks have been swapped for a better flow and improved capabilities for export of requirements.txt +- Each notebook now includes a version log, that will be individually amended when doing individual notebook releases. +- Fnet: A 2D notebook was added, the fnet notebooks have an additional cell that creates the data files, Fnet 3D's re-training cell now has a module to reduce images in the buffer, just as the main training cell (to avoid OOM error). +- YOLOv2: The repository is now saved in the content folder of colab rather than the users gdrive, consistent with other notebooks , Commented code was removed. + ++ minor bug fixes + +Beta notebooks: +- model export to BioImage Model Zoo format for DeepImageJ for 2D U-Net, 3D U-Net and Deep-STORM +- Cellpose notebook now exports PDF during training and QC. Cellpose now trains using Torch. The new Cyto2 model is also available. + ————————————————————————————————————————————————————————— ZeroCostDL4Mic v1.12 diff --git a/Colab_notebooks/CycleGAN_ZeroCostDL4Mic.ipynb b/Colab_notebooks/CycleGAN_ZeroCostDL4Mic.ipynb index 812586d8..6c49bf64 100644 --- a/Colab_notebooks/CycleGAN_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/CycleGAN_ZeroCostDL4Mic.ipynb @@ -1,2148 +1 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "IkSguVy8Xv83" - }, - "source": [ - "# **CycleGAN**\n", - "\n", - "---\n", - "\n", - "CycleGAN is a method that can capture the characteristics of one image domain and learn how these characteristics can be translated into another image domain, all in the absence of any paired training examples. It was first published by [Zhu *et al.* in 2017](https://arxiv.org/abs/1703.10593). Unlike pix2pix, the image transformation performed does not require paired images for training (unsupervised learning) and is made possible here by using a set of two Generative Adversarial Networks (GANs) that learn to transform images both from the first domain to the second and vice-versa.\n", - "\n", - " **This particular notebook enables unpaired image-to-image translation. If your dataset is paired, you should also consider using the pix2pix notebook.**\n", - "\n", - "---\n", - "\n", - "*Disclaimer*:\n", - "\n", - "This notebook is part of the *Zero-Cost Deep-Learning to Enhance Microscopy* project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n", - "\n", - "This notebook is based on the following paper: \n", - "\n", - " **Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks** from Zhu *et al.* published in arXiv in 2018 (https://arxiv.org/abs/1703.10593)\n", - "\n", - "The source code of the CycleGAN PyTorch implementation can be found in: /~https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix\n", - "\n", - "**Please also cite this original paper when using or developing this notebook.**" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jqvkQQkcuMmM" - }, - "source": [ - "# **License**\n", - "\n", - "---" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "vCihhAzluRvI" - }, - "outputs": [], - "source": [ - "#@markdown ##Double click to see the license information\n", - "\n", - "#------------------------- LICENSE FOR ZeroCostDL4Mic------------------------------------\n", - "#This ZeroCostDL4Mic notebook is distributed under the MIT licence\n", - "\n", - "\n", - "\n", - "#------------------------- LICENSE FOR CycleGAN ------------------------------------\n", - "\n", - "#Copyright (c) 2017, Jun-Yan Zhu and Taesung Park\n", - "#All rights reserved.\n", - "\n", - "#Redistribution and use in source and binary forms, with or without\n", - "#modification, are permitted provided that the following conditions are met:\n", - "\n", - "#* Redistributions of source code must retain the above copyright notice, this\n", - "# list of conditions and the following disclaimer.\n", - "\n", - "#* Redistributions in binary form must reproduce the above copyright notice,\n", - "# this list of conditions and the following disclaimer in the documentation\n", - "# and/or other materials provided with the distribution.\n", - "\n", - "#THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n", - "#AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n", - "#IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n", - "#DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n", - "#FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n", - "#DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n", - "#SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n", - "#CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n", - "#OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n", - "#OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n", - "\n", - "\n", - "#--------------------------- LICENSE FOR pix2pix --------------------------------\n", - "#BSD License\n", - "\n", - "#For pix2pix software\n", - "#Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu\n", - "#All rights reserved.\n", - "\n", - "#Redistribution and use in source and binary forms, with or without\n", - "#modification, are permitted provided that the following conditions are met:\n", - "\n", - "#* Redistributions of source code must retain the above copyright notice, this\n", - "# list of conditions and the following disclaimer.\n", - "\n", - "#* Redistributions in binary form must reproduce the above copyright notice,\n", - "# this list of conditions and the following disclaimer in the documentation\n", - "# and/or other materials provided with the distribution.\n", - "\n", - "#----------------------------- LICENSE FOR DCGAN --------------------------------\n", - "#BSD License\n", - "\n", - "#For dcgan.torch software\n", - "\n", - "#Copyright (c) 2015, Facebook, Inc. All rights reserved.\n", - "\n", - "#Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:\n", - "\n", - "#Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.\n", - "\n", - "#Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.\n", - "\n", - "#Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.\n", - "\n", - "#THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jWAz2i7RdxUV" - }, - "source": [ - "#**0. Before getting started**\n", - "---\n", - " To train CycleGAN, **you only need two folders containing PNG images**. The images do not need to be paired.\n", - "\n", - "While you do not need paired images to train CycleGAN, if possible, **we strongly recommend that you generate a paired dataset. This means that the same image needs to be acquired in the two conditions. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n", - "\n", - "\n", - " Please note that you currently can **only use .png files!**\n", - "\n", - "\n", - "Here's a common data structure that can work:\n", - "* Experiment A\n", - " - **Training dataset (non-matching images)**\n", - " - Training_source\n", - " - img_1.png, img_2.png, ...\n", - " - Training_target\n", - " - img_1.png, img_2.png, ...\n", - " - **Quality control dataset (matching images)**\n", - " - Training_source\n", - " - img_1.png, img_2.png\n", - " - Training_target\n", - " - img_1.png, img_2.png\n", - " - **Data to be predicted**\n", - " - **Results**\n", - "\n", - "---\n", - "**Important note**\n", - "\n", - "- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n", - "\n", - "- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n", - "\n", - "- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "n4yWFoJNnoin" - }, - "source": [ - "# **1. Initialise the Colab session**\n", - "---\n", - "\n", - "\n", - "\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "DMNHVZfHmbKb" - }, - "source": [ - "\n", - "## **1.1. Check for GPU access**\n", - "---\n", - "\n", - "By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n", - "\n", - "Go to **Runtime -> Change the Runtime type**\n", - "\n", - "**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n", - "\n", - "**Accelerator: GPU** *(Graphics processing unit)*\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "zCvebubeSaGY" - }, - "outputs": [], - "source": [ - "#@markdown ##Run this cell to check if you have GPU access\n", - "\n", - "\n", - "import tensorflow as tf\n", - "if tf.test.gpu_device_name()=='':\n", - " print('You do not have GPU access.') \n", - " print('Did you change your runtime ?') \n", - " print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n", - " print('Expect slow performance. To access GPU try reconnecting later')\n", - "\n", - "else:\n", - " print('You have GPU access')\n", - " !nvidia-smi" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "sNIVx8_CLolt" - }, - "source": [ - "## **1.2. Mount your Google Drive**\n", - "---\n", - " To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n", - "\n", - " Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n", - "\n", - " Once this is done, your data are available in the **Files** tab on the top left of notebook." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "01Djr8v-5pPk" - }, - "outputs": [], - "source": [ - "#@markdown ##Play the cell to connect your Google Drive to Colab\n", - "\n", - "#@markdown * Click on the URL. \n", - "\n", - "#@markdown * Sign in your Google Account. \n", - "\n", - "#@markdown * Copy the authorization code. \n", - "\n", - "#@markdown * Enter the authorization code. \n", - "\n", - "#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n", - "\n", - "# mount user's Google Drive to Google Colab.\n", - "from google.colab import drive\n", - "drive.mount('/content/gdrive')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "AdN8B91xZO0x" - }, - "source": [ - "# **2. Install CycleGAN and dependencies**\n", - "---\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "fq21zJVFNASx" - }, - "outputs": [], - "source": [ - "Notebook_version = ['1.12']\n", - "\n", - "\n", - "\n", - "#@markdown ##Install CycleGAN and dependencies\n", - "\n", - "\n", - "#------- Code from the cycleGAN demo notebook starts here -------\n", - "\n", - "#Here, we install libraries which are not already included in Colab.\n", - "\n", - "import sys\n", - "before = [str(m) for m in sys.modules]\n", - "\n", - "!git clone /~https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix\n", - "\n", - "import os\n", - "os.chdir('pytorch-CycleGAN-and-pix2pix/')\n", - "!pip install -r requirements.txt\n", - "!pip install fpdf\n", - "\n", - "import imageio\n", - "from skimage import data\n", - "from skimage import exposure\n", - "from skimage.exposure import match_histograms\n", - "\n", - "from skimage.util import img_as_int\n", - "\n", - "\n", - "\n", - "\n", - "# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n", - "import numpy as np\n", - "from matplotlib import pyplot as plt\n", - "import urllib\n", - "import os, random\n", - "import shutil \n", - "import zipfile\n", - "from tifffile import imread, imsave\n", - "import time\n", - "import sys\n", - "from pathlib import Path\n", - "import pandas as pd\n", - "import csv\n", - "from glob import glob\n", - "from scipy import signal\n", - "from scipy import ndimage\n", - "from skimage import io\n", - "from sklearn.linear_model import LinearRegression\n", - "from skimage.util import img_as_uint\n", - "import matplotlib as mpl\n", - "from skimage.metrics import structural_similarity\n", - "from skimage.metrics import peak_signal_noise_ratio as psnr\n", - "from astropy.visualization import simple_norm\n", - "from skimage import img_as_float32\n", - "from skimage.util import img_as_ubyte\n", - "from tqdm import tqdm \n", - "from fpdf import FPDF, HTMLMixin\n", - "from datetime import datetime\n", - "from pip._internal.operations.freeze import freeze\n", - "import subprocess\n", - "\n", - "# Colors for the warning messages\n", - "class bcolors:\n", - " WARNING = '\\033[31m'\n", - "\n", - "#Disable some of the tensorflow warnings\n", - "import warnings\n", - "warnings.filterwarnings(\"ignore\")\n", - "\n", - "print(\"Libraries installed\")\n", - "\n", - "# Check if this is the latest version of the notebook\n", - "Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n", - "\n", - "if Notebook_version == list(Latest_notebook_version.columns):\n", - " print(\"This notebook is up-to-date.\")\n", - "\n", - "if not Notebook_version == list(Latest_notebook_version.columns):\n", - " print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n", - "\n", - "\n", - "def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n", - " class MyFPDF(FPDF, HTMLMixin):\n", - " pass\n", - "\n", - " pdf = MyFPDF()\n", - " pdf.add_page()\n", - " pdf.set_right_margin(-1)\n", - " pdf.set_font(\"Arial\", size = 11, style='B') \n", - "\n", - " Network = 'cycleGAN'\n", - " day = datetime.now()\n", - " datetime_str = str(day)[0:10]\n", - "\n", - " Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n", - " pdf.multi_cell(180, 5, txt = Header, align = 'L') \n", - " \n", - " # add another cell\n", - " if trained:\n", - " training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n", - " pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n", - " pdf.ln(1)\n", - "\n", - " Header_2 = 'Information for your materials and method:'\n", - " pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n", - "\n", - " all_packages = ''\n", - " for requirement in freeze(local_only=True):\n", - " all_packages = all_packages+requirement+', '\n", - " #print(all_packages)\n", - "\n", - " #Main Packages\n", - " main_packages = ''\n", - " version_numbers = []\n", - " for name in ['tensorflow','numpy','torch']:\n", - " find_name=all_packages.find(name)\n", - " main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n", - " #Version numbers only here:\n", - " version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n", - "\n", - " cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n", - " cuda_version = cuda_version.stdout.decode('utf-8')\n", - " cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n", - " gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n", - " gpu_name = gpu_name.stdout.decode('utf-8')\n", - " gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n", - " #print(cuda_version[cuda_version.find(', V')+3:-1])\n", - " #print(gpu_name)\n", - "\n", - " shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n", - " dataset_size = len(os.listdir(Training_source))\n", - "\n", - " text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a least-square GAN loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), numpy (v '+version_numbers[1]+'), torch (v '+version_numbers[2]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n", - "\n", - " if pretrained_model:\n", - " text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and an least-square GAN loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was retrained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), numpy (v '+version_numbers[1]+'), torch (v '+version_numbers[2]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n", - "\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - " pdf.multi_cell(190, 5, txt = text, align='L')\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.ln(1)\n", - " pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n", - " pdf.set_font('')\n", - " if augmentation:\n", - " aug_text = 'The dataset was augmented by default'\n", - " else:\n", - " aug_text = 'No augmentation was used for training.'\n", - " pdf.multi_cell(190, 5, txt=aug_text, align='L')\n", - " pdf.set_font('Arial', size = 11, style = 'B')\n", - " pdf.ln(1)\n", - " pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - " if Use_Default_Advanced_Parameters:\n", - " pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n", - " pdf.cell(200, 5, txt='The following parameters were used for training:')\n", - " pdf.ln(1)\n", - " html = \"\"\" \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
ParameterValue
number_of_epochs{0}
patch_size{1}
batch_size{2}
initial_learning_rate{3}
\n", - " \"\"\".format(number_of_epochs,str(patch_size)+'x'+str(patch_size),batch_size,initial_learning_rate)\n", - " pdf.write_html(html)\n", - "\n", - " #pdf.multi_cell(190, 5, txt = text_2, align='L')\n", - " pdf.set_font(\"Arial\", size = 11, style='B')\n", - " pdf.ln(1)\n", - " pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.cell(30, 5, txt= 'Training_source:', align = 'L', ln=0)\n", - " pdf.set_font('')\n", - " pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.cell(29, 5, txt= 'Training_target:', align = 'L', ln=0)\n", - " pdf.set_font('')\n", - " pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n", - " #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n", - " pdf.ln(1)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.cell(22, 5, txt= 'Model Path:', align = 'L', ln=0)\n", - " pdf.set_font('')\n", - " pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n", - " pdf.ln(1)\n", - " pdf.cell(60, 5, txt = 'Example Training pair', ln=1)\n", - " pdf.ln(1)\n", - " exp_size = io.imread('/content/TrainingDataExample_cycleGAN.png').shape\n", - " pdf.image('/content/TrainingDataExample_cycleGAN.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n", - " pdf.ln(1)\n", - " ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n", - " pdf.multi_cell(190, 5, txt = ref_1, align='L')\n", - " ref_2 = '- cycleGAN: Zhu, Jun-Yan, et al. \"Unpaired image-to-image translation using cycle-consistent adversarial networks.\" Proceedings of the IEEE international conference on computer vision. 2017.'\n", - " pdf.multi_cell(190, 5, txt = ref_2, align='L')\n", - " # if Use_Data_augmentation:\n", - " # ref_3 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\n", - " # pdf.multi_cell(190, 5, txt = ref_3, align='L')\n", - " pdf.ln(3)\n", - " reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n", - " pdf.set_font('Arial', size = 11, style='B')\n", - " pdf.multi_cell(190, 5, txt=reminder, align='C')\n", - "\n", - " pdf.output(model_path+'/'+model_name+'/'+model_name+\"_training_report.pdf\")\n", - "\n", - "\n", - "def qc_pdf_export():\n", - " class MyFPDF(FPDF, HTMLMixin):\n", - " pass\n", - "\n", - " pdf = MyFPDF()\n", - " pdf.add_page()\n", - " pdf.set_right_margin(-1)\n", - " pdf.set_font(\"Arial\", size = 11, style='B') \n", - "\n", - " Network = 'cycleGAN'\n", - "\n", - "\n", - " day = datetime.now()\n", - " datetime_str = str(day)[0:10]\n", - "\n", - " Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n", - " pdf.multi_cell(180, 5, txt = Header, align = 'L') \n", - "\n", - " all_packages = ''\n", - " for requirement in freeze(local_only=True):\n", - " all_packages = all_packages+requirement+', '\n", - "\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 11, style = 'B')\n", - " pdf.ln(2)\n", - " pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n", - " pdf.ln(1)\n", - " exp_size = io.imread(full_QC_model_path+'Quality Control/SSIMvsCheckpoint_data.png').shape\n", - " pdf.image(full_QC_model_path+'Quality Control/SSIMvsCheckpoint_data.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n", - " pdf.ln(2)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.ln(3)\n", - " pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n", - " pdf.ln(1)\n", - " exp_size = io.imread(full_QC_model_path+'Quality Control/QC_example_data.png').shape\n", - " if Image_type == 'RGB':\n", - " pdf.image(full_QC_model_path+'Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/5), h = round(exp_size[0]/5))\n", - " if Image_type == 'Grayscale':\n", - " pdf.image(full_QC_model_path+'Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n", - " pdf.ln(1)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 11, style = 'B')\n", - " pdf.ln(1)\n", - " pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - "\n", - " pdf.ln(1)\n", - " for checkpoint in os.listdir(full_QC_model_path+'Quality Control'):\n", - " if os.path.isdir(os.path.join(full_QC_model_path,'Quality Control',checkpoint)):\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.cell(70, 5, txt = 'Metrics for checkpoint: '+ str(checkpoint), align='L', ln=1)\n", - " html = \"\"\"\n", - " \n", - " \n", - " \"\"\"\n", - " with open(full_QC_model_path+'Quality Control/'+str(checkpoint)+'/QC_metrics_'+QC_model_name+str(checkpoint)+'.csv', 'r') as csvfile:\n", - " metrics = csv.reader(csvfile)\n", - " header = next(metrics)\n", - " image = header[0]\n", - " mSSIM_PvsGT = header[1]\n", - " mSSIM_SvsGT = header[2]\n", - " header = \"\"\"\n", - " \n", - " \n", - " \n", - " \n", - " \"\"\".format(image,mSSIM_PvsGT,mSSIM_SvsGT)\n", - " html = html+header\n", - " for row in metrics:\n", - " image = row[0]\n", - " mSSIM_PvsGT = row[1]\n", - " mSSIM_SvsGT = row[2]\n", - " cells = \"\"\"\n", - " \n", - " \n", - " \n", - " \n", - " \"\"\".format(image,str(round(float(mSSIM_PvsGT),3)),str(round(float(mSSIM_SvsGT),3)))\n", - " html = html+cells\n", - " html = html+\"\"\"
{0}{1}{2}
{0}{1}{2}
\"\"\"\n", - " pdf.write_html(html)\n", - " pdf.ln(2)\n", - " else:\n", - " continue\n", - "\n", - " pdf.ln(1)\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - " ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n", - " pdf.multi_cell(190, 5, txt = ref_1, align='L')\n", - " ref_2 = '- cycleGAN: Zhu, Jun-Yan, et al. \"Unpaired image-to-image translation using cycle-consistent adversarial networks.\" Proceedings of the IEEE international conference on computer vision. 2017.'\n", - " pdf.multi_cell(190, 5, txt = ref_2, align='L')\n", - "\n", - " pdf.ln(3)\n", - " reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n", - "\n", - " pdf.set_font('Arial', size = 11, style='B')\n", - " pdf.multi_cell(190, 5, txt=reminder, align='C')\n", - "\n", - " pdf.output(full_QC_model_path+'Quality Control/'+QC_model_name+'_QC_report.pdf')\n", - "\n", - "# Exporting requirements.txt for local run\n", - "!pip freeze > ../requirements.txt\n", - "\n", - "after = [str(m) for m in sys.modules]\n", - "# Get minimum requirements file\n", - "\n", - "#Add the following lines before all imports: \n", - "# import sys\n", - "# before = [str(m) for m in sys.modules]\n", - "\n", - "#Add the following line after the imports:\n", - "# after = [str(m) for m in sys.modules]\n", - "\n", - "from builtins import any as b_any\n", - "\n", - "def filter_files(file_list, filter_list):\n", - " filtered_list = []\n", - " for fname in file_list:\n", - " if b_any(fname.split('==')[0] in s for s in filter_list):\n", - " filtered_list.append(fname)\n", - " return filtered_list\n", - "\n", - "df = pd.read_csv('../requirements.txt', delimiter = \"\\n\")\n", - "mod_list = [m.split('.')[0] for m in after if not m in before]\n", - "req_list_temp = df.values.tolist()\n", - "req_list = [x[0] for x in req_list_temp]\n", - "\n", - "# Replace with package name \n", - "mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]\n", - "mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list] \n", - "filtered_list = filter_files(req_list, mod_replace_list)\n", - "\n", - "file=open('../CycleGAN_requirements_simple.txt','w')\n", - "for item in filtered_list:\n", - " file.writelines(item + '\\n')\n", - "\n", - "file.close()\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "HLYcZR9gMv42" - }, - "source": [ - "# **3. Select your parameters and paths**\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "FQ_QxtSWQ7CL" - }, - "source": [ - "## **3.1. Setting main training parameters**\n", - "---\n", - " " - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "AuESFimvMv43" - }, - "source": [ - " **Paths for training, predictions and results**\n", - "\n", - "**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source and Training_target training data respecively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n", - "\n", - "**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n", - "\n", - "**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n", - "\n", - "**Training Parameters**\n", - "\n", - "**`number_of_epochs`:**Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10) epochs, but a full training should run for 200 epochs or more. Evaluate the performance after training (see 5). **Default value: 200**\n", - "\n", - "\n", - "**Advanced Parameters - experienced users only**\n", - "\n", - "**`patch_size`:** CycleGAN divides the image into patches for training. Input the size of the patches (length of a side). The value should be smaller than the dimensions of the image and divisible by 4. **Default value: 512**\n", - "\n", - "**When choosing the patch_size, the value should be i) large enough that it will enclose many instances, ii) small enough that the resulting patches fit into the RAM.** \n", - "\n", - "**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 1**\n", - "\n", - "**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0002**" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "ewpNJ_I0Mv47" - }, - "outputs": [], - "source": [ - "\n", - "\n", - "#@markdown ###Path to training images:\n", - "\n", - "Training_source = \"\" #@param {type:\"string\"}\n", - "InputFile = Training_source+\"/*.png\"\n", - "\n", - "Training_target = \"\" #@param {type:\"string\"}\n", - "OutputFile = Training_target+\"/*.png\"\n", - "\n", - "\n", - "#Define where the patch file will be saved\n", - "base = \"/content\"\n", - "\n", - "\n", - "# model name and path\n", - "#@markdown ###Name of the model and path to model folder:\n", - "model_name = \"\" #@param {type:\"string\"}\n", - "model_path = \"\" #@param {type:\"string\"}\n", - "\n", - "# other parameters for training.\n", - "#@markdown ###Training Parameters\n", - "#@markdown Number of epochs:\n", - "number_of_epochs = 200#@param {type:\"number\"}\n", - "\n", - "#@markdown ###Advanced Parameters\n", - "\n", - "Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n", - "#@markdown ###If not, please input:\n", - "patch_size = 512#@param {type:\"number\"} # in pixels\n", - "batch_size = 1#@param {type:\"number\"}\n", - "initial_learning_rate = 0.0002 #@param {type:\"number\"}\n", - "\n", - "\n", - "if (Use_Default_Advanced_Parameters): \n", - " print(\"Default advanced parameters enabled\")\n", - " batch_size = 1\n", - " patch_size = 512\n", - " initial_learning_rate = 0.0002\n", - "\n", - "#here we check that no model with the same name already exist, if so delete\n", - "if os.path.exists(model_path+'/'+model_name):\n", - " print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n", - " print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\")\n", - " \n", - "\n", - "\n", - "#To use Cyclegan we need to organise the data in a way the model can understand\n", - "\n", - "Saving_path= \"/content/\"+model_name\n", - "#Saving_path= model_path+\"/\"+model_name\n", - "\n", - "if os.path.exists(Saving_path):\n", - " shutil.rmtree(Saving_path)\n", - "os.makedirs(Saving_path)\n", - "\n", - "TrainA_Folder = Saving_path+\"/trainA\"\n", - "if os.path.exists(TrainA_Folder):\n", - " shutil.rmtree(TrainA_Folder)\n", - "os.makedirs(TrainA_Folder)\n", - " \n", - "TrainB_Folder = Saving_path+\"/trainB\"\n", - "if os.path.exists(TrainB_Folder):\n", - " shutil.rmtree(TrainB_Folder)\n", - "os.makedirs(TrainB_Folder)\n", - "\n", - "# Here we disable pre-trained model by default (in case the cell is not ran)\n", - "Use_pretrained_model = False\n", - "\n", - "# Here we disable data augmentation by default (in case the cell is not ran)\n", - "\n", - "Use_Data_augmentation = True\n", - "\n", - "\n", - "# This will display a randomly chosen dataset input and output\n", - "random_choice = random.choice(os.listdir(Training_source))\n", - "x = imageio.imread(Training_source+\"/\"+random_choice)\n", - "\n", - "\n", - "#Find image XY dimension\n", - "Image_Y = x.shape[0]\n", - "Image_X = x.shape[1]\n", - "\n", - "Image_min_dim = min(Image_Y, Image_X)\n", - "\n", - "\n", - "\n", - "#Hyperparameters failsafes\n", - "if patch_size > min(Image_Y, Image_X):\n", - " patch_size = min(Image_Y, Image_X)\n", - " print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n", - "\n", - "# Here we check that patch_size is divisible by 4\n", - "if not patch_size % 4 == 0:\n", - " patch_size = ((int(patch_size / 4)-1) * 4)\n", - " print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 4; therefore the patch_size chosen is now:\",patch_size)\n", - "\n", - "\n", - "random_choice_2 = random.choice(os.listdir(Training_target))\n", - "y = imageio.imread(Training_target+\"/\"+random_choice_2)\n", - "\n", - "f=plt.figure(figsize=(16,8))\n", - "plt.subplot(1,2,1)\n", - "plt.imshow(x, interpolation='nearest')\n", - "plt.title('Training source')\n", - "plt.axis('off');\n", - "\n", - "plt.subplot(1,2,2)\n", - "plt.imshow(y, interpolation='nearest')\n", - "plt.title('Training target')\n", - "plt.axis('off');\n", - "plt.savefig('/content/TrainingDataExample_cycleGAN.png',bbox_inches='tight',pad_inches=0)\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "xyQZKby8yFME" - }, - "source": [ - "## **3.2. Data augmentation**\n", - "---\n", - "" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "w_jCy7xOx2g3" - }, - "source": [ - "Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n", - "\n", - "Data augmentation is performed here by flipping the patches. \n", - "\n", - " By default data augmentation is enabled." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "DMqWq5-AxnFU" - }, - "outputs": [], - "source": [ - "#Data augmentation\n", - "\n", - "#@markdown ##Play this cell to enable or disable data augmentation: \n", - "\n", - "Use_Data_augmentation = True #@param {type:\"boolean\"}\n", - "\n", - "if Use_Data_augmentation:\n", - " print(\"Data augmentation enabled\")\n", - "\n", - "if not Use_Data_augmentation:\n", - " print(\"Data augmentation disabled\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "3L9zSGtORKYI" - }, - "source": [ - "\n", - "## **3.3. Using weights from a pre-trained model as initial weights**\n", - "---\n", - " Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a CycleGAN model**. \n", - "\n", - " This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n", - "\n", - " In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "9vC2n-HeLdiJ" - }, - "outputs": [], - "source": [ - "# @markdown ##Loading weights from a pre-trained network\n", - "\n", - "\n", - "Use_pretrained_model = False #@param {type:\"boolean\"}\n", - "\n", - "\n", - "#@markdown ###If not, please provide the path to the model folder:\n", - "pretrained_model_path = \"\" #@param {type:\"string\"}\n", - "\n", - "# --------------------- Check if we load a previously trained model ------------------------\n", - "if Use_pretrained_model:\n", - "\n", - " h5_file_path_A = os.path.join(pretrained_model_path, \"latest_net_G_A.pth\")\n", - " h5_file_path_B = os.path.join(pretrained_model_path, \"latest_net_G_B.pth\")\n", - "\n", - "# --------------------- Check the model exist ------------------------\n", - "\n", - " if not os.path.exists(h5_file_path_A) and os.path.exists(h5_file_path_B):\n", - " print(bcolors.WARNING+'WARNING: Pretrained model does not exist')\n", - " Use_pretrained_model = False\n", - " print(bcolors.WARNING+'No pretrained network will be used.')\n", - "\n", - " if os.path.exists(h5_file_path_A) and os.path.exists(h5_file_path_B):\n", - " print(\"Pretrained model \"+os.path.basename(pretrained_model_path)+\" was found and will be loaded prior to training.\")\n", - " \n", - "else:\n", - " print(bcolors.WARNING+'No pretrained network will be used.')\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "MCGklf1vZf2M" - }, - "source": [ - "# **4. Train the network**\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1KYOuygETJkT" - }, - "source": [ - "## **4.1. Prepare the training data for training**\n", - "---\n", - "Here, we use the information from 3. to prepare the training data into a suitable format for training." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "lIUAOJ_LMv5E" - }, - "outputs": [], - "source": [ - "#@markdown ##Prepare the data for training\n", - "\n", - "print(\"Data preparation in progress\")\n", - "\n", - "if os.path.exists(model_path+'/'+model_name):\n", - " shutil.rmtree(model_path+'/'+model_name)\n", - "os.makedirs(model_path+'/'+model_name)\n", - "\n", - "#--------------- Here we move the files to trainA and train B ---------\n", - "\n", - "\n", - "for f in os.listdir(Training_source):\n", - " shutil.copyfile(Training_source+\"/\"+f, TrainA_Folder+\"/\"+f)\n", - "\n", - "for files in os.listdir(Training_target):\n", - " shutil.copyfile(Training_target+\"/\"+files, TrainB_Folder+\"/\"+files)\n", - "\n", - "#---------------------------------------------------------------------\n", - "\n", - "# CycleGAN use number of EPOCH withouth lr decay and number of EPOCH with lr decay\n", - "\n", - "\n", - "number_of_epochs_lr_stable = int(number_of_epochs/2)\n", - "number_of_epochs_lr_decay = int(number_of_epochs/2)\n", - "\n", - "if Use_pretrained_model :\n", - " for f in os.listdir(pretrained_model_path):\n", - " if (f.startswith(\"latest_net_\")): \n", - " shutil.copyfile(pretrained_model_path+\"/\"+f, model_path+'/'+model_name+\"/\"+f)\n", - "\n", - "\n", - "pdf_export(augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)\n", - "\n", - "print(\"Data ready for training\")\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "0Dfn8ZsEMv5d" - }, - "source": [ - "## **4.2. Start Training**\n", - "---\n", - "When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n", - "\n", - "* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches or continue the training in a second Colab session.\n", - "\n", - "Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder from Google Drive as all data can be erased at the next training if using the same folder." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "iwNmp1PUzRDQ", - "scrolled": true - }, - "outputs": [], - "source": [ - "\n", - "#@markdown ##Start training\n", - "\n", - "start = time.time()\n", - "\n", - "os.chdir(\"/content\")\n", - "\n", - "#--------------------------------- Command line inputs to change CycleGAN paramaters------------\n", - "\n", - " # basic parameters\n", - " #('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')\n", - " #('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')\n", - " #('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')\n", - " #('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')\n", - " \n", - " # model parameters\n", - " #('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]')\n", - " #('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')\n", - " #('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')\n", - " #('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')\n", - " #('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')\n", - " #('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')\n", - " #('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]')\n", - " #('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')\n", - " #('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]')\n", - " #('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]')\n", - " #('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')\n", - " #('--no_dropout', action='store_true', help='no dropout for the generator')\n", - " \n", - " # dataset parameters\n", - " #('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')\n", - " #('--direction', type=str, default='AtoB', help='AtoB or BtoA')\n", - " #('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')\n", - " #('--num_threads', default=4, type=int, help='# threads for loading data')\n", - " #('--batch_size', type=int, default=1, help='input batch size')\n", - " #('--load_size', type=int, default=286, help='scale images to this size')\n", - " #('--crop_size', type=int, default=256, help='then crop to this size')\n", - " #('--max_dataset_size', type=int, default=float(\"inf\"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')\n", - " #('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')\n", - " #('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')\n", - " #('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')\n", - " \n", - " # additional parameters\n", - " #('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')\n", - " #('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')\n", - " #('--verbose', action='store_true', help='if specified, print more debugging information')\n", - " #('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')\n", - " \n", - " # visdom and HTML visualization parameters\n", - " #('--display_freq', type=int, default=400, help='frequency of showing training results on screen')\n", - " #('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.')\n", - " #('--display_id', type=int, default=1, help='window id of the web display')\n", - " #('--display_server', type=str, default=\"http://localhost\", help='visdom server of the web display')\n", - " #('--display_env', type=str, default='main', help='visdom display environment name (default is \"main\")')\n", - " #('--display_port', type=int, default=8097, help='visdom port of the web display')\n", - " #('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html')\n", - " #('--print_freq', type=int, default=100, help='frequency of showing training results on console')\n", - " #('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')\n", - " \n", - " # network saving and loading parameters\n", - " #('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')\n", - " #('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')\n", - " #('--save_by_iter', action='store_true', help='whether saves model by iteration')\n", - " #('--continue_train', action='store_true', help='continue training: load the latest model')\n", - " #('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...')\n", - " #('--phase', type=str, default='train', help='train, val, test, etc')\n", - " \n", - " # training parameters\n", - " #('--n_epochs', type=int, default=100, help='number of epochs with the initial learning rate')\n", - " #('--n_epochs_decay', type=int, default=100, help='number of epochs to linearly decay learning rate to zero')\n", - " #('--beta1', type=float, default=0.5, help='momentum term of adam')\n", - " #('--lr', type=float, default=0.0002, help='initial learning rate for adam')\n", - " #('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.')\n", - " #('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images')\n", - " #('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]')\n", - " #('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations'\n", - "\n", - "#---------------------------------------------------------\n", - "\n", - "#----- Start the training ------------------------------------\n", - "if not Use_pretrained_model:\n", - " if Use_Data_augmentation:\n", - " !python pytorch-CycleGAN-and-pix2pix/train.py --dataroot \"$Saving_path\" --input_nc 3 --name $model_name --model cycle_gan --batch_size $batch_size --preprocess scale_width_and_crop --load_size $Image_min_dim --crop_size $patch_size --checkpoints_dir \"$model_path\" --no_html --n_epochs $number_of_epochs_lr_stable --n_epochs_decay $number_of_epochs_lr_decay --lr $initial_learning_rate --display_id 0 --save_epoch_freq 5\n", - " if not Use_Data_augmentation:\n", - " !python pytorch-CycleGAN-and-pix2pix/train.py --dataroot \"$Saving_path\" --input_nc 3 --name $model_name --model cycle_gan --batch_size $batch_size --preprocess scale_width_and_crop --load_size $Image_min_dim --crop_size $patch_size --checkpoints_dir \"$model_path\" --no_html --n_epochs $number_of_epochs_lr_stable --n_epochs_decay $number_of_epochs_lr_decay --lr $initial_learning_rate --display_id 0 --save_epoch_freq 5 --no_flip\n", - "\n", - "if Use_pretrained_model:\n", - " if Use_Data_augmentation:\n", - " !python pytorch-CycleGAN-and-pix2pix/train.py --dataroot \"$Saving_path\" --input_nc 3 --name $model_name --model cycle_gan --batch_size $batch_size --preprocess scale_width_and_crop --load_size $Image_min_dim --crop_size $patch_size --checkpoints_dir \"$model_path\" --no_html --n_epochs $number_of_epochs_lr_stable --n_epochs_decay $number_of_epochs_lr_decay --lr $initial_learning_rate --display_id 0 --save_epoch_freq 5 --continue_train\n", - " \n", - " if not Use_Data_augmentation:\n", - " !python pytorch-CycleGAN-and-pix2pix/train.py --dataroot \"$Saving_path\" --input_nc 3 --name $model_name --model cycle_gan --batch_size $batch_size --preprocess scale_width_and_crop --load_size $Image_min_dim --crop_size $patch_size --checkpoints_dir \"$model_path\" --no_html --n_epochs $number_of_epochs_lr_stable --n_epochs_decay $number_of_epochs_lr_decay --lr $initial_learning_rate --display_id 0 --save_epoch_freq 5 --continue_train --no_flip\n", - "\n", - "#---------------------------------------------------------\n", - "\n", - "print(\"Training, done.\")\n", - "\n", - "# Displaying the time elapsed for training\n", - "dt = time.time() - start\n", - "mins, sec = divmod(dt, 60) \n", - "hour, mins = divmod(mins, 60) \n", - "print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n", - "\n", - "# Save training summary as pdf\n", - "\n", - "pdf_export(trained = True, augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "_0Hynw3-xHp1" - }, - "source": [ - "# **5. Evaluate your model**\n", - "---\n", - "\n", - "This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n", - "\n", - "**We highly recommend to perform quality control on all newly trained models.**\n", - "\n", - "Unfortunately loss functions curve are not very informative for GAN network. Therefore we perform the QC here using a test dataset.\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1Wext8woxt_F" - }, - "source": [ - "## **5.1. Choose the model you want to assess**" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "eAJzMwPA6tlH" - }, - "outputs": [], - "source": [ - "# model name and path\n", - "#@markdown ###Do you want to assess the model you just trained ?\n", - "Use_the_current_trained_model = True #@param {type:\"boolean\"}\n", - "\n", - "#@markdown ###If not, please provide the path to the model folder:\n", - "\n", - "QC_model_folder = \"\" #@param {type:\"string\"}\n", - "\n", - "#Here we define the loaded model name and path\n", - "QC_model_name = os.path.basename(QC_model_folder)\n", - "QC_model_path = os.path.dirname(QC_model_folder)\n", - "\n", - "if (Use_the_current_trained_model): \n", - " QC_model_name = model_name\n", - " QC_model_path = model_path\n", - "\n", - "full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n", - "if os.path.exists(full_QC_model_path):\n", - " print(\"The \"+QC_model_name+\" network will be evaluated\")\n", - "else:\n", - " W = '\\033[0m' # white (normal)\n", - " R = '\\033[31m' # red\n", - " print(R+'!! WARNING: The chosen model does not exist !!'+W)\n", - " print('Please make sure you provide a valid model path and model name before proceeding further.')\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1CFbjvTpx5C3" - }, - "source": [ - "## **5.2. Identify the best checkpoint to use to make predictions**" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "q8tCfAadx96X" - }, - "source": [ - " CycleGAN save model checkpoints every five epochs. Due to the stochastic nature of GAN networks, the last checkpoint is not always the best one to use. As a consequence, it can be challenging to choose the most suitable checkpoint to use to make predictions.\n", - "\n", - "This section allows you to perform predictions using all the saved checkpoints and to estimate the quality of these predictions by comparing them to the provided ground truths images. Metric used include:\n", - "\n", - "**1. The SSIM (structural similarity) map** \n", - "\n", - "The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n", - "\n", - "**mSSIM** is the SSIM value calculated across the entire window of both images.\n", - "\n", - "**The output below shows the SSIM maps with the mSSIM**\n", - "\n", - "**2. The RSE (Root Squared Error) map** \n", - "\n", - "This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n", - "\n", - "\n", - "**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n", - "\n", - "**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n", - "\n", - "**The output below shows the RSE maps with the NRMSE and PSNR values.**\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "q2T4t8NNyDZ6" - }, - "outputs": [], - "source": [ - "#@markdown ##Choose the folders that contain your Quality Control dataset\n", - "\n", - "Source_QC_folder = \"\" #@param{type:\"string\"}\n", - "Target_QC_folder = \"\" #@param{type:\"string\"}\n", - "\n", - "Image_type = \"Grayscale\" #@param [\"Grayscale\", \"RGB\"]\n", - "\n", - "# average function\n", - "def Average(lst): \n", - " return sum(lst) / len(lst) \n", - "\n", - "\n", - "# Create a quality control folder\n", - "\n", - "if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\"):\n", - " shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n", - "\n", - "os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n", - "\n", - "# List images in Source_QC_folder\n", - "# This will find the image dimension of a randomly choosen image in Source_QC_folder \n", - "random_choice = random.choice(os.listdir(Source_QC_folder))\n", - "x = imageio.imread(Source_QC_folder+\"/\"+random_choice)\n", - "\n", - "#Find image XY dimension\n", - "Image_Y = x.shape[0]\n", - "Image_X = x.shape[1]\n", - "\n", - "Image_min_dim = min(Image_Y, Image_X)\n", - "\n", - "\n", - "# Here we need to move the data to be analysed so that cycleGAN can find them\n", - "\n", - "Saving_path_QC= \"/content/\"+QC_model_name\n", - "\n", - "if os.path.exists(Saving_path_QC):\n", - " shutil.rmtree(Saving_path_QC)\n", - "os.makedirs(Saving_path_QC)\n", - "\n", - "Saving_path_QC_folder = Saving_path_QC+\"_images\"\n", - "\n", - "if os.path.exists(Saving_path_QC_folder):\n", - " shutil.rmtree(Saving_path_QC_folder)\n", - "os.makedirs(Saving_path_QC_folder)\n", - "\n", - "\n", - "#Here we copy and rename the all the checkpoint to be analysed\n", - "\n", - "for f in os.listdir(full_QC_model_path):\n", - " shortname = f[:-6]\n", - " shortname = shortname + \".pth\"\n", - " if f.endswith(\"net_G_A.pth\"):\n", - " shutil.copyfile(full_QC_model_path+f, Saving_path_QC+\"/\"+shortname)\n", - "\n", - "\n", - "for files in os.listdir(Source_QC_folder):\n", - " shutil.copyfile(Source_QC_folder+\"/\"+files, Saving_path_QC_folder+\"/\"+files)\n", - " \n", - "\n", - "# This will find the image dimension of a randomly chosen image in Source_QC_folder \n", - "random_choice = random.choice(os.listdir(Source_QC_folder))\n", - "x = imageio.imread(Source_QC_folder+\"/\"+random_choice)\n", - "\n", - "#Find image XY dimension\n", - "Image_Y = x.shape[0]\n", - "Image_X = x.shape[1]\n", - "\n", - "Image_min_dim = int(min(Image_Y, Image_X))\n", - "\n", - "Nb_Checkpoint = len(os.listdir(Saving_path_QC))\n", - "\n", - "print(Nb_Checkpoint)\n", - "\n", - "\n", - "\n", - "## Initiate list\n", - "\n", - "Checkpoint_list = []\n", - "Average_ssim_score_list = []\n", - "\n", - "\n", - "for j in range(1, len(os.listdir(Saving_path_QC))+1):\n", - " checkpoints = j*5\n", - "\n", - " if checkpoints == Nb_Checkpoint*5:\n", - " checkpoints = \"latest\"\n", - "\n", - "\n", - " print(\"The checkpoint currently analysed is =\"+str(checkpoints))\n", - "\n", - " Checkpoint_list.append(checkpoints)\n", - "\n", - "\n", - " # Create a quality control/Prediction Folder\n", - "\n", - " QC_prediction_results = QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)\n", - "\n", - " if os.path.exists(QC_prediction_results):\n", - " shutil.rmtree(QC_prediction_results)\n", - "\n", - " os.makedirs(QC_prediction_results)\n", - "\n", - "\n", - "\n", - "#---------------------------- Predictions are performed here ----------------------\n", - "\n", - " os.chdir(\"/content\")\n", - "\n", - " !python pytorch-CycleGAN-and-pix2pix/test.py --dataroot \"$Saving_path_QC_folder\" --name \"$QC_model_name\" --model test --epoch $checkpoints --no_dropout --preprocess scale_width --load_size $Image_min_dim --crop_size $Image_min_dim --results_dir \"$QC_prediction_results\" --checkpoints_dir \"/content/\"\n", - "\n", - "#-----------------------------------------------------------------------------------\n", - "\n", - "#Here we need to move the data again and remove all the unnecessary folders\n", - "\n", - " Checkpoint_name = \"test_\"+str(checkpoints)\n", - "\n", - " QC_results_images = QC_prediction_results+\"/\"+QC_model_name+\"/\"+Checkpoint_name+\"/images\"\n", - "\n", - " QC_results_images_files = os.listdir(QC_results_images)\n", - "\n", - " for f in QC_results_images_files: \n", - " shutil.copyfile(QC_results_images+\"/\"+f, QC_prediction_results+\"/\"+f)\n", - "\n", - " os.chdir(\"/content\") \n", - "\n", - " #Here we clean up the extra files\n", - " shutil.rmtree(QC_prediction_results+\"/\"+QC_model_name)\n", - "\n", - "\n", - "#-------------------------------- QC for RGB ------------------------------------\n", - " if Image_type == \"RGB\":\n", - "# List images in Source_QC_folder\n", - "# This will find the image dimension of a randomly choosen image in Source_QC_folder \n", - " random_choice = random.choice(os.listdir(Source_QC_folder))\n", - " x = imageio.imread(Source_QC_folder+\"/\"+random_choice)\n", - "\n", - " def ssim(img1, img2):\n", - " return structural_similarity(img1,img2,data_range=1.,full=True, multichannel=True)\n", - "\n", - "# Open and create the csv file that will contain all the QC metrics\n", - " with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", \"w\", newline='') as file:\n", - " writer = csv.writer(file)\n", - "\n", - " # Write the header in the csv file\n", - " writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\"])\n", - " \n", - " \n", - " # Initiate list\n", - " ssim_score_list = [] \n", - "\n", - "\n", - " # Let's loop through the provided dataset in the QC folders\n", - "\n", - "\n", - " for i in os.listdir(Source_QC_folder):\n", - " if not os.path.isdir(os.path.join(Source_QC_folder,i)):\n", - " print('Running QC on: '+i)\n", - "\n", - " shortname_no_PNG = i[:-4]\n", - " \n", - " # -------------------------------- Target test data (Ground truth) --------------------------------\n", - " test_GT = imageio.imread(os.path.join(Target_QC_folder, i), as_gray=False, pilmode=\"RGB\")\n", - "\n", - " # -------------------------------- Source test data --------------------------------\n", - " test_source = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_real.png\"))\n", - " \n", - " \n", - " # -------------------------------- Prediction --------------------------------\n", - " \n", - " test_prediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_fake.png\"))\n", - " \n", - " #--------------------------- Here we normalise using histograms matching--------------------------------\n", - " test_prediction_matched = match_histograms(test_prediction, test_GT, multichannel=True)\n", - " test_source_matched = match_histograms(test_source, test_GT, multichannel=True)\n", - " \n", - " # -------------------------------- Calculate the metric maps and save them --------------------------------\n", - "\n", - " # Calculate the SSIM maps\n", - " index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT, test_prediction_matched)\n", - " index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT, test_source_matched)\n", - "\n", - " ssim_score_list.append(index_SSIM_GTvsPrediction)\n", - "\n", - " #Save ssim_maps\n", - " img_SSIM_GTvsPrediction_8bit = (img_SSIM_GTvsPrediction* 255).astype(\"uint8\")\n", - " io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsPrediction_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsPrediction_8bit)\n", - " img_SSIM_GTvsSource_8bit = (img_SSIM_GTvsSource* 255).astype(\"uint8\")\n", - " io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsSource_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsSource_8bit)\n", - " \n", - " \n", - " writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource)])\n", - "\n", - " #Here we calculate the ssim average for each image in each checkpoints\n", - "\n", - " Average_SSIM_checkpoint = Average(ssim_score_list)\n", - " Average_ssim_score_list.append(Average_SSIM_checkpoint)\n", - "\n", - "\n", - "\n", - "\n", - "#------------------------------------------- QC for Grayscale ----------------------------------------------\n", - "\n", - " if Image_type == \"Grayscale\":\n", - " def ssim(img1, img2):\n", - " return structural_similarity(img1,img2,data_range=1.,full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n", - "\n", - "\n", - " def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n", - "\n", - "\n", - " mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n", - " ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n", - " return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n", - "\n", - "\n", - " def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n", - " \n", - " if dtype is not None:\n", - " x = x.astype(dtype,copy=False)\n", - " mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n", - " ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n", - " eps = dtype(eps)\n", - "\n", - " try:\n", - " import numexpr\n", - " x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n", - " except ImportError:\n", - " x = (x - mi) / ( ma - mi + eps )\n", - "\n", - " if clip:\n", - " x = np.clip(x,0,1)\n", - "\n", - " return x\n", - "\n", - " def norm_minmse(gt, x, normalize_gt=True):\n", - " \n", - " if normalize_gt:\n", - " gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n", - " x = x.astype(np.float32, copy=False) - np.mean(x)\n", - " #x = x - np.mean(x)\n", - " gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n", - " #gt = gt - np.mean(gt)\n", - " scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n", - " return gt, scale * x\n", - "\n", - "# Open and create the csv file that will contain all the QC metrics\n", - " with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", \"w\", newline='') as file:\n", - " writer = csv.writer(file)\n", - "\n", - " # Write the header in the csv file\n", - " writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n", - "\n", - " \n", - " \n", - " # Let's loop through the provided dataset in the QC folders\n", - "\n", - "\n", - " for i in os.listdir(Source_QC_folder):\n", - " if not os.path.isdir(os.path.join(Source_QC_folder,i)):\n", - " print('Running QC on: '+i)\n", - "\n", - " ssim_score_list = []\n", - " shortname_no_PNG = i[:-4]\n", - " # -------------------------------- Target test data (Ground truth) --------------------------------\n", - " test_GT_raw = imageio.imread(os.path.join(Target_QC_folder, i), as_gray=False, pilmode=\"RGB\")\n", - " \n", - " test_GT = test_GT_raw[:,:,2]\n", - "\n", - " # -------------------------------- Source test data --------------------------------\n", - " test_source_raw = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_real.png\"))\n", - " \n", - " test_source = test_source_raw[:,:,2]\n", - "\n", - " # Normalize the images wrt each other by minimizing the MSE between GT and Source image\n", - " test_GT_norm,test_source_norm = norm_minmse(test_GT, test_source, normalize_gt=True)\n", - "\n", - " # -------------------------------- Prediction --------------------------------\n", - " test_prediction_raw = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_fake.png\"))\n", - " \n", - " test_prediction = test_prediction_raw[:,:,2]\n", - "\n", - " # Normalize the images wrt each other by minimizing the MSE between GT and prediction\n", - " test_GT_norm,test_prediction_norm = norm_minmse(test_GT, test_prediction, normalize_gt=True) \n", - "\n", - "\n", - " # -------------------------------- Calculate the metric maps and save them --------------------------------\n", - "\n", - " # Calculate the SSIM maps\n", - " index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT_norm, test_prediction_norm)\n", - " index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT_norm, test_source_norm)\n", - "\n", - " ssim_score_list.append(index_SSIM_GTvsPrediction)\n", - "\n", - " #Save ssim_maps\n", - " \n", - " img_SSIM_GTvsPrediction_8bit = (img_SSIM_GTvsPrediction* 255).astype(\"uint8\")\n", - " io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsPrediction_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsPrediction_8bit)\n", - " img_SSIM_GTvsSource_8bit = (img_SSIM_GTvsSource* 255).astype(\"uint8\")\n", - " io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsSource_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsSource_8bit)\n", - " \n", - " # Calculate the Root Squared Error (RSE) maps\n", - " img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n", - " img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n", - "\n", - " # Save SE maps\n", - " img_RSE_GTvsPrediction_8bit = (img_RSE_GTvsPrediction* 255).astype(\"uint8\")\n", - " io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/RSE_GTvsPrediction_\"+shortname_no_PNG+'.tif',img_RSE_GTvsPrediction_8bit)\n", - " img_RSE_GTvsSource_8bit = (img_RSE_GTvsSource* 255).astype(\"uint8\")\n", - " io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/RSE_GTvsSource_\"+shortname_no_PNG+'.tif',img_RSE_GTvsSource_8bit)\n", - "\n", - "\n", - " # -------------------------------- Calculate the RSE metrics and save them --------------------------------\n", - "\n", - " # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n", - " NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n", - " NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n", - " \n", - " # We can also measure the peak signal to noise ratio between the images\n", - " PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n", - " PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n", - "\n", - " writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource),str(PSNR_GTvsPrediction),str(PSNR_GTvsSource)])\n", - "\n", - " #Here we calculate the ssim average for each image in each checkpoints\n", - "\n", - " Average_SSIM_checkpoint = Average(ssim_score_list)\n", - " Average_ssim_score_list.append(Average_SSIM_checkpoint)\n", - "\n", - "\n", - "# All data is now processed saved\n", - " \n", - "\n", - "# -------------------------------- Display --------------------------------\n", - "\n", - "# Display the IoV vs Threshold plot\n", - "plt.figure(figsize=(20,5))\n", - "plt.plot(Checkpoint_list, Average_ssim_score_list, label=\"SSIM\")\n", - "plt.title('Checkpoints vs. SSIM')\n", - "plt.ylabel('SSIM')\n", - "plt.xlabel('Checkpoints')\n", - "plt.legend()\n", - "plt.savefig(full_QC_model_path+'Quality Control/SSIMvsCheckpoint_data.png',bbox_inches='tight',pad_inches=0)\n", - "plt.show()\n", - "\n", - "\n", - "\n", - "# -------------------------------- Display RGB --------------------------------\n", - "\n", - "from ipywidgets import interact\n", - "import ipywidgets as widgets\n", - "\n", - "\n", - "if Image_type == \"RGB\":\n", - " random_choice_shortname_no_PNG = shortname_no_PNG\n", - "\n", - " @interact\n", - " def show_results(file=os.listdir(Source_QC_folder), checkpoints=Checkpoint_list):\n", - "\n", - " random_choice_shortname_no_PNG = file[:-4]\n", - "\n", - " df1 = pd.read_csv(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", header=0)\n", - " df2 = df1.set_index(\"image #\", drop = False)\n", - " index_SSIM_GTvsPrediction = df2.loc[file, \"Prediction v. GT mSSIM\"]\n", - " index_SSIM_GTvsSource = df2.loc[file, \"Input v. GT mSSIM\"]\n", - "\n", - "#Setting up colours\n", - " \n", - " cmap = None\n", - "\n", - " plt.figure(figsize=(10,10))\n", - "\n", - "# Target (Ground-truth)\n", - " plt.subplot(3,3,1)\n", - " plt.axis('off')\n", - " img_GT = imageio.imread(os.path.join(Target_QC_folder, file), as_gray=False, pilmode=\"RGB\")\n", - " plt.imshow(img_GT, cmap = cmap)\n", - " plt.title('Target',fontsize=15)\n", - "\n", - "# Source\n", - " plt.subplot(3,3,2)\n", - " plt.axis('off')\n", - " img_Source = imageio.imread(os.path.join(Source_QC_folder, file), as_gray=False, pilmode=\"RGB\")\n", - " plt.imshow(img_Source, cmap = cmap)\n", - " plt.title('Source',fontsize=15)\n", - "\n", - "#Prediction\n", - " plt.subplot(3,3,3)\n", - " plt.axis('off')\n", - "\n", - " img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_fake.png\"))\n", - "\n", - " plt.imshow(img_Prediction, cmap = cmap)\n", - " plt.title('Prediction',fontsize=15)\n", - "\n", - "\n", - "#SSIM between GT and Source\n", - " plt.subplot(3,3,5)\n", - "#plt.axis('off')\n", - " plt.tick_params(\n", - " axis='both', # changes apply to the x-axis and y-axis\n", - " which='both', # both major and minor ticks are affected\n", - " bottom=False, # ticks along the bottom edge are off\n", - " top=False, # ticks along the top edge are off\n", - " left=False, # ticks along the left edge are off\n", - " right=False, # ticks along the right edge are off\n", - " labelbottom=False,\n", - " labelleft=False)\n", - "\n", - " img_SSIM_GTvsSource = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsSource_\"+random_choice_shortname_no_PNG+\".tif\"))\n", - "\n", - " imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n", - "#plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n", - " plt.title('Target vs. Source',fontsize=15)\n", - " plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\n", - " plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n", - "\n", - "#SSIM between GT and Prediction\n", - " plt.subplot(3,3,6)\n", - "#plt.axis('off')\n", - " plt.tick_params(\n", - " axis='both', # changes apply to the x-axis and y-axis\n", - " which='both', # both major and minor ticks are affected\n", - " bottom=False, # ticks along the bottom edge are off\n", - " top=False, # ticks along the top edge are off\n", - " left=False, # ticks along the left edge are off\n", - " right=False, # ticks along the right edge are off\n", - " labelbottom=False,\n", - " labelleft=False) \n", - "\n", - " img_SSIM_GTvsPrediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\n", - "\n", - " imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n", - "#plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n", - " plt.title('Target vs. Prediction',fontsize=15)\n", - " plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\n", - " plt.savefig(full_QC_model_path+'Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n", - "\n", - "# -------------------------------- Display Grayscale --------------------------------\n", - "\n", - "if Image_type == \"Grayscale\":\n", - " random_choice_shortname_no_PNG = shortname_no_PNG\n", - "\n", - " @interact\n", - " def show_results(file=os.listdir(Source_QC_folder), checkpoints=Checkpoint_list):\n", - "\n", - " random_choice_shortname_no_PNG = file[:-4]\n", - "\n", - " df1 = pd.read_csv(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", header=0)\n", - " df2 = df1.set_index(\"image #\", drop = False)\n", - " index_SSIM_GTvsPrediction = df2.loc[file, \"Prediction v. GT mSSIM\"]\n", - " index_SSIM_GTvsSource = df2.loc[file, \"Input v. GT mSSIM\"]\n", - "\n", - " NRMSE_GTvsPrediction = df2.loc[file, \"Prediction v. GT NRMSE\"]\n", - " NRMSE_GTvsSource = df2.loc[file, \"Input v. GT NRMSE\"]\n", - " PSNR_GTvsSource = df2.loc[file, \"Input v. GT PSNR\"]\n", - " PSNR_GTvsPrediction = df2.loc[file, \"Prediction v. GT PSNR\"]\n", - " \n", - "\n", - " plt.figure(figsize=(15,15))\n", - "\n", - " cmap = None\n", - " \n", - " # Target (Ground-truth)\n", - " plt.subplot(3,3,1)\n", - " plt.axis('off')\n", - " img_GT = imageio.imread(os.path.join(Target_QC_folder, file), as_gray=True, pilmode=\"RGB\")\n", - "\n", - " plt.imshow(img_GT, norm=simple_norm(img_GT, percent = 99), cmap = 'gray')\n", - " plt.title('Target',fontsize=15)\n", - "\n", - "# Source\n", - " plt.subplot(3,3,2)\n", - " plt.axis('off')\n", - " img_Source = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_real.png\"))\n", - " plt.imshow(img_Source, norm=simple_norm(img_Source, percent = 99))\n", - " plt.title('Source',fontsize=15)\n", - "\n", - "#Prediction\n", - " plt.subplot(3,3,3)\n", - " plt.axis('off')\n", - " img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_fake.png\"))\n", - " plt.imshow(img_Prediction, norm=simple_norm(img_Prediction, percent = 99))\n", - " plt.title('Prediction',fontsize=15)\n", - "\n", - "#Setting up colours\n", - " cmap = plt.cm.CMRmap\n", - "\n", - "#SSIM between GT and Source\n", - " plt.subplot(3,3,5)\n", - "#plt.axis('off')\n", - " plt.tick_params(\n", - " axis='both', # changes apply to the x-axis and y-axis\n", - " which='both', # both major and minor ticks are affected\n", - " bottom=False, # ticks along the bottom edge are off\n", - " top=False, # ticks along the top edge are off\n", - " left=False, # ticks along the left edge are off\n", - " right=False, # ticks along the right edge are off\n", - " labelbottom=False,\n", - " labelleft=False)\n", - "\n", - " img_SSIM_GTvsSource = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsSource_\"+random_choice_shortname_no_PNG+\".tif\"))\n", - " img_SSIM_GTvsSource = img_SSIM_GTvsSource / 255\n", - " imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n", - "\n", - " \n", - " plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n", - " plt.title('Target vs. Source',fontsize=15)\n", - " plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\n", - " plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n", - "\n", - "#SSIM between GT and Prediction\n", - " plt.subplot(3,3,6)\n", - "#plt.axis('off')\n", - " plt.tick_params(\n", - " axis='both', # changes apply to the x-axis and y-axis\n", - " which='both', # both major and minor ticks are affected\n", - " bottom=False, # ticks along the bottom edge are off\n", - " top=False, # ticks along the top edge are off\n", - " left=False, # ticks along the left edge are off\n", - " right=False, # ticks along the right edge are off\n", - " labelbottom=False,\n", - " labelleft=False) \n", - " \n", - " \n", - " img_SSIM_GTvsPrediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\n", - " img_SSIM_GTvsPrediction = img_SSIM_GTvsPrediction / 255\n", - " imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n", - "\n", - " \n", - " plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n", - " plt.title('Target vs. Prediction',fontsize=15)\n", - " plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\n", - "\n", - "#Root Squared Error between GT and Source\n", - " plt.subplot(3,3,8)\n", - "#plt.axis('off')\n", - " plt.tick_params(\n", - " axis='both', # changes apply to the x-axis and y-axis\n", - " which='both', # both major and minor ticks are affected\n", - " bottom=False, # ticks along the bottom edge are off\n", - " top=False, # ticks along the top edge are off\n", - " left=False, # ticks along the left edge are off\n", - " right=False, # ticks along the right edge are off\n", - " labelbottom=False,\n", - " labelleft=False)\n", - "\n", - " img_RSE_GTvsSource = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"RSE_GTvsSource_\"+random_choice_shortname_no_PNG+\".tif\"))\n", - " img_RSE_GTvsSource = img_RSE_GTvsSource / 255\n", - " \n", - "\n", - " imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource, cmap = cmap, vmin=0, vmax = 1)\n", - " plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n", - " plt.title('Target vs. Source',fontsize=15)\n", - " plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsSource,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n", - "#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\n", - " plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n", - "\n", - "#Root Squared Error between GT and Prediction\n", - " plt.subplot(3,3,9)\n", - "#plt.axis('off')\n", - " plt.tick_params(\n", - " axis='both', # changes apply to the x-axis and y-axis\n", - " which='both', # both major and minor ticks are affected\n", - " bottom=False, # ticks along the bottom edge are off\n", - " top=False, # ticks along the top edge are off\n", - " left=False, # ticks along the left edge are off\n", - " right=False, # ticks along the right edge are off\n", - " labelbottom=False,\n", - " labelleft=False)\n", - "\n", - " img_RSE_GTvsPrediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"RSE_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\n", - "\n", - " img_RSE_GTvsPrediction = img_RSE_GTvsPrediction / 255\n", - "\n", - " imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n", - " plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n", - " plt.title('Target vs. Prediction',fontsize=15)\n", - " plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsPrediction,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n", - " plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n", - "\n", - "\n", - "#Make a pdf summary of the QC results\n", - "\n", - "qc_pdf_export()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-tJeeJjLnRkP" - }, - "source": [ - "# **6. Using the trained model**\n", - "\n", - "---\n", - "\n", - "In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "d8wuQGjoq6eN" - }, - "source": [ - "## **6.1. Generate prediction(s) from unseen dataset**\n", - "---\n", - "\n", - "The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as PNG images.\n", - "\n", - "**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n", - "\n", - "**`Result_folder`:** This folder will contain the predicted output images.\n", - "\n", - "**`checkpoint`:** Choose the checkpoint number you would like to use to perform predictions. To use the \"latest\" checkpoint, input \"latest\"." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "y2TD5p7MZrEb" - }, - "outputs": [], - "source": [ - "#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then play the cell to predict outputs from your unseen images.\n", - "\n", - "import glob\n", - "import os.path\n", - "\n", - "\n", - "latest = \"latest\"\n", - "\n", - "Data_folder = \"\" #@param {type:\"string\"}\n", - "Result_folder = \"\" #@param {type:\"string\"}\n", - "\n", - "\n", - "# model name and path\n", - "#@markdown ###Do you want to use the current trained model?\n", - "Use_the_current_trained_model = True #@param {type:\"boolean\"}\n", - "\n", - "#@markdown ###If not, please provide the path to the model folder:\n", - "\n", - "Prediction_model_folder = \"\" #@param {type:\"string\"}\n", - "\n", - "#@markdown ###What model checkpoint would you like to use?\n", - "\n", - "checkpoint = latest#@param {type:\"raw\"}\n", - "\n", - "\n", - "#Here we find the loaded model name and parent path\n", - "Prediction_model_name = os.path.basename(Prediction_model_folder)\n", - "Prediction_model_path = os.path.dirname(Prediction_model_folder)\n", - "\n", - "#here we check if we use the newly trained network or not\n", - "if (Use_the_current_trained_model): \n", - " print(\"Using current trained network\")\n", - " Prediction_model_name = model_name\n", - " Prediction_model_path = model_path\n", - "\n", - "#here we check if the model exists\n", - "full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n", - "\n", - "if os.path.exists(full_Prediction_model_path):\n", - " print(\"The \"+Prediction_model_name+\" network will be used.\")\n", - "else:\n", - " W = '\\033[0m' # white (normal)\n", - " R = '\\033[31m' # red\n", - " print(R+'!! WARNING: The chosen model does not exist !!'+W)\n", - " print('Please make sure you provide a valid model path and model name before proceeding further.')\n", - "\n", - "\n", - "# Here we check that checkpoint exist, if not the closest one will be chosen \n", - "\n", - "Nb_Checkpoint = len(glob.glob(os.path.join(full_Prediction_model_path, '*G_A.pth')))\n", - "print(Nb_Checkpoint)\n", - "\n", - "\n", - "if not checkpoint == \"latest\":\n", - "\n", - " if checkpoint < 10:\n", - " checkpoint = 5\n", - "\n", - " if not checkpoint % 5 == 0:\n", - " checkpoint = ((int(checkpoint / 5)-1) * 5)\n", - " print (bcolors.WARNING + \" Your chosen checkpoints is not divisible by 5; therefore the checkpoints chosen is now:\",checkpoints)\n", - " \n", - " if checkpoint > Nb_Checkpoint*5:\n", - " checkpoint = \"latest\"\n", - "\n", - " if checkpoint == Nb_Checkpoint*5:\n", - " checkpoint = \"latest\"\n", - "\n", - "\n", - "\n", - "\n", - "# Here we need to move the data to be analysed so that cycleGAN can find them\n", - "\n", - "Saving_path_prediction= \"/content/\"+Prediction_model_name\n", - "\n", - "if os.path.exists(Saving_path_prediction):\n", - " shutil.rmtree(Saving_path_prediction)\n", - "os.makedirs(Saving_path_prediction)\n", - "\n", - "Saving_path_Data_folder = Saving_path_prediction+\"/testA\"\n", - "\n", - "if os.path.exists(Saving_path_Data_folder):\n", - " shutil.rmtree(Saving_path_Data_folder)\n", - "os.makedirs(Saving_path_Data_folder)\n", - "\n", - "for files in os.listdir(Data_folder):\n", - " shutil.copyfile(Data_folder+\"/\"+files, Saving_path_Data_folder+\"/\"+files)\n", - "\n", - "\n", - "Nb_files_Data_folder = len(os.listdir(Data_folder)) +10\n", - "\n", - "\n", - "\n", - "#Here we copy and rename the checkpoint to be used\n", - "\n", - "shutil.copyfile(full_Prediction_model_path+\"/\"+str(checkpoint)+\"_net_G_A.pth\", full_Prediction_model_path+\"/\"+str(checkpoint)+\"_net_G.pth\")\n", - "\n", - "\n", - "# This will find the image dimension of a randomly choosen image in Data_folder \n", - "random_choice = random.choice(os.listdir(Data_folder))\n", - "x = imageio.imread(Data_folder+\"/\"+random_choice)\n", - "\n", - "#Find image XY dimension\n", - "Image_Y = x.shape[0]\n", - "Image_X = x.shape[1]\n", - "\n", - "Image_min_dim = min(Image_Y, Image_X)\n", - "\n", - "print(Image_min_dim)\n", - "\n", - "\n", - "\n", - "#-------------------------------- Perform predictions -----------------------------\n", - "\n", - "#-------------------------------- Options that can be used to perform predictions -----------------------------\n", - "\n", - "# basic parameters\n", - " #('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')\n", - " #('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')\n", - " #('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')\n", - " #('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')\n", - "\n", - "# model parameters\n", - " #('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]')\n", - " #('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')\n", - " #('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')\n", - " #('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')\n", - " #('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')\n", - " #('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')\n", - " #('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]')\n", - " #('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')\n", - " #('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]')\n", - " #('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]')\n", - " #('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')\n", - " #('--no_dropout', action='store_true', help='no dropout for the generator')\n", - " \n", - "# dataset parameters\n", - " #('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')\n", - " #('--direction', type=str, default='AtoB', help='AtoB or BtoA')\n", - " #('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')\n", - " #('--num_threads', default=4, type=int, help='# threads for loading data')\n", - " #('--batch_size', type=int, default=1, help='input batch size')\n", - " #('--load_size', type=int, default=286, help='scale images to this size')\n", - " #('--crop_size', type=int, default=256, help='then crop to this size')\n", - " #('--max_dataset_size', type=int, default=float(\"inf\"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')\n", - " #('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')\n", - " #('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')\n", - " #('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')\n", - " \n", - "# additional parameters\n", - " #('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')\n", - " #('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')\n", - " #('--verbose', action='store_true', help='if specified, print more debugging information')\n", - " #('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')\n", - " \n", - "\n", - " #('--ntest', type=int, default=float(\"inf\"), help='# of test examples.')\n", - " #('--results_dir', type=str, default='./results/', help='saves results here.')\n", - " #('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')\n", - " #('--phase', type=str, default='test', help='train, val, test, etc')\n", - "\n", - "# Dropout and Batchnorm has different behavioir during training and test.\n", - " #('--eval', action='store_true', help='use eval mode during test time.')\n", - " #('--num_test', type=int, default=50, help='how many test images to run')\n", - " # rewrite devalue values\n", - " \n", - "# To avoid cropping, the load_size should be the same as crop_size\n", - " #parser.set_defaults(load_size=parser.get_default('crop_size'))\n", - "\n", - "#------------------------------------------------------------------------\n", - "\n", - "\n", - "#---------------------------- Predictions are performed here ----------------------\n", - "\n", - "os.chdir(\"/content\")\n", - "\n", - "!python pytorch-CycleGAN-and-pix2pix/test.py --dataroot \"$Saving_path_Data_folder\" --name \"$Prediction_model_name\" --model test --no_dropout --preprocess scale_width --load_size $Image_min_dim --crop_size $Image_min_dim --results_dir \"$Result_folder\" --checkpoints_dir \"$Prediction_model_path\" --num_test $Nb_files_Data_folder --epoch $checkpoint\n", - "\n", - "#-----------------------------------------------------------------------------------\n", - "\n", - "\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "SXqS_EhByhQ7" - }, - "source": [ - "## **6.2. Inspect the predicted output**\n", - "---\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "64emoATwylxM" - }, - "outputs": [], - "source": [ - "# @markdown ##Run this cell to display a randomly chosen input and its corresponding predicted output.\n", - "import os\n", - "# This will display a randomly chosen dataset input and predicted output\n", - "random_choice = random.choice(os.listdir(Data_folder))\n", - "\n", - "\n", - "random_choice_no_extension = os.path.splitext(random_choice)\n", - "\n", - "\n", - "x = imageio.imread(Result_folder+\"/\"+Prediction_model_name+\"/test_\"+str(checkpoint)+\"/images/\"+random_choice_no_extension[0]+\"_real.png\")\n", - "\n", - "\n", - "y = imageio.imread(Result_folder+\"/\"+Prediction_model_name+\"/test_\"+str(checkpoint)+\"/images/\"+random_choice_no_extension[0]+\"_fake.png\")\n", - "\n", - "f=plt.figure(figsize=(16,8))\n", - "plt.subplot(1,2,1)\n", - "plt.imshow(x, interpolation='nearest')\n", - "plt.title('Input')\n", - "plt.axis('off');\n", - "\n", - "plt.subplot(1,2,2)\n", - "plt.imshow(y, interpolation='nearest')\n", - "plt.title('Prediction')\n", - "plt.axis('off');\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "hvkd66PldsXB" - }, - "source": [ - "## **6.3. Download your predictions**\n", - "---\n", - "\n", - "**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "UvSlTaH14s3t" - }, - "source": [ - "\n", - "#**Thank you for using CycleGAN!**" - ] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "collapsed_sections": [], - "machine_shape": "hm", - "name": "CycleGAN_ZeroCostDL4Mic.ipynb", - "provenance": [], - "toc_visible": true - }, - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.4" - } - }, - "nbformat": 4, - "nbformat_minor": 1 -} +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"CycleGAN_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1A26cn0nxWQCv-LuP3UBfyCWlKBGIo0RU","timestamp":1611059046709},{"file_id":"1MmLTCC0nyX3Akb9V4C_OVxM3X_M8u-eX","timestamp":1610543191319},{"file_id":"1paNjUObR5Rcr4BMGADJTz0PQBBLZDPrY","timestamp":1602522500580},{"file_id":"1WZRIoSBNcRUEq4-Rq5M4mDkIaOlEHnxz","timestamp":1588762142860},{"file_id":"10weAY0es-pEfHlACCaBCKK7PmgdoJqdh","timestamp":1587728072051},{"file_id":"10Ze0rFZoooyyTL_OIVWGdFJEhWE6_cSB","timestamp":1586789421439},{"file_id":"1SsGyUbWcMaLGHFepMuKElRNYLdEBUwf6","timestamp":1583244509550}],"collapsed_sections":[],"machine_shape":"hm"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.7"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"IkSguVy8Xv83"},"source":["# **CycleGAN**\n","\n","---\n","\n","CycleGAN is a method that can capture the characteristics of one image domain and learn how these characteristics can be translated into another image domain, all in the absence of any paired training examples. It was first published by [Zhu *et al.* in 2017](https://arxiv.org/abs/1703.10593). Unlike pix2pix, the image transformation performed does not require paired images for training (unsupervised learning) and is made possible here by using a set of two Generative Adversarial Networks (GANs) that learn to transform images both from the first domain to the second and vice-versa.\n","\n"," **This particular notebook enables unpaired image-to-image translation. If your dataset is paired, you should also consider using the pix2pix notebook.**\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the *Zero-Cost Deep-Learning to Enhance Microscopy* project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is based on the following paper: \n","\n"," **Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks** from Zhu *et al.* published in arXiv in 2018 (https://arxiv.org/abs/1703.10593)\n","\n","The source code of the CycleGAN PyTorch implementation can be found in: /~https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix\n","\n","**Please also cite this original paper when using or developing this notebook.**"]},{"cell_type":"markdown","metadata":{"id":"jqvkQQkcuMmM"},"source":["# **License**\n","\n","---"]},{"cell_type":"code","metadata":{"cellView":"form","id":"vCihhAzluRvI"},"source":["#@markdown ##Double click to see the license information\n","\n","#------------------------- LICENSE FOR ZeroCostDL4Mic------------------------------------\n","#This ZeroCostDL4Mic notebook is distributed under the MIT licence\n","\n","\n","\n","#------------------------- LICENSE FOR CycleGAN ------------------------------------\n","\n","#Copyright (c) 2017, Jun-Yan Zhu and Taesung Park\n","#All rights reserved.\n","\n","#Redistribution and use in source and binary forms, with or without\n","#modification, are permitted provided that the following conditions are met:\n","\n","#* Redistributions of source code must retain the above copyright notice, this\n","# list of conditions and the following disclaimer.\n","\n","#* Redistributions in binary form must reproduce the above copyright notice,\n","# this list of conditions and the following disclaimer in the documentation\n","# and/or other materials provided with the distribution.\n","\n","#THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n","#AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n","#IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n","#DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n","#FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n","#DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n","#SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n","#CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n","#OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n","#OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n","\n","\n","#--------------------------- LICENSE FOR pix2pix --------------------------------\n","#BSD License\n","\n","#For pix2pix software\n","#Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu\n","#All rights reserved.\n","\n","#Redistribution and use in source and binary forms, with or without\n","#modification, are permitted provided that the following conditions are met:\n","\n","#* Redistributions of source code must retain the above copyright notice, this\n","# list of conditions and the following disclaimer.\n","\n","#* Redistributions in binary form must reproduce the above copyright notice,\n","# this list of conditions and the following disclaimer in the documentation\n","# and/or other materials provided with the distribution.\n","\n","#----------------------------- LICENSE FOR DCGAN --------------------------------\n","#BSD License\n","\n","#For dcgan.torch software\n","\n","#Copyright (c) 2015, Facebook, Inc. All rights reserved.\n","\n","#Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:\n","\n","#Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.\n","\n","#Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.\n","\n","#Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.\n","\n","#THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE."],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["#**0. Before getting started**\n","---\n"," To train CycleGAN, **you only need two folders containing PNG images**. The images do not need to be paired.\n","\n","While you do not need paired images to train CycleGAN, if possible, **we strongly recommend that you generate a paired dataset. This means that the same image needs to be acquired in the two conditions. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n","\n"," Please note that you currently can **only use .png files!**\n","\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset (non-matching images)**\n"," - Training_source\n"," - img_1.png, img_2.png, ...\n"," - Training_target\n"," - img_1.png, img_2.png, ...\n"," - **Quality control dataset (matching images)**\n"," - Training_source\n"," - img_1.png, img_2.png\n"," - Training_target\n"," - img_1.png, img_2.png\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"AdN8B91xZO0x"},"source":["# **1. Install CycleGAN and dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"fq21zJVFNASx","cellView":"form"},"source":["Notebook_version = '1.13'\n","Network = 'CycleGAN'\n","\n","from builtins import any as b_any\n","\n","def get_requirements_path():\n"," # Store requirements file in 'contents' directory \n"," current_dir = os.getcwd()\n"," dir_count = current_dir.count('/') - 1\n"," path = '../' * (dir_count) + 'requirements.txt'\n"," return path\n","\n","def filter_files(file_list, filter_list):\n"," filtered_list = []\n"," for fname in file_list:\n"," if b_any(fname.split('==')[0] in s for s in filter_list):\n"," filtered_list.append(fname)\n"," return filtered_list\n","\n","def build_requirements_file(before, after):\n"," path = get_requirements_path()\n","\n"," # Exporting requirements.txt for local run\n"," !pip freeze > $path\n","\n"," # Get minimum requirements file\n"," df = pd.read_csv(path, delimiter = \"\\n\")\n"," mod_list = [m.split('.')[0] for m in after if not m in before]\n"," req_list_temp = df.values.tolist()\n"," req_list = [x[0] for x in req_list_temp]\n","\n"," # Replace with package name and handle cases where import name is different to module name\n"," mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]\n"," mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list] \n"," filtered_list = filter_files(req_list, mod_replace_list)\n","\n"," file=open(path,'w')\n"," for item in filtered_list:\n"," file.writelines(item + '\\n')\n","\n"," file.close()\n","\n","import sys\n","before = [str(m) for m in sys.modules]\n","\n","#@markdown ##Install CycleGAN and dependencies\n","\n","#------- Code from the cycleGAN demo notebook starts here -------\n","\n","#Here, we install libraries which are not already included in Colab.\n","!git clone /~https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix\n","\n","import os\n","os.chdir('pytorch-CycleGAN-and-pix2pix/')\n","!pip install -r requirements.txt\n","!pip install fpdf\n","\n","import imageio\n","from skimage import data\n","from skimage import exposure\n","from skimage.exposure import match_histograms\n","\n","from skimage.util import img_as_int\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","from skimage.util import img_as_ubyte\n","from tqdm import tqdm \n","from fpdf import FPDF, HTMLMixin\n","from datetime import datetime\n","from pip._internal.operations.freeze import freeze\n","import subprocess\n","\n","import torch\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print(\"Libraries installed\")\n","\n","\n","# Check if this is the latest version of the notebook\n","All_notebook_versions = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_Notebook_versions.csv\", dtype=str)\n","print('Notebook version: '+Notebook_version)\n","Latest_Notebook_version = All_notebook_versions[All_notebook_versions[\"Notebook\"] == Network]['Version'].iloc[0]\n","print('Latest notebook version: '+Latest_Notebook_version)\n","if Notebook_version == Latest_Notebook_version:\n"," print(\"This notebook is up-to-date.\")\n","else:\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","\n","\n","def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," \n"," # add another cell\n"," if trained:\n"," training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n"," pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n"," pdf.ln(1)\n","\n"," Header_2 = 'Information for your materials and method:'\n"," pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n"," # print(all_packages)\n","\n"," #Main Packages\n"," main_packages = ''\n"," version_numbers = []\n"," for name in ['tensorflow','numpy','torch']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n"," cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n"," cuda_version = cuda_version.stdout.decode('utf-8')\n"," cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n"," gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n"," gpu_name = gpu_name.stdout.decode('utf-8')\n"," gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n"," #print(cuda_version[cuda_version.find(', V')+3:-1])\n"," #print(gpu_name)\n","\n"," shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n"," dataset_size = len(os.listdir(Training_source))\n","\n"," text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a least-square GAN loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), numpy (v '+version_numbers[1]+'), torch (v '+version_numbers[2]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," if pretrained_model:\n"," text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and an least-square GAN loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was retrained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), numpy (v '+version_numbers[1]+'), torch (v '+version_numbers[2]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.multi_cell(190, 5, txt = text, align='L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n"," pdf.set_font('')\n"," if augmentation:\n"," aug_text = 'The dataset was augmented by default'\n"," else:\n"," aug_text = 'No augmentation was used for training.'\n"," pdf.multi_cell(190, 5, txt=aug_text, align='L')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," if Use_Default_Advanced_Parameters:\n"," pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n"," pdf.cell(200, 5, txt='The following parameters were used for training:')\n"," pdf.ln(1)\n"," html = \"\"\" \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
number_of_epochs{0}
patch_size{1}
batch_size{2}
initial_learning_rate{3}
\n"," \"\"\".format(number_of_epochs,str(patch_size)+'x'+str(patch_size),batch_size,initial_learning_rate)\n"," pdf.write_html(html)\n","\n"," #pdf.multi_cell(190, 5, txt = text_2, align='L')\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," pdf.ln(1)\n"," pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(30, 5, txt= 'Training_source:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(29, 5, txt= 'Training_target:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n"," #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(22, 5, txt= 'Model Path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n"," pdf.ln(1)\n"," pdf.cell(60, 5, txt = 'Example Training pair', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread('/content/TrainingDataExample_cycleGAN.png').shape\n"," pdf.image('/content/TrainingDataExample_cycleGAN.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- cycleGAN: Zhu, Jun-Yan, et al. \"Unpaired image-to-image translation using cycle-consistent adversarial networks.\" Proceedings of the IEEE international conference on computer vision. 2017.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," # if Use_Data_augmentation:\n"," # ref_3 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\n"," # pdf.multi_cell(190, 5, txt = ref_3, align='L')\n"," pdf.ln(3)\n"," reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(model_path+'/'+model_name+'/'+model_name+\"_training_report.pdf\")\n","\n","\n","def qc_pdf_export():\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'cycleGAN'\n","\n","\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(2)\n"," pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'Quality Control/SSIMvsCheckpoint_data.png').shape\n"," pdf.image(full_QC_model_path+'Quality Control/SSIMvsCheckpoint_data.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(2)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(3)\n"," pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'Quality Control/QC_example_data.png').shape\n"," if Image_type == 'RGB':\n"," pdf.image(full_QC_model_path+'Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/5), h = round(exp_size[0]/5))\n"," if Image_type == 'Grayscale':\n"," pdf.image(full_QC_model_path+'Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n","\n"," pdf.ln(1)\n"," for checkpoint in os.listdir(full_QC_model_path+'Quality Control'):\n"," if os.path.isdir(os.path.join(full_QC_model_path,'Quality Control',checkpoint)):\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(70, 5, txt = 'Metrics for checkpoint: '+ str(checkpoint), align='L', ln=1)\n"," html = \"\"\"\n"," \n"," \n"," \"\"\"\n"," with open(full_QC_model_path+'Quality Control/'+str(checkpoint)+'/QC_metrics_'+QC_model_name+str(checkpoint)+'.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," mSSIM_PvsGT = header[1]\n"," mSSIM_SvsGT = header[2]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \"\"\".format(image,mSSIM_PvsGT,mSSIM_SvsGT)\n"," html = html+header\n"," for row in metrics:\n"," image = row[0]\n"," mSSIM_PvsGT = row[1]\n"," mSSIM_SvsGT = row[2]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \"\"\".format(image,str(round(float(mSSIM_PvsGT),3)),str(round(float(mSSIM_SvsGT),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}
{0}{1}{2}
\"\"\"\n"," pdf.write_html(html)\n"," pdf.ln(2)\n"," else:\n"," continue\n","\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- cycleGAN: Zhu, Jun-Yan, et al. \"Unpaired image-to-image translation using cycle-consistent adversarial networks.\" Proceedings of the IEEE international conference on computer vision. 2017.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n"," pdf.ln(3)\n"," reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(full_QC_model_path+'Quality Control/'+QC_model_name+'_QC_report.pdf')\n","\n","\n","# Build requirements file for local run\n","after = [str(m) for m in sys.modules]\n","build_requirements_file(before, after)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **2. Initialise the Colab session**\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["\n","## **2.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelerator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"zCvebubeSaGY","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"sNIVx8_CLolt"},"source":["## **2.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","cellView":"form"},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"HLYcZR9gMv42"},"source":["# **3. Select your parameters and paths**\n","---"]},{"cell_type":"markdown","metadata":{"id":"FQ_QxtSWQ7CL"},"source":["## **3.1. Setting main training parameters**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"id":"AuESFimvMv43"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source and Training_target training data respecively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**Training Parameters**\n","\n","**`number_of_epochs`:**Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10) epochs, but a full training should run for 200 epochs or more. Evaluate the performance after training (see 5). **Default value: 200**\n","\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`patch_size`:** CycleGAN divides the image into patches for training. Input the size of the patches (length of a side). The value should be smaller than the dimensions of the image and divisible by 4. **Default value: 512**\n","\n","**When choosing the patch_size, the value should be i) large enough that it will enclose many instances, ii) small enough that the resulting patches fit into the RAM.** \n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 1**\n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0002**"]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","cellView":"form"},"source":["\n","\n","#@markdown ###Path to training images:\n","\n","Training_source = \"\" #@param {type:\"string\"}\n","InputFile = Training_source+\"/*.png\"\n","\n","Training_target = \"\" #@param {type:\"string\"}\n","OutputFile = Training_target+\"/*.png\"\n","\n","\n","#Define where the patch file will be saved\n","base = \"/content\"\n","\n","\n","# model name and path\n","#@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","# other parameters for training.\n","#@markdown ###Training Parameters\n","#@markdown Number of epochs:\n","number_of_epochs = 200#@param {type:\"number\"}\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","#@markdown ###If not, please input:\n","patch_size = 512#@param {type:\"number\"} # in pixels\n","batch_size = 2#@param {type:\"number\"}\n","initial_learning_rate = 0.0002 #@param {type:\"number\"}\n","\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 1\n"," patch_size = 512\n"," initial_learning_rate = 0.0002\n","\n","#here we check that no model with the same name already exist, if so delete\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n"," print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\")\n"," \n","\n","\n","#To use Cyclegan we need to organise the data in a way the model can understand\n","\n","Saving_path= \"/content/\"+model_name\n","#Saving_path= model_path+\"/\"+model_name\n","\n","if os.path.exists(Saving_path):\n"," shutil.rmtree(Saving_path)\n","os.makedirs(Saving_path)\n","\n","TrainA_Folder = Saving_path+\"/trainA\"\n","if os.path.exists(TrainA_Folder):\n"," shutil.rmtree(TrainA_Folder)\n","os.makedirs(TrainA_Folder)\n"," \n","TrainB_Folder = Saving_path+\"/trainB\"\n","if os.path.exists(TrainB_Folder):\n"," shutil.rmtree(TrainB_Folder)\n","os.makedirs(TrainB_Folder)\n","\n","# Here we disable pre-trained model by default (in case the cell is not ran)\n","Use_pretrained_model = False\n","\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = True\n","\n","\n","# This will display a randomly chosen dataset input and output\n","random_choice = random.choice(os.listdir(Training_source))\n","x = imageio.imread(Training_source+\"/\"+random_choice)\n","\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","Image_min_dim = min(Image_Y, Image_X)\n","\n","\n","\n","#Hyperparameters failsafes\n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is divisible by 4\n","if not patch_size % 4 == 0:\n"," patch_size = ((int(patch_size / 4)-1) * 4)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 4; therefore the patch_size chosen is now:\",patch_size)\n","\n","\n","random_choice_2 = random.choice(os.listdir(Training_target))\n","y = imageio.imread(Training_target+\"/\"+random_choice_2)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest')\n","plt.title('Training source')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(y, interpolation='nearest')\n","plt.title('Training target')\n","plt.axis('off');\n","plt.savefig('/content/TrainingDataExample_cycleGAN.png',bbox_inches='tight',pad_inches=0)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xyQZKby8yFME"},"source":["## **3.2. Data augmentation**\n","---\n",""]},{"cell_type":"markdown","metadata":{"id":"w_jCy7xOx2g3"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n","Data augmentation is performed here by flipping the patches. \n","\n"," By default data augmentation is enabled."]},{"cell_type":"code","metadata":{"id":"DMqWq5-AxnFU","cellView":"form"},"source":["#Data augmentation\n","\n","#@markdown ##Play this cell to enable or disable data augmentation: \n","\n","Use_Data_augmentation = True #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation:\n"," print(\"Data augmentation enabled\")\n","\n","if not Use_Data_augmentation:\n"," print(\"Data augmentation disabled\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"3L9zSGtORKYI"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a CycleGAN model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"9vC2n-HeLdiJ","cellView":"form"},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n"," h5_file_path_A = os.path.join(pretrained_model_path, \"latest_net_G_A.pth\")\n"," h5_file_path_B = os.path.join(pretrained_model_path, \"latest_net_G_B.pth\")\n","\n","# --------------------- Check the model exist ------------------------\n","\n"," if not os.path.exists(h5_file_path_A) and os.path.exists(h5_file_path_B):\n"," print(bcolors.WARNING+'WARNING: Pretrained model does not exist')\n"," Use_pretrained_model = False\n"," print(bcolors.WARNING+'No pretrained network will be used.')\n","\n"," if os.path.exists(h5_file_path_A) and os.path.exists(h5_file_path_B):\n"," print(\"Pretrained model \"+os.path.basename(pretrained_model_path)+\" was found and will be loaded prior to training.\")\n"," \n","else:\n"," print(bcolors.WARNING+'No pretrained network will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"MCGklf1vZf2M"},"source":["# **4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"1KYOuygETJkT"},"source":["## **4.1. Prepare the training data for training**\n","---\n","Here, we use the information from 3. to prepare the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"id":"lIUAOJ_LMv5E","cellView":"form"},"source":["#@markdown ##Prepare the data for training\n","\n","print(\"Data preparation in progress\")\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," shutil.rmtree(model_path+'/'+model_name)\n","os.makedirs(model_path+'/'+model_name)\n","\n","#--------------- Here we move the files to trainA and train B ---------\n","\n","\n","for f in os.listdir(Training_source):\n"," shutil.copyfile(Training_source+\"/\"+f, TrainA_Folder+\"/\"+f)\n","\n","for files in os.listdir(Training_target):\n"," shutil.copyfile(Training_target+\"/\"+files, TrainB_Folder+\"/\"+files)\n","\n","#---------------------------------------------------------------------\n","\n","# CycleGAN use number of EPOCH withouth lr decay and number of EPOCH with lr decay\n","\n","\n","number_of_epochs_lr_stable = int(number_of_epochs/2)\n","number_of_epochs_lr_decay = int(number_of_epochs/2)\n","\n","if Use_pretrained_model :\n"," for f in os.listdir(pretrained_model_path):\n"," if (f.startswith(\"latest_net_\")): \n"," shutil.copyfile(pretrained_model_path+\"/\"+f, model_path+'/'+model_name+\"/\"+f)\n","\n","\n","pdf_export(augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)\n","\n","print(\"Data ready for training\")\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"0Dfn8ZsEMv5d"},"source":["## **4.2. Start Training**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches or continue the training in a second Colab session.\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder from Google Drive as all data can be erased at the next training if using the same folder."]},{"cell_type":"code","metadata":{"scrolled":true,"id":"iwNmp1PUzRDQ","cellView":"form"},"source":["\n","#@markdown ##Start training\n","\n","start = time.time()\n","\n","os.chdir(\"/content\")\n","\n","#--------------------------------- Command line inputs to change CycleGAN paramaters------------\n","\n"," # basic parameters\n"," #('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')\n"," #('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')\n"," #('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')\n"," #('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')\n"," \n"," # model parameters\n"," #('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]')\n"," #('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')\n"," #('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')\n"," #('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')\n"," #('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')\n"," #('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')\n"," #('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]')\n"," #('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')\n"," #('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]')\n"," #('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]')\n"," #('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')\n"," #('--no_dropout', action='store_true', help='no dropout for the generator')\n"," \n"," # dataset parameters\n"," #('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')\n"," #('--direction', type=str, default='AtoB', help='AtoB or BtoA')\n"," #('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')\n"," #('--num_threads', default=4, type=int, help='# threads for loading data')\n"," #('--batch_size', type=int, default=1, help='input batch size')\n"," #('--load_size', type=int, default=286, help='scale images to this size')\n"," #('--crop_size', type=int, default=256, help='then crop to this size')\n"," #('--max_dataset_size', type=int, default=float(\"inf\"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')\n"," #('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')\n"," #('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')\n"," #('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')\n"," \n"," # additional parameters\n"," #('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')\n"," #('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')\n"," #('--verbose', action='store_true', help='if specified, print more debugging information')\n"," #('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')\n"," \n"," # visdom and HTML visualization parameters\n"," #('--display_freq', type=int, default=400, help='frequency of showing training results on screen')\n"," #('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.')\n"," #('--display_id', type=int, default=1, help='window id of the web display')\n"," #('--display_server', type=str, default=\"http://localhost\", help='visdom server of the web display')\n"," #('--display_env', type=str, default='main', help='visdom display environment name (default is \"main\")')\n"," #('--display_port', type=int, default=8097, help='visdom port of the web display')\n"," #('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html')\n"," #('--print_freq', type=int, default=100, help='frequency of showing training results on console')\n"," #('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')\n"," \n"," # network saving and loading parameters\n"," #('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')\n"," #('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')\n"," #('--save_by_iter', action='store_true', help='whether saves model by iteration')\n"," #('--continue_train', action='store_true', help='continue training: load the latest model')\n"," #('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...')\n"," #('--phase', type=str, default='train', help='train, val, test, etc')\n"," \n"," # training parameters\n"," #('--n_epochs', type=int, default=100, help='number of epochs with the initial learning rate')\n"," #('--n_epochs_decay', type=int, default=100, help='number of epochs to linearly decay learning rate to zero')\n"," #('--beta1', type=float, default=0.5, help='momentum term of adam')\n"," #('--lr', type=float, default=0.0002, help='initial learning rate for adam')\n"," #('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.')\n"," #('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images')\n"," #('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]')\n"," #('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations'\n","\n","#---------------------------------------------------------\n","\n","#----- Start the training ------------------------------------\n","if not Use_pretrained_model:\n"," if Use_Data_augmentation:\n"," !python pytorch-CycleGAN-and-pix2pix/train.py --dataroot \"$Saving_path\" --input_nc 3 --name $model_name --model cycle_gan --batch_size $batch_size --preprocess scale_width_and_crop --load_size $Image_min_dim --crop_size $patch_size --checkpoints_dir \"$model_path\" --no_html --n_epochs $number_of_epochs_lr_stable --n_epochs_decay $number_of_epochs_lr_decay --lr $initial_learning_rate --display_id 0 --save_epoch_freq 5\n"," if not Use_Data_augmentation:\n"," !python pytorch-CycleGAN-and-pix2pix/train.py --dataroot \"$Saving_path\" --input_nc 3 --name $model_name --model cycle_gan --batch_size $batch_size --preprocess scale_width_and_crop --load_size $Image_min_dim --crop_size $patch_size --checkpoints_dir \"$model_path\" --no_html --n_epochs $number_of_epochs_lr_stable --n_epochs_decay $number_of_epochs_lr_decay --lr $initial_learning_rate --display_id 0 --save_epoch_freq 5 --no_flip\n","\n","if Use_pretrained_model:\n"," if Use_Data_augmentation:\n"," !python pytorch-CycleGAN-and-pix2pix/train.py --dataroot \"$Saving_path\" --input_nc 3 --name $model_name --model cycle_gan --batch_size $batch_size --preprocess scale_width_and_crop --load_size $Image_min_dim --crop_size $patch_size --checkpoints_dir \"$model_path\" --no_html --n_epochs $number_of_epochs_lr_stable --n_epochs_decay $number_of_epochs_lr_decay --lr $initial_learning_rate --display_id 0 --save_epoch_freq 5 --continue_train\n"," \n"," if not Use_Data_augmentation:\n"," !python pytorch-CycleGAN-and-pix2pix/train.py --dataroot \"$Saving_path\" --input_nc 3 --name $model_name --model cycle_gan --batch_size $batch_size --preprocess scale_width_and_crop --load_size $Image_min_dim --crop_size $patch_size --checkpoints_dir \"$model_path\" --no_html --n_epochs $number_of_epochs_lr_stable --n_epochs_decay $number_of_epochs_lr_decay --lr $initial_learning_rate --display_id 0 --save_epoch_freq 5 --continue_train --no_flip\n","\n","#---------------------------------------------------------\n","\n","print(\"Training, done.\")\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","# Save training summary as pdf\n","\n","pdf_export(trained = True, augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"_0Hynw3-xHp1"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n","Unfortunately loss functions curve are not very informative for GAN network. Therefore we perform the QC here using a test dataset.\n","\n","\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"1Wext8woxt_F"},"source":["## **5.1. Choose the model you want to assess**"]},{"cell_type":"code","metadata":{"id":"eAJzMwPA6tlH","cellView":"form"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"1CFbjvTpx5C3"},"source":["## **5.2. Identify the best checkpoint to use to make predictions**"]},{"cell_type":"markdown","metadata":{"id":"q8tCfAadx96X"},"source":[" CycleGAN save model checkpoints every five epochs. Due to the stochastic nature of GAN networks, the last checkpoint is not always the best one to use. As a consequence, it can be challenging to choose the most suitable checkpoint to use to make predictions.\n","\n","This section allows you to perform predictions using all the saved checkpoints and to estimate the quality of these predictions by comparing them to the provided ground truths images. Metric used include:\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n","\n"]},{"cell_type":"code","metadata":{"cellView":"form","id":"q2T4t8NNyDZ6"},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","Image_type = \"Grayscale\" #@param [\"Grayscale\", \"RGB\"]\n","\n","# average function\n","def Average(lst): \n"," return sum(lst) / len(lst) \n","\n","\n","# Create a quality control folder\n","\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n","\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n","\n","# List images in Source_QC_folder\n","# This will find the image dimension of a randomly choosen image in Source_QC_folder \n","random_choice = random.choice(os.listdir(Source_QC_folder))\n","x = imageio.imread(Source_QC_folder+\"/\"+random_choice)\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","Image_min_dim = min(Image_Y, Image_X)\n","\n","\n","# Here we need to move the data to be analysed so that cycleGAN can find them\n","\n","Saving_path_QC= \"/content/\"+QC_model_name\n","\n","if os.path.exists(Saving_path_QC):\n"," shutil.rmtree(Saving_path_QC)\n","os.makedirs(Saving_path_QC)\n","\n","Saving_path_QC_folder = Saving_path_QC+\"_images\"\n","\n","if os.path.exists(Saving_path_QC_folder):\n"," shutil.rmtree(Saving_path_QC_folder)\n","os.makedirs(Saving_path_QC_folder)\n","\n","\n","#Here we copy and rename the all the checkpoint to be analysed\n","\n","for f in os.listdir(full_QC_model_path):\n"," shortname = f[:-6]\n"," shortname = shortname + \".pth\"\n"," if f.endswith(\"net_G_A.pth\"):\n"," shutil.copyfile(full_QC_model_path+f, Saving_path_QC+\"/\"+shortname)\n","\n","\n","for files in os.listdir(Source_QC_folder):\n"," shutil.copyfile(Source_QC_folder+\"/\"+files, Saving_path_QC_folder+\"/\"+files)\n"," \n","\n","# This will find the image dimension of a randomly chosen image in Source_QC_folder \n","random_choice = random.choice(os.listdir(Source_QC_folder))\n","x = imageio.imread(Source_QC_folder+\"/\"+random_choice)\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","Image_min_dim = int(min(Image_Y, Image_X))\n","\n","Nb_Checkpoint = len(os.listdir(Saving_path_QC))\n","\n","print(Nb_Checkpoint)\n","\n","\n","\n","## Initiate list\n","\n","Checkpoint_list = []\n","Average_ssim_score_list = []\n","\n","\n","for j in range(1, len(os.listdir(Saving_path_QC))+1):\n"," checkpoints = j*5\n","\n"," if checkpoints == Nb_Checkpoint*5:\n"," checkpoints = \"latest\"\n","\n","\n"," print(\"The checkpoint currently analysed is =\"+str(checkpoints))\n","\n"," Checkpoint_list.append(checkpoints)\n","\n","\n"," # Create a quality control/Prediction Folder\n","\n"," QC_prediction_results = QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)\n","\n"," if os.path.exists(QC_prediction_results):\n"," shutil.rmtree(QC_prediction_results)\n","\n"," os.makedirs(QC_prediction_results)\n","\n","\n","\n","#---------------------------- Predictions are performed here ----------------------\n","\n"," os.chdir(\"/content\")\n","\n"," !python pytorch-CycleGAN-and-pix2pix/test.py --dataroot \"$Saving_path_QC_folder\" --name \"$QC_model_name\" --model test --epoch $checkpoints --no_dropout --preprocess scale_width --load_size $Image_min_dim --crop_size $Image_min_dim --results_dir \"$QC_prediction_results\" --checkpoints_dir \"/content/\"\n","\n","#-----------------------------------------------------------------------------------\n","\n","#Here we need to move the data again and remove all the unnecessary folders\n","\n"," Checkpoint_name = \"test_\"+str(checkpoints)\n","\n"," QC_results_images = QC_prediction_results+\"/\"+QC_model_name+\"/\"+Checkpoint_name+\"/images\"\n","\n"," QC_results_images_files = os.listdir(QC_results_images)\n","\n"," for f in QC_results_images_files: \n"," shutil.copyfile(QC_results_images+\"/\"+f, QC_prediction_results+\"/\"+f)\n","\n"," os.chdir(\"/content\") \n","\n"," #Here we clean up the extra files\n"," shutil.rmtree(QC_prediction_results+\"/\"+QC_model_name)\n","\n","\n","#-------------------------------- QC for RGB ------------------------------------\n"," if Image_type == \"RGB\":\n","# List images in Source_QC_folder\n","# This will find the image dimension of a randomly choosen image in Source_QC_folder \n"," random_choice = random.choice(os.listdir(Source_QC_folder))\n"," x = imageio.imread(Source_QC_folder+\"/\"+random_choice)\n","\n"," def ssim(img1, img2):\n"," return structural_similarity(img1,img2,data_range=1.,full=True, multichannel=True)\n","\n","# Open and create the csv file that will contain all the QC metrics\n"," with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\"])\n"," \n"," \n"," # Initiate list\n"," ssim_score_list = [] \n","\n","\n"," # Let's loop through the provided dataset in the QC folders\n","\n","\n"," for i in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder,i)):\n"," print('Running QC on: '+i)\n","\n"," shortname_no_PNG = i[:-4]\n"," \n"," # -------------------------------- Target test data (Ground truth) --------------------------------\n"," test_GT = imageio.imread(os.path.join(Target_QC_folder, i), as_gray=False, pilmode=\"RGB\")\n","\n"," # -------------------------------- Source test data --------------------------------\n"," test_source = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_real.png\"))\n"," \n"," \n"," # -------------------------------- Prediction --------------------------------\n"," \n"," test_prediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_fake.png\"))\n"," \n"," #--------------------------- Here we normalise using histograms matching--------------------------------\n"," test_prediction_matched = match_histograms(test_prediction, test_GT, multichannel=True)\n"," test_source_matched = match_histograms(test_source, test_GT, multichannel=True)\n"," \n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\n","\n"," # Calculate the SSIM maps\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT, test_prediction_matched)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT, test_source_matched)\n","\n"," ssim_score_list.append(index_SSIM_GTvsPrediction)\n","\n"," #Save ssim_maps\n"," img_SSIM_GTvsPrediction_8bit = (img_SSIM_GTvsPrediction* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsPrediction_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsPrediction_8bit)\n"," img_SSIM_GTvsSource_8bit = (img_SSIM_GTvsSource* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsSource_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsSource_8bit)\n"," \n"," \n"," writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource)])\n","\n"," #Here we calculate the ssim average for each image in each checkpoints\n","\n"," Average_SSIM_checkpoint = Average(ssim_score_list)\n"," Average_ssim_score_list.append(Average_SSIM_checkpoint)\n","\n","\n","\n","\n","#------------------------------------------- QC for Grayscale ----------------------------------------------\n","\n"," if Image_type == \"Grayscale\":\n"," def ssim(img1, img2):\n"," return structural_similarity(img1,img2,data_range=1.,full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n","\n"," def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n","\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n"," def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n"," def norm_minmse(gt, x, normalize_gt=True):\n"," \n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","# Open and create the csv file that will contain all the QC metrics\n"," with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n","\n"," \n"," \n"," # Let's loop through the provided dataset in the QC folders\n","\n","\n"," for i in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder,i)):\n"," print('Running QC on: '+i)\n","\n"," ssim_score_list = []\n"," shortname_no_PNG = i[:-4]\n"," # -------------------------------- Target test data (Ground truth) --------------------------------\n"," test_GT_raw = imageio.imread(os.path.join(Target_QC_folder, i), as_gray=False, pilmode=\"RGB\")\n"," \n"," test_GT = test_GT_raw[:,:,2]\n","\n"," # -------------------------------- Source test data --------------------------------\n"," test_source_raw = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_real.png\"))\n"," \n"," test_source = test_source_raw[:,:,2]\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and Source image\n"," test_GT_norm,test_source_norm = norm_minmse(test_GT, test_source, normalize_gt=True)\n","\n"," # -------------------------------- Prediction --------------------------------\n"," test_prediction_raw = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_fake.png\"))\n"," \n"," test_prediction = test_prediction_raw[:,:,2]\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and prediction\n"," test_GT_norm,test_prediction_norm = norm_minmse(test_GT, test_prediction, normalize_gt=True) \n","\n","\n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\n","\n"," # Calculate the SSIM maps\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT_norm, test_prediction_norm)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT_norm, test_source_norm)\n","\n"," ssim_score_list.append(index_SSIM_GTvsPrediction)\n","\n"," #Save ssim_maps\n"," \n"," img_SSIM_GTvsPrediction_8bit = (img_SSIM_GTvsPrediction* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsPrediction_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsPrediction_8bit)\n"," img_SSIM_GTvsSource_8bit = (img_SSIM_GTvsSource* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsSource_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsSource_8bit)\n"," \n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n","\n"," # Save SE maps\n"," img_RSE_GTvsPrediction_8bit = (img_RSE_GTvsPrediction* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/RSE_GTvsPrediction_\"+shortname_no_PNG+'.tif',img_RSE_GTvsPrediction_8bit)\n"," img_RSE_GTvsSource_8bit = (img_RSE_GTvsSource* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/RSE_GTvsSource_\"+shortname_no_PNG+'.tif',img_RSE_GTvsSource_8bit)\n","\n","\n"," # -------------------------------- Calculate the RSE metrics and save them --------------------------------\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n"," \n"," # We can also measure the peak signal to noise ratio between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n","\n"," writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource),str(PSNR_GTvsPrediction),str(PSNR_GTvsSource)])\n","\n"," #Here we calculate the ssim average for each image in each checkpoints\n","\n"," Average_SSIM_checkpoint = Average(ssim_score_list)\n"," Average_ssim_score_list.append(Average_SSIM_checkpoint)\n","\n","\n","# All data is now processed saved\n"," \n","\n","# -------------------------------- Display --------------------------------\n","\n","# Display the IoV vs Threshold plot\n","plt.figure(figsize=(20,5))\n","plt.plot(Checkpoint_list, Average_ssim_score_list, label=\"SSIM\")\n","plt.title('Checkpoints vs. SSIM')\n","plt.ylabel('SSIM')\n","plt.xlabel('Checkpoints')\n","plt.legend()\n","plt.savefig(full_QC_model_path+'Quality Control/SSIMvsCheckpoint_data.png',bbox_inches='tight',pad_inches=0)\n","plt.show()\n","\n","\n","\n","# -------------------------------- Display RGB --------------------------------\n","\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","\n","\n","if Image_type == \"RGB\":\n"," random_choice_shortname_no_PNG = shortname_no_PNG\n","\n"," @interact\n"," def show_results(file=os.listdir(Source_QC_folder), checkpoints=Checkpoint_list):\n","\n"," random_choice_shortname_no_PNG = file[:-4]\n","\n"," df1 = pd.read_csv(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", header=0)\n"," df2 = df1.set_index(\"image #\", drop = False)\n"," index_SSIM_GTvsPrediction = df2.loc[file, \"Prediction v. GT mSSIM\"]\n"," index_SSIM_GTvsSource = df2.loc[file, \"Input v. GT mSSIM\"]\n","\n","#Setting up colours\n"," \n"," cmap = None\n","\n"," plt.figure(figsize=(10,10))\n","\n","# Target (Ground-truth)\n"," plt.subplot(3,3,1)\n"," plt.axis('off')\n"," img_GT = imageio.imread(os.path.join(Target_QC_folder, file), as_gray=False, pilmode=\"RGB\")\n"," plt.imshow(img_GT, cmap = cmap)\n"," plt.title('Target',fontsize=15)\n","\n","# Source\n"," plt.subplot(3,3,2)\n"," plt.axis('off')\n"," img_Source = imageio.imread(os.path.join(Source_QC_folder, file), as_gray=False, pilmode=\"RGB\")\n"," plt.imshow(img_Source, cmap = cmap)\n"," plt.title('Source',fontsize=15)\n","\n","#Prediction\n"," plt.subplot(3,3,3)\n"," plt.axis('off')\n","\n"," img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_fake.png\"))\n","\n"," plt.imshow(img_Prediction, cmap = cmap)\n"," plt.title('Prediction',fontsize=15)\n","\n","\n","#SSIM between GT and Source\n"," plt.subplot(3,3,5)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","\n"," img_SSIM_GTvsSource = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsSource_\"+random_choice_shortname_no_PNG+\".tif\"))\n","\n"," imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n","#plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Source',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\n"," plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n"," plt.subplot(3,3,6)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","\n"," img_SSIM_GTvsPrediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\n","\n"," imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n","#plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\n"," plt.savefig(full_QC_model_path+'Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","\n","# -------------------------------- Display Grayscale --------------------------------\n","\n","if Image_type == \"Grayscale\":\n"," random_choice_shortname_no_PNG = shortname_no_PNG\n","\n"," @interact\n"," def show_results(file=os.listdir(Source_QC_folder), checkpoints=Checkpoint_list):\n","\n"," random_choice_shortname_no_PNG = file[:-4]\n","\n"," df1 = pd.read_csv(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", header=0)\n"," df2 = df1.set_index(\"image #\", drop = False)\n"," index_SSIM_GTvsPrediction = df2.loc[file, \"Prediction v. GT mSSIM\"]\n"," index_SSIM_GTvsSource = df2.loc[file, \"Input v. GT mSSIM\"]\n","\n"," NRMSE_GTvsPrediction = df2.loc[file, \"Prediction v. GT NRMSE\"]\n"," NRMSE_GTvsSource = df2.loc[file, \"Input v. GT NRMSE\"]\n"," PSNR_GTvsSource = df2.loc[file, \"Input v. GT PSNR\"]\n"," PSNR_GTvsPrediction = df2.loc[file, \"Prediction v. GT PSNR\"]\n"," \n","\n"," plt.figure(figsize=(15,15))\n","\n"," cmap = None\n"," \n"," # Target (Ground-truth)\n"," plt.subplot(3,3,1)\n"," plt.axis('off')\n"," img_GT = imageio.imread(os.path.join(Target_QC_folder, file), as_gray=True, pilmode=\"RGB\")\n","\n"," plt.imshow(img_GT, norm=simple_norm(img_GT, percent = 99), cmap = 'gray')\n"," plt.title('Target',fontsize=15)\n","\n","# Source\n"," plt.subplot(3,3,2)\n"," plt.axis('off')\n"," img_Source = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_real.png\"))\n"," plt.imshow(img_Source, norm=simple_norm(img_Source, percent = 99))\n"," plt.title('Source',fontsize=15)\n","\n","#Prediction\n"," plt.subplot(3,3,3)\n"," plt.axis('off')\n"," img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_fake.png\"))\n"," plt.imshow(img_Prediction, norm=simple_norm(img_Prediction, percent = 99))\n"," plt.title('Prediction',fontsize=15)\n","\n","#Setting up colours\n"," cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Source\n"," plt.subplot(3,3,5)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","\n"," img_SSIM_GTvsSource = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsSource_\"+random_choice_shortname_no_PNG+\".tif\"))\n"," img_SSIM_GTvsSource = img_SSIM_GTvsSource / 255\n"," imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n","\n"," \n"," plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Source',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\n"," plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n"," plt.subplot(3,3,6)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n"," \n"," \n"," img_SSIM_GTvsPrediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\n"," img_SSIM_GTvsPrediction = img_SSIM_GTvsPrediction / 255\n"," imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n","\n"," \n"," plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\n","\n","#Root Squared Error between GT and Source\n"," plt.subplot(3,3,8)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","\n"," img_RSE_GTvsSource = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"RSE_GTvsSource_\"+random_choice_shortname_no_PNG+\".tif\"))\n"," img_RSE_GTvsSource = img_RSE_GTvsSource / 255\n"," \n","\n"," imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource, cmap = cmap, vmin=0, vmax = 1)\n"," plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n"," plt.title('Target vs. Source',fontsize=15)\n"," plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsSource,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n","#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\n"," plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#Root Squared Error between GT and Prediction\n"," plt.subplot(3,3,9)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","\n"," img_RSE_GTvsPrediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"RSE_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\n","\n"," img_RSE_GTvsPrediction = img_RSE_GTvsPrediction / 255\n","\n"," imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n"," plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsPrediction,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n"," plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","\n","\n","#Make a pdf summary of the QC results\n","\n","qc_pdf_export()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-tJeeJjLnRkP"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"d8wuQGjoq6eN"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as PNG images.\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images.\n","\n","**`checkpoint`:** Choose the checkpoint number you would like to use to perform predictions. To use the \"latest\" checkpoint, input \"latest\"."]},{"cell_type":"code","metadata":{"id":"y2TD5p7MZrEb","cellView":"form"},"source":["#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then play the cell to predict outputs from your unseen images.\n","\n","import glob\n","import os.path\n","\n","\n","latest = \"latest\"\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#@markdown ###What model checkpoint would you like to use?\n","\n","checkpoint = latest#@param {type:\"raw\"}\n","\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","#here we check if we use the newly trained network or not\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","#here we check if the model exists\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","# Here we check that checkpoint exist, if not the closest one will be chosen \n","\n","Nb_Checkpoint = len(glob.glob(os.path.join(full_Prediction_model_path, '*G_A.pth')))\n","print(Nb_Checkpoint)\n","\n","\n","if not checkpoint == \"latest\":\n","\n"," if checkpoint < 10:\n"," checkpoint = 5\n","\n"," if not checkpoint % 5 == 0:\n"," checkpoint = ((int(checkpoint / 5)-1) * 5)\n"," print (bcolors.WARNING + \" Your chosen checkpoints is not divisible by 5; therefore the checkpoints chosen is now:\",checkpoints)\n"," \n"," if checkpoint > Nb_Checkpoint*5:\n"," checkpoint = \"latest\"\n","\n"," if checkpoint == Nb_Checkpoint*5:\n"," checkpoint = \"latest\"\n","\n","\n","\n","\n","# Here we need to move the data to be analysed so that cycleGAN can find them\n","\n","Saving_path_prediction= \"/content/\"+Prediction_model_name\n","\n","if os.path.exists(Saving_path_prediction):\n"," shutil.rmtree(Saving_path_prediction)\n","os.makedirs(Saving_path_prediction)\n","\n","Saving_path_Data_folder = Saving_path_prediction+\"/testA\"\n","\n","if os.path.exists(Saving_path_Data_folder):\n"," shutil.rmtree(Saving_path_Data_folder)\n","os.makedirs(Saving_path_Data_folder)\n","\n","for files in os.listdir(Data_folder):\n"," shutil.copyfile(Data_folder+\"/\"+files, Saving_path_Data_folder+\"/\"+files)\n","\n","\n","Nb_files_Data_folder = len(os.listdir(Data_folder)) +10\n","\n","\n","\n","#Here we copy and rename the checkpoint to be used\n","\n","shutil.copyfile(full_Prediction_model_path+\"/\"+str(checkpoint)+\"_net_G_A.pth\", full_Prediction_model_path+\"/\"+str(checkpoint)+\"_net_G.pth\")\n","\n","\n","# This will find the image dimension of a randomly choosen image in Data_folder \n","random_choice = random.choice(os.listdir(Data_folder))\n","x = imageio.imread(Data_folder+\"/\"+random_choice)\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","Image_min_dim = min(Image_Y, Image_X)\n","\n","print(Image_min_dim)\n","\n","\n","\n","#-------------------------------- Perform predictions -----------------------------\n","\n","#-------------------------------- Options that can be used to perform predictions -----------------------------\n","\n","# basic parameters\n"," #('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')\n"," #('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')\n"," #('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')\n"," #('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')\n","\n","# model parameters\n"," #('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]')\n"," #('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')\n"," #('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')\n"," #('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')\n"," #('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')\n"," #('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')\n"," #('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]')\n"," #('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')\n"," #('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]')\n"," #('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]')\n"," #('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')\n"," #('--no_dropout', action='store_true', help='no dropout for the generator')\n"," \n","# dataset parameters\n"," #('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')\n"," #('--direction', type=str, default='AtoB', help='AtoB or BtoA')\n"," #('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')\n"," #('--num_threads', default=4, type=int, help='# threads for loading data')\n"," #('--batch_size', type=int, default=1, help='input batch size')\n"," #('--load_size', type=int, default=286, help='scale images to this size')\n"," #('--crop_size', type=int, default=256, help='then crop to this size')\n"," #('--max_dataset_size', type=int, default=float(\"inf\"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')\n"," #('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')\n"," #('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')\n"," #('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')\n"," \n","# additional parameters\n"," #('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')\n"," #('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')\n"," #('--verbose', action='store_true', help='if specified, print more debugging information')\n"," #('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')\n"," \n","\n"," #('--ntest', type=int, default=float(\"inf\"), help='# of test examples.')\n"," #('--results_dir', type=str, default='./results/', help='saves results here.')\n"," #('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')\n"," #('--phase', type=str, default='test', help='train, val, test, etc')\n","\n","# Dropout and Batchnorm has different behavioir during training and test.\n"," #('--eval', action='store_true', help='use eval mode during test time.')\n"," #('--num_test', type=int, default=50, help='how many test images to run')\n"," # rewrite devalue values\n"," \n","# To avoid cropping, the load_size should be the same as crop_size\n"," #parser.set_defaults(load_size=parser.get_default('crop_size'))\n","\n","#------------------------------------------------------------------------\n","\n","\n","#---------------------------- Predictions are performed here ----------------------\n","\n","os.chdir(\"/content\")\n","\n","!python pytorch-CycleGAN-and-pix2pix/test.py --dataroot \"$Saving_path_Data_folder\" --name \"$Prediction_model_name\" --model test --no_dropout --preprocess scale_width --load_size $Image_min_dim --crop_size $Image_min_dim --results_dir \"$Result_folder\" --checkpoints_dir \"$Prediction_model_path\" --num_test $Nb_files_Data_folder --epoch $checkpoint\n","\n","#-----------------------------------------------------------------------------------\n","\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"SXqS_EhByhQ7"},"source":["## **6.2. Inspect the predicted output**\n","---\n","\n"]},{"cell_type":"code","metadata":{"cellView":"form","id":"64emoATwylxM"},"source":["# @markdown ##Run this cell to display a randomly chosen input and its corresponding predicted output.\n","import os\n","# This will display a randomly chosen dataset input and predicted output\n","random_choice = random.choice(os.listdir(Data_folder))\n","\n","\n","random_choice_no_extension = os.path.splitext(random_choice)\n","\n","\n","x = imageio.imread(Result_folder+\"/\"+Prediction_model_name+\"/test_\"+str(checkpoint)+\"/images/\"+random_choice_no_extension[0]+\"_real.png\")\n","\n","\n","y = imageio.imread(Result_folder+\"/\"+Prediction_model_name+\"/test_\"+str(checkpoint)+\"/images/\"+random_choice_no_extension[0]+\"_fake.png\")\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest')\n","plt.title('Input')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(y, interpolation='nearest')\n","plt.title('Prediction')\n","plt.axis('off');\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"pE8vQZ7RWY_L"},"source":["# **7. Version log**\n","---\n","**v1.13**: \n","\n","\n","* This version now includes an automatic restart allowing to set the h5py library to v2.10.\n","* The section 1 and 2 are now swapped for better export of *requirements.txt*.\n","This version also now includes built-in version check and the version log that \n","\n","* This version also now includes built-in version check and the version log that you're reading now.\n","\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"UvSlTaH14s3t"},"source":["\n","#**Thank you for using CycleGAN!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/Deep-STORM_2D_ZeroCostDL4Mic.ipynb b/Colab_notebooks/Deep-STORM_2D_ZeroCostDL4Mic.ipynb index d28f89e1..d0708a05 100644 --- a/Colab_notebooks/Deep-STORM_2D_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/Deep-STORM_2D_ZeroCostDL4Mic.ipynb @@ -1,3110 +1 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "FpCtYevLHfl4" - }, - "source": [ - "# **Deep-STORM (2D)**\n", - "\n", - "---\n", - "\n", - "Deep-STORM is a neural network capable of image reconstruction from high-density single-molecule localization microscopy (SMLM), first published in 2018 by [Nehme *et al.* in Optica](https://www.osapublishing.org/optica/abstract.cfm?uri=optica-5-4-458). The architecture used here is a U-Net based network without skip connections. This network allows image reconstruction of 2D super-resolution images, in a supervised training manner. The network is trained using simulated high-density SMLM data for which the ground-truth is available. These simulations are obtained from random distribution of single molecules in a field-of-view and therefore do not imprint structural priors during training. The network output a super-resolution image with increased pixel density (typically upsampling factor of 8 in each dimension).\n", - "\n", - "Deep-STORM has **two key advantages**:\n", - "- SMLM reconstruction at high density of emitters\n", - "- fast prediction (reconstruction) once the model is trained appropriately, compared to more common multi-emitter fitting processes.\n", - "\n", - "\n", - "---\n", - "\n", - "*Disclaimer*:\n", - "\n", - "This notebook is part of the *Zero-Cost Deep-Learning to Enhance Microscopy* project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n", - "\n", - "This notebook is based on the following paper: \n", - "\n", - "**Deep-STORM: super-resolution single-molecule microscopy by deep learning**, Optica (2018) by *Elias Nehme, Lucien E. Weiss, Tomer Michaeli, and Yoav Shechtman* (https://www.osapublishing.org/optica/abstract.cfm?uri=optica-5-4-458)\n", - "\n", - "And source code found in: /~https://github.com/EliasNehme/Deep-STORM\n", - "\n", - "\n", - "**Please also cite this original paper when using or developing this notebook.**" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "wyzTn3IcHq6Y" - }, - "source": [ - "# **How to use this notebook?**\n", - "\n", - "---\n", - "\n", - "Video describing how to use our notebooks are available on youtube:\n", - " - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n", - " - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n", - "\n", - "\n", - "---\n", - "###**Structure of a notebook**\n", - "\n", - "The notebook contains two types of cell: \n", - "\n", - "**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n", - "\n", - "**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n", - "\n", - "---\n", - "###**Table of contents, Code snippets** and **Files**\n", - "\n", - "On the top left side of the notebook you find three tabs which contain from top to bottom:\n", - "\n", - "*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n", - "\n", - "*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n", - "\n", - "*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n", - "\n", - "**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n", - "\n", - "**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n", - "\n", - "---\n", - "###**Making changes to the notebook**\n", - "\n", - "**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n", - "\n", - "To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n", - "You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "bEy4EBXHHyAX" - }, - "source": [ - "#**0. Before getting started**\n", - "---\n", - " Deep-STORM is able to train on simulated dataset of SMLM data (see https://www.osapublishing.org/optica/abstract.cfm?uri=optica-5-4-458 for more info). Here, we provide a simulator that will generate training dataset (section 3.1.b). A few parameters will allow you to match the simulation to your experimental data. Similarly to what is described in the paper, simulations obtained from ThunderSTORM can also be loaded here (section 3.1.a).\n", - "\n", - "---\n", - "**Important note**\n", - "\n", - "- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n", - "\n", - "- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n", - "\n", - "- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "E04mOlG_H5Tz" - }, - "source": [ - "# **1. Initialise the Colab session**\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "F_tjlGzsH-Dn" - }, - "source": [ - "\n", - "## **1.1. Check for GPU access**\n", - "---\n", - "\n", - "By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n", - "\n", - "Go to **Runtime -> Change the Runtime type**\n", - "\n", - "**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n", - "\n", - "**Accelerator: GPU** *(Graphics processing unit)*\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "gn-LaaNNICqL" - }, - "outputs": [], - "source": [ - "#@markdown ##Run this cell to check if you have GPU access\n", - "# %tensorflow_version 1.x\n", - "\n", - "import tensorflow as tf\n", - "# if tf.__version__ != '2.2.0':\n", - "# !pip install tensorflow==2.2.0\n", - "\n", - "if tf.test.gpu_device_name()=='':\n", - " print('You do not have GPU access.') \n", - " print('Did you change your runtime ?') \n", - " print('If the runtime settings are correct then Google did not allocate GPU to your session')\n", - " print('Expect slow performance. To access GPU try reconnecting later')\n", - "\n", - "else:\n", - " print('You have GPU access')\n", - " !nvidia-smi\n", - "\n", - "# from tensorflow.python.client import device_lib \n", - "# device_lib.list_local_devices()\n", - "\n", - "# print the tensorflow version\n", - "print('Tensorflow version is ' + str(tf.__version__))\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tnP7wM79IKW-" - }, - "source": [ - "## **1.2. Mount your Google Drive**\n", - "---\n", - " To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n", - "\n", - " Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n", - "\n", - " Once this is done, your data are available in the **Files** tab on the top left of notebook." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "1R-7Fo34_gOd" - }, - "outputs": [], - "source": [ - "#@markdown ##Run this cell to connect your Google Drive to Colab\n", - "\n", - "#@markdown * Click on the URL. \n", - "\n", - "#@markdown * Sign in your Google Account. \n", - "\n", - "#@markdown * Copy the authorization code. \n", - "\n", - "#@markdown * Enter the authorization code. \n", - "\n", - "#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n", - "\n", - "#mounts user's Google Drive to Google Colab.\n", - "\n", - "from google.colab import drive\n", - "drive.mount('/content/gdrive')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jRnQZWSZhArJ" - }, - "source": [ - "# **2. Install Deep-STORM and dependencies**\n", - "---\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "kSrZMo3X_NhO" - }, - "outputs": [], - "source": [ - "Notebook_version = ['1.12']\n", - "\n", - "#@markdown ##Install Deep-STORM and dependencies\n", - "\n", - "\n", - "# %% Model definition + helper functions\n", - "\n", - "!pip install fpdf\n", - "\n", - "import sys\n", - "before = [str(m) for m in sys.modules]\n", - "\n", - "# Import keras modules and libraries\n", - "from tensorflow import keras\n", - "from tensorflow.keras.models import Model\n", - "from tensorflow.keras.layers import Input, Activation, UpSampling2D, Convolution2D, MaxPooling2D, BatchNormalization, Layer\n", - "from tensorflow.keras.callbacks import Callback\n", - "from tensorflow.keras import backend as K\n", - "from tensorflow.keras import optimizers, losses\n", - "\n", - "from tensorflow.keras.preprocessing.image import ImageDataGenerator\n", - "from tensorflow.keras.callbacks import ModelCheckpoint\n", - "from tensorflow.keras.callbacks import ReduceLROnPlateau\n", - "from skimage.transform import warp\n", - "from skimage.transform import SimilarityTransform\n", - "from skimage.metrics import structural_similarity\n", - "from skimage.metrics import peak_signal_noise_ratio as psnr\n", - "from scipy.signal import fftconvolve\n", - "\n", - "# Import common libraries\n", - "import tensorflow as tf\n", - "import numpy as np\n", - "import pandas as pd\n", - "import matplotlib.pyplot as plt\n", - "import h5py\n", - "import scipy.io as sio\n", - "from os.path import abspath\n", - "from sklearn.model_selection import train_test_split\n", - "from skimage import io\n", - "import time\n", - "import os\n", - "import shutil\n", - "import csv\n", - "from PIL import Image \n", - "from PIL.TiffTags import TAGS\n", - "from scipy.ndimage import gaussian_filter\n", - "import math\n", - "from astropy.visualization import simple_norm\n", - "from sys import getsizeof\n", - "from fpdf import FPDF, HTMLMixin\n", - "from pip._internal.operations.freeze import freeze\n", - "import subprocess\n", - "from datetime import datetime\n", - "\n", - "\n", - "# For sliders and dropdown menu, progress bar\n", - "from ipywidgets import interact\n", - "import ipywidgets as widgets\n", - "from tqdm import tqdm\n", - "\n", - "# For Multi-threading in simulation\n", - "from numba import njit, prange\n", - "\n", - "\n", - "\n", - "# define a function that projects and rescales an image to the range [0,1]\n", - "def project_01(im):\n", - " im = np.squeeze(im)\n", - " min_val = im.min()\n", - " max_val = im.max()\n", - " return (im - min_val)/(max_val - min_val)\n", - "\n", - "# normalize image given mean and std\n", - "def normalize_im(im, dmean, dstd):\n", - " im = np.squeeze(im)\n", - " im_norm = np.zeros(im.shape,dtype=np.float32)\n", - " im_norm = (im - dmean)/dstd\n", - " return im_norm\n", - "\n", - "# Define the loss history recorder\n", - "class LossHistory(Callback):\n", - " def on_train_begin(self, logs={}):\n", - " self.losses = []\n", - "\n", - " def on_batch_end(self, batch, logs={}):\n", - " self.losses.append(logs.get('loss'))\n", - " \n", - "# Define a matlab like gaussian 2D filter\n", - "def matlab_style_gauss2D(shape=(7,7),sigma=1):\n", - " \"\"\" \n", - " 2D gaussian filter - should give the same result as:\n", - " MATLAB's fspecial('gaussian',[shape],[sigma]) \n", - " \"\"\"\n", - " m,n = [(ss-1.)/2. for ss in shape]\n", - " y,x = np.ogrid[-m:m+1,-n:n+1]\n", - " h = np.exp( -(x*x + y*y) / (2.*sigma*sigma) )\n", - " h.astype(dtype=K.floatx())\n", - " h[ h < np.finfo(h.dtype).eps*h.max() ] = 0\n", - " sumh = h.sum()\n", - " if sumh != 0:\n", - " h /= sumh\n", - " h = h*2.0\n", - " h = h.astype('float32')\n", - " return h\n", - "\n", - "# Expand the filter dimensions\n", - "psf_heatmap = matlab_style_gauss2D(shape = (7,7),sigma=1)\n", - "gfilter = tf.reshape(psf_heatmap, [7, 7, 1, 1])\n", - "\n", - "# Combined MSE + L1 loss\n", - "def L1L2loss(input_shape):\n", - " def bump_mse(heatmap_true, spikes_pred):\n", - "\n", - " # generate the heatmap corresponding to the predicted spikes\n", - " heatmap_pred = K.conv2d(spikes_pred, gfilter, strides=(1, 1), padding='same')\n", - "\n", - " # heatmaps MSE\n", - " loss_heatmaps = losses.mean_squared_error(heatmap_true,heatmap_pred)\n", - "\n", - " # l1 on the predicted spikes\n", - " loss_spikes = losses.mean_absolute_error(spikes_pred,tf.zeros(input_shape))\n", - " return loss_heatmaps + loss_spikes\n", - " return bump_mse\n", - "\n", - "# Define the concatenated conv2, batch normalization, and relu block\n", - "def conv_bn_relu(nb_filter, rk, ck, name):\n", - " def f(input):\n", - " conv = Convolution2D(nb_filter, kernel_size=(rk, ck), strides=(1,1),\\\n", - " padding=\"same\", use_bias=False,\\\n", - " kernel_initializer=\"Orthogonal\",name='conv-'+name)(input)\n", - " conv_norm = BatchNormalization(name='BN-'+name)(conv)\n", - " conv_norm_relu = Activation(activation = \"relu\",name='Relu-'+name)(conv_norm)\n", - " return conv_norm_relu\n", - " return f\n", - "\n", - "# Define the model architechture\n", - "def CNN(input,names):\n", - " Features1 = conv_bn_relu(32,3,3,names+'F1')(input)\n", - " pool1 = MaxPooling2D(pool_size=(2,2),name=names+'Pool1')(Features1)\n", - " Features2 = conv_bn_relu(64,3,3,names+'F2')(pool1)\n", - " pool2 = MaxPooling2D(pool_size=(2, 2),name=names+'Pool2')(Features2)\n", - " Features3 = conv_bn_relu(128,3,3,names+'F3')(pool2)\n", - " pool3 = MaxPooling2D(pool_size=(2, 2),name=names+'Pool3')(Features3)\n", - " Features4 = conv_bn_relu(512,3,3,names+'F4')(pool3)\n", - " up5 = UpSampling2D(size=(2, 2),name=names+'Upsample1')(Features4)\n", - " Features5 = conv_bn_relu(128,3,3,names+'F5')(up5)\n", - " up6 = UpSampling2D(size=(2, 2),name=names+'Upsample2')(Features5)\n", - " Features6 = conv_bn_relu(64,3,3,names+'F6')(up6)\n", - " up7 = UpSampling2D(size=(2, 2),name=names+'Upsample3')(Features6)\n", - " Features7 = conv_bn_relu(32,3,3,names+'F7')(up7)\n", - " return Features7\n", - "\n", - "# Define the Model building for an arbitrary input size\n", - "def buildModel(input_dim, initial_learning_rate = 0.001):\n", - " input_ = Input (shape = (input_dim))\n", - " act_ = CNN (input_,'CNN')\n", - " density_pred = Convolution2D(1, kernel_size=(1, 1), strides=(1, 1), padding=\"same\",\\\n", - " activation=\"linear\", use_bias = False,\\\n", - " kernel_initializer=\"Orthogonal\",name='Prediction')(act_)\n", - " model = Model (inputs= input_, outputs=density_pred)\n", - " opt = optimizers.Adam(lr = initial_learning_rate)\n", - " model.compile(optimizer=opt, loss = L1L2loss(input_dim))\n", - " return model\n", - "\n", - "\n", - "# define a function that trains a model for a given data SNR and density\n", - "def train_model(patches, heatmaps, modelPath, epochs, steps_per_epoch, batch_size, upsampling_factor=8, validation_split = 0.3, initial_learning_rate = 0.001, pretrained_model_path = '', L2_weighting_factor = 100):\n", - " \n", - " \"\"\"\n", - " This function trains a CNN model on the desired training set, given the \n", - " upsampled training images and labels generated in MATLAB.\n", - " \n", - " # Inputs\n", - " # TO UPDATE ----------\n", - "\n", - " # Outputs\n", - " function saves the weights of the trained model to a hdf5, and the \n", - " normalization factors to a mat file. These will be loaded later for testing \n", - " the model in test_model. \n", - " \"\"\"\n", - " \n", - " # for reproducibility\n", - " np.random.seed(123)\n", - "\n", - " X_train, X_test, y_train, y_test = train_test_split(patches, heatmaps, test_size = validation_split, random_state=42)\n", - " print('Number of training examples: %d' % X_train.shape[0])\n", - " print('Number of validation examples: %d' % X_test.shape[0])\n", - " \n", - " # Setting type\n", - " X_train = X_train.astype('float32')\n", - " X_test = X_test.astype('float32')\n", - " y_train = y_train.astype('float32')\n", - " y_test = y_test.astype('float32')\n", - "\n", - " \n", - " #===================== Training set normalization ==========================\n", - " # normalize training images to be in the range [0,1] and calculate the \n", - " # training set mean and std\n", - " mean_train = np.zeros(X_train.shape[0],dtype=np.float32)\n", - " std_train = np.zeros(X_train.shape[0], dtype=np.float32)\n", - " for i in range(X_train.shape[0]):\n", - " X_train[i, :, :] = project_01(X_train[i, :, :])\n", - " mean_train[i] = X_train[i, :, :].mean()\n", - " std_train[i] = X_train[i, :, :].std()\n", - "\n", - " # resulting normalized training images\n", - " mean_val_train = mean_train.mean()\n", - " std_val_train = std_train.mean()\n", - " X_train_norm = np.zeros(X_train.shape, dtype=np.float32)\n", - " for i in range(X_train.shape[0]):\n", - " X_train_norm[i, :, :] = normalize_im(X_train[i, :, :], mean_val_train, std_val_train)\n", - " \n", - " # patch size\n", - " psize = X_train_norm.shape[1]\n", - "\n", - " # Reshaping\n", - " X_train_norm = X_train_norm.reshape(X_train.shape[0], psize, psize, 1)\n", - "\n", - " # ===================== Test set normalization ==========================\n", - " # normalize test images to be in the range [0,1] and calculate the test set \n", - " # mean and std\n", - " mean_test = np.zeros(X_test.shape[0],dtype=np.float32)\n", - " std_test = np.zeros(X_test.shape[0], dtype=np.float32)\n", - " for i in range(X_test.shape[0]):\n", - " X_test[i, :, :] = project_01(X_test[i, :, :])\n", - " mean_test[i] = X_test[i, :, :].mean()\n", - " std_test[i] = X_test[i, :, :].std()\n", - "\n", - " # resulting normalized test images\n", - " mean_val_test = mean_test.mean()\n", - " std_val_test = std_test.mean()\n", - " X_test_norm = np.zeros(X_test.shape, dtype=np.float32)\n", - " for i in range(X_test.shape[0]):\n", - " X_test_norm[i, :, :] = normalize_im(X_test[i, :, :], mean_val_test, std_val_test)\n", - " \n", - " # Reshaping\n", - " X_test_norm = X_test_norm.reshape(X_test.shape[0], psize, psize, 1)\n", - "\n", - " # Reshaping labels\n", - " Y_train = y_train.reshape(y_train.shape[0], psize, psize, 1)\n", - " Y_test = y_test.reshape(y_test.shape[0], psize, psize, 1)\n", - "\n", - " # Save datasets to a matfile to open later in matlab\n", - " mdict = {\"mean_test\": mean_val_test, \"std_test\": std_val_test, \"upsampling_factor\": upsampling_factor, \"Normalization factor\": L2_weighting_factor}\n", - " sio.savemat(os.path.join(modelPath,\"model_metadata.mat\"), mdict)\n", - "\n", - "\n", - " # Set the dimensions ordering according to tensorflow consensous\n", - " # K.set_image_dim_ordering('tf')\n", - " K.set_image_data_format('channels_last')\n", - "\n", - " # Save the model weights after each epoch if the validation loss decreased\n", - " checkpointer = ModelCheckpoint(filepath=os.path.join(modelPath,\"weights_best.hdf5\"), verbose=1,\n", - " save_best_only=True)\n", - "\n", - " # Change learning when loss reaches a plataeu\n", - " change_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5, min_lr=0.00005)\n", - " \n", - " # Model building and complitation\n", - " model = buildModel((psize, psize, 1), initial_learning_rate = initial_learning_rate)\n", - " model.summary()\n", - "\n", - " # Load pretrained model\n", - " if not pretrained_model_path:\n", - " print('Using random initial model weights.')\n", - " else:\n", - " print('Loading model weights from '+pretrained_model_path)\n", - " model.load_weights(pretrained_model_path)\n", - " \n", - " # Create an image data generator for real time data augmentation\n", - " datagen = ImageDataGenerator(\n", - " featurewise_center=False, # set input mean to 0 over the dataset\n", - " samplewise_center=False, # set each sample mean to 0\n", - " featurewise_std_normalization=False, # divide inputs by std of the dataset\n", - " samplewise_std_normalization=False, # divide each input by its std\n", - " zca_whitening=False, # apply ZCA whitening\n", - " rotation_range=0., # randomly rotate images in the range (degrees, 0 to 180)\n", - " width_shift_range=0., # randomly shift images horizontally (fraction of total width)\n", - " height_shift_range=0., # randomly shift images vertically (fraction of total height)\n", - " zoom_range=0.,\n", - " shear_range=0.,\n", - " horizontal_flip=False, # randomly flip images\n", - " vertical_flip=False, # randomly flip images\n", - " fill_mode='constant',\n", - " data_format=K.image_data_format())\n", - "\n", - " # Fit the image generator on the training data\n", - " datagen.fit(X_train_norm)\n", - " \n", - " # loss history recorder\n", - " history = LossHistory()\n", - "\n", - " # Inform user training begun\n", - " print('-------------------------------')\n", - " print('Training model...')\n", - "\n", - " # Fit model on the batches generated by datagen.flow()\n", - " train_history = model.fit_generator(datagen.flow(X_train_norm, Y_train, batch_size=batch_size), \n", - " steps_per_epoch=steps_per_epoch, epochs=epochs, verbose=1, \n", - " validation_data=(X_test_norm, Y_test), \n", - " callbacks=[history, checkpointer, change_lr]) \n", - "\n", - " # Inform user training ended\n", - " print('-------------------------------')\n", - " print('Training Complete!')\n", - " \n", - " # Save the last model\n", - " model.save(os.path.join(modelPath, 'weights_last.hdf5'))\n", - "\n", - " # convert the history.history dict to a pandas DataFrame: \n", - " lossData = pd.DataFrame(train_history.history) \n", - "\n", - " if os.path.exists(os.path.join(modelPath,\"Quality Control\")):\n", - " shutil.rmtree(os.path.join(modelPath,\"Quality Control\"))\n", - "\n", - " os.makedirs(os.path.join(modelPath,\"Quality Control\"))\n", - "\n", - " # The training evaluation.csv is saved (overwrites the Files if needed). \n", - " lossDataCSVpath = os.path.join(modelPath,\"Quality Control/training_evaluation.csv\")\n", - " with open(lossDataCSVpath, 'w') as f:\n", - " writer = csv.writer(f)\n", - " writer.writerow(['loss','val_loss','learning rate'])\n", - " for i in range(len(train_history.history['loss'])):\n", - " writer.writerow([train_history.history['loss'][i], train_history.history['val_loss'][i], train_history.history['lr'][i]])\n", - "\n", - " return\n", - "\n", - "\n", - "# Normalization functions from Martin Weigert used in CARE\n", - "def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n", - " \"\"\"This function is adapted from Martin Weigert\"\"\"\n", - " \"\"\"Percentile-based image normalization.\"\"\"\n", - "\n", - " mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n", - " ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n", - " return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n", - "\n", - "\n", - "def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n", - " \"\"\"This function is adapted from Martin Weigert\"\"\"\n", - " if dtype is not None:\n", - " x = x.astype(dtype,copy=False)\n", - " mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n", - " ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n", - " eps = dtype(eps)\n", - "\n", - " try:\n", - " import numexpr\n", - " x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n", - " except ImportError:\n", - " x = (x - mi) / ( ma - mi + eps )\n", - "\n", - " if clip:\n", - " x = np.clip(x,0,1)\n", - "\n", - " return x\n", - "\n", - "def norm_minmse(gt, x, normalize_gt=True):\n", - " \"\"\"This function is adapted from Martin Weigert\"\"\"\n", - "\n", - " \"\"\"\n", - " normalizes and affinely scales an image pair such that the MSE is minimized \n", - " \n", - " Parameters\n", - " ----------\n", - " gt: ndarray\n", - " the ground truth image \n", - " x: ndarray\n", - " the image that will be affinely scaled \n", - " normalize_gt: bool\n", - " set to True of gt image should be normalized (default)\n", - " Returns\n", - " -------\n", - " gt_scaled, x_scaled \n", - " \"\"\"\n", - " if normalize_gt:\n", - " gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n", - " x = x.astype(np.float32, copy=False) - np.mean(x)\n", - " #x = x - np.mean(x)\n", - " gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n", - " #gt = gt - np.mean(gt)\n", - " scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n", - " return gt, scale * x\n", - "\n", - "\n", - "# Multi-threaded Erf-based image construction\n", - "@njit(parallel=True)\n", - "def FromLoc2Image_Erf(xc_array, yc_array, photon_array, sigma_array, image_size = (64,64), pixel_size = 100):\n", - " w = image_size[0]\n", - " h = image_size[1]\n", - " erfImage = np.zeros((w, h))\n", - " for ij in prange(w*h):\n", - " j = int(ij/w)\n", - " i = ij - j*w\n", - " for (xc, yc, photon, sigma) in zip(xc_array, yc_array, photon_array, sigma_array):\n", - " # Don't bother if the emitter has photons <= 0 or if Sigma <= 0\n", - " if (sigma > 0) and (photon > 0):\n", - " S = sigma*math.sqrt(2)\n", - " x = i*pixel_size - xc\n", - " y = j*pixel_size - yc\n", - " # Don't bother if the emitter is further than 4 sigma from the centre of the pixel\n", - " if (x+pixel_size/2)**2 + (y+pixel_size/2)**2 < 16*sigma**2:\n", - " ErfX = math.erf((x+pixel_size)/S) - math.erf(x/S)\n", - " ErfY = math.erf((y+pixel_size)/S) - math.erf(y/S)\n", - " erfImage[j][i] += 0.25*photon*ErfX*ErfY\n", - " return erfImage\n", - "\n", - "\n", - "@njit(parallel=True)\n", - "def FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = (64,64), pixel_size = 100):\n", - " w = image_size[0]\n", - " h = image_size[1]\n", - " locImage = np.zeros((image_size[0],image_size[1]) )\n", - " n_locs = len(xc_array)\n", - "\n", - " for e in prange(n_locs):\n", - " locImage[int(max(min(round(yc_array[e]/pixel_size),w-1),0))][int(max(min(round(xc_array[e]/pixel_size),h-1),0))] += 1\n", - "\n", - " return locImage\n", - "\n", - "\n", - "\n", - "def getPixelSizeTIFFmetadata(TIFFpath, display=False):\n", - " with Image.open(TIFFpath) as img:\n", - " meta_dict = {TAGS[key] : img.tag[key] for key in img.tag.keys()}\n", - "\n", - "\n", - " # TIFF tags\n", - " # https://www.loc.gov/preservation/digital/formats/content/tiff_tags.shtml\n", - " # https://www.awaresystems.be/imaging/tiff/tifftags/resolutionunit.html\n", - " ResolutionUnit = meta_dict['ResolutionUnit'][0] # unit of resolution\n", - " width = meta_dict['ImageWidth'][0]\n", - " height = meta_dict['ImageLength'][0]\n", - "\n", - " xResolution = meta_dict['XResolution'][0] # number of pixels / ResolutionUnit\n", - "\n", - " if len(xResolution) == 1:\n", - " xResolution = xResolution[0]\n", - " elif len(xResolution) == 2:\n", - " xResolution = xResolution[0]/xResolution[1]\n", - " else:\n", - " print('Image resolution not defined.')\n", - " xResolution = 1\n", - "\n", - " if ResolutionUnit == 2:\n", - " # Units given are in inches\n", - " pixel_size = 0.025*1e9/xResolution\n", - " elif ResolutionUnit == 3:\n", - " # Units given are in cm\n", - " pixel_size = 0.01*1e9/xResolution\n", - " else: \n", - " # ResolutionUnit is therefore 1\n", - " print('Resolution unit not defined. Assuming: um')\n", - " pixel_size = 1e3/xResolution\n", - "\n", - " if display:\n", - " print('Pixel size obtained from metadata: '+str(pixel_size)+' nm')\n", - " print('Image size: '+str(width)+'x'+str(height))\n", - " \n", - " return (pixel_size, width, height)\n", - "\n", - "\n", - "def saveAsTIF(path, filename, array, pixel_size):\n", - " \"\"\"\n", - " Image saving using PIL to save as .tif format\n", - " # Input \n", - " path - path where it will be saved\n", - " filename - name of the file to save (no extension)\n", - " array - numpy array conatining the data at the required format\n", - " pixel_size - physical size of pixels in nanometers (identical for x and y)\n", - " \"\"\"\n", - "\n", - " # print('Data type: '+str(array.dtype))\n", - " if (array.dtype == np.uint16):\n", - " mode = 'I;16'\n", - " elif (array.dtype == np.uint32):\n", - " mode = 'I'\n", - " else:\n", - " mode = 'F'\n", - "\n", - " # Rounding the pixel size to the nearest number that divides exactly 1cm.\n", - " # Resolution needs to be a rational number --> see TIFF format\n", - " # pixel_size = 10000/(round(10000/pixel_size))\n", - "\n", - " if len(array.shape) == 2:\n", - " im = Image.fromarray(array)\n", - " im.save(os.path.join(path, filename+'.tif'),\n", - " mode = mode, \n", - " resolution_unit = 3,\n", - " resolution = 0.01*1e9/pixel_size)\n", - "\n", - "\n", - " elif len(array.shape) == 3:\n", - " imlist = []\n", - " for frame in array:\n", - " imlist.append(Image.fromarray(frame))\n", - "\n", - " imlist[0].save(os.path.join(path, filename+'.tif'), save_all=True,\n", - " append_images=imlist[1:],\n", - " mode = mode, \n", - " resolution_unit = 3,\n", - " resolution = 0.01*1e9/pixel_size)\n", - "\n", - " return\n", - "\n", - "\n", - "\n", - "\n", - "class Maximafinder(Layer):\n", - " def __init__(self, thresh, neighborhood_size, use_local_avg, **kwargs):\n", - " super(Maximafinder, self).__init__(**kwargs)\n", - " self.thresh = tf.constant(thresh, dtype=tf.float32)\n", - " self.nhood = neighborhood_size\n", - " self.use_local_avg = use_local_avg\n", - "\n", - " def build(self, input_shape):\n", - " if self.use_local_avg is True:\n", - " self.kernel_x = tf.reshape(tf.constant([[-1,0,1],[-1,0,1],[-1,0,1]], dtype=tf.float32), [3, 3, 1, 1])\n", - " self.kernel_y = tf.reshape(tf.constant([[-1,-1,-1],[0,0,0],[1,1,1]], dtype=tf.float32), [3, 3, 1, 1])\n", - " self.kernel_sum = tf.reshape(tf.constant([[1,1,1],[1,1,1],[1,1,1]], dtype=tf.float32), [3, 3, 1, 1])\n", - "\n", - " def call(self, inputs):\n", - "\n", - " # local maxima positions\n", - " max_pool_image = MaxPooling2D(pool_size=(self.nhood,self.nhood), strides=(1,1), padding='same')(inputs)\n", - " cond = tf.math.greater(max_pool_image, self.thresh) & tf.math.equal(max_pool_image, inputs)\n", - " indices = tf.where(cond)\n", - " bind, xind, yind = indices[:, 0], indices[:, 2], indices[:, 1]\n", - " confidence = tf.gather_nd(inputs, indices)\n", - "\n", - " # local CoG estimator\n", - " if self.use_local_avg:\n", - " x_image = K.conv2d(inputs, self.kernel_x, padding='same')\n", - " y_image = K.conv2d(inputs, self.kernel_y, padding='same')\n", - " sum_image = K.conv2d(inputs, self.kernel_sum, padding='same')\n", - " confidence = tf.cast(tf.gather_nd(sum_image, indices), dtype=tf.float32)\n", - " x_local = tf.math.divide(tf.gather_nd(x_image, indices),tf.gather_nd(sum_image, indices))\n", - " y_local = tf.math.divide(tf.gather_nd(y_image, indices),tf.gather_nd(sum_image, indices))\n", - " xind = tf.cast(xind, dtype=tf.float32) + tf.cast(x_local, dtype=tf.float32)\n", - " yind = tf.cast(yind, dtype=tf.float32) + tf.cast(y_local, dtype=tf.float32)\n", - " else:\n", - " xind = tf.cast(xind, dtype=tf.float32)\n", - " yind = tf.cast(yind, dtype=tf.float32)\n", - " \n", - " return bind, xind, yind, confidence\n", - "\n", - " def get_config(self):\n", - "\n", - " # Implement get_config to enable serialization. This is optional.\n", - " base_config = super(Maximafinder, self).get_config()\n", - " config = {}\n", - " return dict(list(base_config.items()) + list(config.items()))\n", - "\n", - "\n", - "\n", - "# ------------------------------- Prediction with postprocessing function-------------------------------\n", - "def batchFramePredictionLocalization(dataPath, filename, modelPath, savePath, batch_size=1, thresh=0.1, neighborhood_size=3, use_local_avg = False, pixel_size = None):\n", - " \"\"\"\n", - " This function tests a trained model on the desired test set, given the \n", - " tiff stack of test images, learned weights, and normalization factors.\n", - " \n", - " # Inputs\n", - " dataPath - the path to the folder containing the tiff stack(s) to run prediction on \n", - " filename - the name of the file to process\n", - " modelPath - the path to the folder containing the weights file and the mean and standard deviation file generated in train_model\n", - " savePath - the path to the folder where to save the prediction\n", - " batch_size. - the number of frames to predict on for each iteration\n", - " thresh - threshoold percentage from the maximum of the gaussian scaling\n", - " neighborhood_size - the size of the neighborhood for local maxima finding\n", - " use_local_average - Boolean whether to perform local averaging or not\n", - " \"\"\"\n", - " \n", - " # load mean and std\n", - " matfile = sio.loadmat(os.path.join(modelPath,'model_metadata.mat'))\n", - " test_mean = np.array(matfile['mean_test'])\n", - " test_std = np.array(matfile['std_test']) \n", - " upsampling_factor = np.array(matfile['upsampling_factor'])\n", - " upsampling_factor = upsampling_factor.item() # convert to scalar\n", - " L2_weighting_factor = np.array(matfile['Normalization factor'])\n", - " L2_weighting_factor = L2_weighting_factor.item() # convert to scalar\n", - "\n", - " # Read in the raw file\n", - " Images = io.imread(os.path.join(dataPath, filename))\n", - " if pixel_size == None:\n", - " pixel_size, _, _ = getPixelSizeTIFFmetadata(os.path.join(dataPath, filename), display=True)\n", - " pixel_size_hr = pixel_size/upsampling_factor\n", - "\n", - " # get dataset dimensions\n", - " (nFrames, M, N) = Images.shape\n", - " print('Input image is '+str(N)+'x'+str(M)+' with '+str(nFrames)+' frames.')\n", - "\n", - " # Build the model for a bigger image\n", - " model = buildModel((upsampling_factor*M, upsampling_factor*N, 1))\n", - "\n", - " # Load the trained weights\n", - " model.load_weights(os.path.join(modelPath,'weights_best.hdf5'))\n", - "\n", - " # add a post-processing module\n", - " max_layer = Maximafinder(thresh*L2_weighting_factor, neighborhood_size, use_local_avg)\n", - "\n", - " # Initialise the results: lists will be used to collect all the localizations\n", - " frame_number_list, x_nm_list, y_nm_list, confidence_au_list = [], [], [], []\n", - "\n", - " # Initialise the results\n", - " Prediction = np.zeros((M*upsampling_factor, N*upsampling_factor), dtype=np.float32)\n", - " Widefield = np.zeros((M, N), dtype=np.float32)\n", - "\n", - " # run model in batches\n", - " n_batches = math.ceil(nFrames/batch_size)\n", - " for b in tqdm(range(n_batches)):\n", - "\n", - " nF = min(batch_size, nFrames - b*batch_size)\n", - " Images_norm = np.zeros((nF, M, N),dtype=np.float32)\n", - " Images_upsampled = np.zeros((nF, M*upsampling_factor, N*upsampling_factor), dtype=np.float32)\n", - "\n", - " # Upsampling using a simple nearest neighbor interp and calculating - MULTI-THREAD this?\n", - " for f in range(nF):\n", - " Images_norm[f,:,:] = project_01(Images[b*batch_size+f,:,:])\n", - " Images_norm[f,:,:] = normalize_im(Images_norm[f,:,:], test_mean, test_std)\n", - " Images_upsampled[f,:,:] = np.kron(Images_norm[f,:,:], np.ones((upsampling_factor,upsampling_factor)))\n", - " Widefield += Images[b*batch_size+f,:,:]\n", - "\n", - " # Reshaping\n", - " Images_upsampled = np.expand_dims(Images_upsampled,axis=3)\n", - "\n", - " # Run prediction and local amxima finding\n", - " predicted_density = model.predict_on_batch(Images_upsampled)\n", - " predicted_density[predicted_density < 0] = 0\n", - " Prediction += predicted_density.sum(axis = 3).sum(axis = 0)\n", - "\n", - " bind, xind, yind, confidence = max_layer(predicted_density)\n", - " \n", - " # normalizing the confidence by the L2_weighting_factor\n", - " confidence /= L2_weighting_factor \n", - "\n", - " # turn indices to nms and append to the results\n", - " xind, yind = xind*pixel_size_hr, yind*pixel_size_hr\n", - " frmind = (bind.numpy() + b*batch_size + 1).tolist()\n", - " xind = xind.numpy().tolist()\n", - " yind = yind.numpy().tolist()\n", - " confidence = confidence.numpy().tolist()\n", - " frame_number_list += frmind\n", - " x_nm_list += xind\n", - " y_nm_list += yind\n", - " confidence_au_list += confidence\n", - "\n", - " # Open and create the csv file that will contain all the localizations\n", - " if use_local_avg:\n", - " ext = '_avg'\n", - " else:\n", - " ext = '_max'\n", - " with open(os.path.join(savePath, 'Localizations_' + os.path.splitext(filename)[0] + ext + '.csv'), \"w\", newline='') as file:\n", - " writer = csv.writer(file)\n", - " writer.writerow(['frame', 'x [nm]', 'y [nm]', 'confidence [a.u]'])\n", - " locs = list(zip(frame_number_list, x_nm_list, y_nm_list, confidence_au_list))\n", - " writer.writerows(locs)\n", - "\n", - " # Save the prediction and widefield image\n", - " Widefield = np.kron(Widefield, np.ones((upsampling_factor,upsampling_factor)))\n", - " Widefield = np.float32(Widefield)\n", - "\n", - " # io.imsave(os.path.join(savePath, 'Predicted_'+os.path.splitext(filename)[0]+'.tif'), Prediction)\n", - " # io.imsave(os.path.join(savePath, 'Widefield_'+os.path.splitext(filename)[0]+'.tif'), Widefield)\n", - "\n", - " saveAsTIF(savePath, 'Predicted_'+os.path.splitext(filename)[0], Prediction, pixel_size_hr)\n", - " saveAsTIF(savePath, 'Widefield_'+os.path.splitext(filename)[0], Widefield, pixel_size_hr)\n", - "\n", - "\n", - " return\n", - "\n", - "\n", - "# Colors for the warning messages\n", - "class bcolors:\n", - " WARNING = '\\033[31m'\n", - " NORMAL = '\\033[0m' # white (normal)\n", - "\n", - "\n", - "\n", - "def list_files(directory, extension):\n", - " return (f for f in os.listdir(directory) if f.endswith('.' + extension))\n", - "\n", - "\n", - "# @njit(parallel=True)\n", - "def subPixelMaxLocalization(array, method = 'CoM', patch_size = 3):\n", - " xMaxInd, yMaxInd = np.unravel_index(array.argmax(), array.shape, order='C')\n", - " centralPatch = XC[(xMaxInd-patch_size):(xMaxInd+patch_size+1),(yMaxInd-patch_size):(yMaxInd+patch_size+1)]\n", - "\n", - " if (method == 'MAX'):\n", - " x0 = xMaxInd\n", - " y0 = yMaxInd\n", - "\n", - " elif (method == 'CoM'):\n", - " x0 = 0\n", - " y0 = 0\n", - " S = 0\n", - " for xy in range(patch_size*patch_size):\n", - " y = math.floor(xy/patch_size)\n", - " x = xy - y*patch_size\n", - " x0 += x*array[x,y]\n", - " y0 += y*array[x,y]\n", - " S = array[x,y]\n", - " \n", - " x0 = x0/S - patch_size/2 + xMaxInd\n", - " y0 = y0/S - patch_size/2 + yMaxInd\n", - " \n", - " elif (method == 'Radiality'):\n", - " # Not implemented yet\n", - " x0 = xMaxInd\n", - " y0 = yMaxInd\n", - " \n", - " return (x0, y0)\n", - "\n", - "\n", - "@njit(parallel=True)\n", - "def correctDriftLocalization(xc_array, yc_array, frames, xDrift, yDrift):\n", - " n_locs = xc_array.shape[0]\n", - " xc_array_Corr = np.empty(n_locs)\n", - " yc_array_Corr = np.empty(n_locs)\n", - " \n", - " for loc in prange(n_locs):\n", - " xc_array_Corr[loc] = xc_array[loc] - xDrift[frames[loc]]\n", - " yc_array_Corr[loc] = yc_array[loc] - yDrift[frames[loc]]\n", - "\n", - " return (xc_array_Corr, yc_array_Corr)\n", - "\n", - "\n", - "print('--------------------------------')\n", - "print('DeepSTORM installation complete.')\n", - "\n", - "# Check if this is the latest version of the notebook\n", - "Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n", - "\n", - "if Notebook_version == list(Latest_notebook_version.columns):\n", - " print(\"This notebook is up-to-date.\")\n", - "\n", - "if not Notebook_version == list(Latest_notebook_version.columns):\n", - " print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n", - "\n", - "def pdf_export(trained = False, raw_data = False, pretrained_model = False):\n", - " class MyFPDF(FPDF, HTMLMixin):\n", - " pass\n", - "\n", - " pdf = MyFPDF()\n", - " pdf.add_page()\n", - " pdf.set_right_margin(-1)\n", - " pdf.set_font(\"Arial\", size = 11, style='B') \n", - "\n", - " Network = 'Deep-STORM'\n", - " #model_name = 'little_CARE_test'\n", - " day = datetime.now()\n", - " datetime_str = str(day)[0:10]\n", - "\n", - " Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n", - " pdf.multi_cell(180, 5, txt = Header, align = 'L') \n", - " \n", - " # add another cell \n", - " if trained:\n", - " training_time = \"Training time: \"+str(hours)+ \"hour(s) \"+str(minutes)+\"min(s) \"+str(round(seconds))+\"sec(s)\"\n", - " pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n", - " pdf.ln(1)\n", - "\n", - " Header_2 = 'Information for your materials and method:'\n", - " pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n", - "\n", - " all_packages = ''\n", - " for requirement in freeze(local_only=True):\n", - " all_packages = all_packages+requirement+', '\n", - " #print(all_packages)\n", - "\n", - " #Main Packages\n", - " main_packages = ''\n", - " version_numbers = []\n", - " for name in ['tensorflow','numpy','Keras']:\n", - " find_name=all_packages.find(name)\n", - " main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n", - " #Version numbers only here:\n", - " version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n", - "\n", - " cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n", - " cuda_version = cuda_version.stdout.decode('utf-8')\n", - " cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n", - " gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n", - " gpu_name = gpu_name.stdout.decode('utf-8')\n", - " gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n", - " #print(cuda_version[cuda_version.find(', V')+3:-1])\n", - " #print(gpu_name)\n", - " if raw_data == True:\n", - " shape = (M,N)\n", - " else:\n", - " shape = (int(FOV_size/pixel_size),int(FOV_size/pixel_size))\n", - " #dataset_size = len(os.listdir(Training_source))\n", - "\n", - " text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(n_patches)+' paired image patches (image dimensions: '+str(patch_size)+', patch size (upsampled): ('+str(int(patch_size))+','+str(int(patch_size))+') with a batch size of '+str(batch_size)+', using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Losses were calculated using MSE for the heatmaps and L1 loss for the spike prediction. Key python packages used include tensorflow (v '+version_numbers[0]+'), numpy (v '+version_numbers[1]+'), Keras (v '+version_numbers[2]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+' GPU.'\n", - "\n", - " if pretrained_model:\n", - " text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(n_patches)+' paired image patches (image dimensions: '+str(patch_size)+', patch size (upsampled): ('+str(int(patch_size))+','+str(int(patch_size))+') with a batch size of '+str(batch_size)+', using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Losses were calculated using MSE for the heatmaps and L1 loss for the spike prediction. The models was retrained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), numpy (v '+version_numbers[1]+'), Keras (v '+version_numbers[2]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+' GPU.'\n", - "\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - " pdf.multi_cell(180, 5, txt = text, align='L')\n", - " pdf.ln(1)\n", - " pdf.set_font('')\n", - " pdf.set_font(\"Arial\", size = 11, style='B')\n", - " pdf.ln(1)\n", - " pdf.cell(190, 5, txt = 'Training dataset', align='L', ln=1)\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - " if raw_data==False:\n", - " simul_text = 'The training dataset was created in the notebook using the following simulation settings:'\n", - " pdf.cell(200, 5, txt=simul_text, align='L')\n", - " pdf.ln(1)\n", - " html = \"\"\" \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
SettingSimulated Value
FOV_size{0}
pixel_size{1}
ADC_per_photon_conversion{2}
ReadOutNoise_ADC{3}
ADC_offset{4}
emitter_density{5}
emitter_density_std{6}
number_of_frames{7}
sigma{8}
sigma_std{9}
n_photons{10}
n_photons_std{11}
\n", - " \"\"\".format(FOV_size, pixel_size, ADC_per_photon_conversion, ReadOutNoise_ADC, ADC_offset, emitter_density, emitter_density_std, number_of_frames, sigma, sigma_std, n_photons, n_photons_std)\n", - " pdf.write_html(html)\n", - " else:\n", - " simul_text = 'The training dataset was simulated using ThunderSTORM and loaded into the notebook.'\n", - " pdf.multi_cell(190, 5, txt=simul_text, align='L')\n", - " pdf.set_font(\"Arial\", size = 11, style='B')\n", - " #pdf.ln(1)\n", - " #pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.cell(29, 5, txt= 'ImageData_path', align = 'L', ln=0)\n", - " pdf.set_font('')\n", - " pdf.multi_cell(170, 5, txt = ImageData_path, align = 'L')\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.cell(28, 5, txt= 'LocalizationData_path:', align = 'L', ln=0)\n", - " pdf.set_font('')\n", - " pdf.multi_cell(170, 5, txt = LocalizationData_path, align = 'L')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.cell(28, 5, txt= 'pixel_size:', align = 'L', ln=0)\n", - " pdf.set_font('')\n", - " pdf.multi_cell(170, 5, txt = str(pixel_size), align = 'L')\n", - " #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n", - " pdf.set_font('Arial', size = 11, style = 'B')\n", - " pdf.ln(1)\n", - " pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - " # if Use_Default_Advanced_Parameters:\n", - " # pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n", - " pdf.cell(200, 5, txt='The following parameters were used to generate patches:')\n", - " pdf.ln(1)\n", - " html = \"\"\"\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \"\"\".format(str(patch_size)+'x'+str(patch_size), upsampling_factor, num_patches_per_frame, min_number_of_emitters_per_patch, max_num_patches, gaussian_sigma, Automatic_normalization, L2_weighting_factor)\n", - " pdf.write_html(html)\n", - " pdf.ln(3)\n", - " pdf.set_font('Arial', size=10)\n", - " pdf.cell(200, 5, txt='The following parameters were used for training:')\n", - " pdf.ln(1)\n", - " html = \"\"\" \n", - "
Patch ParameterValue
patch_size{0}
upsampling_factor{1}
num_patches_per_frame{2}
min_number_of_emitters_per_patch{3}
max_num_patches{4}
gaussian_sigma{5}
Automatic_normalization{6}
L2_weighting_factor{7}
\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
Training ParameterValue
number_of_epochs{0}
batch_size{1}
number_of_steps{2}
percentage_validation{3}
initial_learning_rate{4}
\n", - " \"\"\".format(number_of_epochs,batch_size,number_of_steps,percentage_validation,initial_learning_rate)\n", - " pdf.write_html(html)\n", - "\n", - " pdf.ln(1)\n", - " # pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.cell(21, 5, txt= 'Model Path:', align = 'L', ln=0)\n", - " pdf.set_font('')\n", - " pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n", - "\n", - " pdf.ln(1)\n", - " pdf.cell(60, 5, txt = 'Example Training Images', ln=1)\n", - " pdf.ln(1)\n", - " exp_size = io.imread('/content/TrainingDataExample_DeepSTORM2D.png').shape\n", - " pdf.image('/content/TrainingDataExample_DeepSTORM2D.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n", - " pdf.ln(1)\n", - " ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" bioRxiv (2020).'\n", - " pdf.multi_cell(190, 5, txt = ref_1, align='L')\n", - " ref_2 = '- Deep-STORM: Nehme, Elias, et al. \"Deep-STORM: super-resolution single-molecule microscopy by deep learning.\" Optica 5.4 (2018): 458-464.'\n", - " pdf.multi_cell(190, 5, txt = ref_2, align='L')\n", - " # if Use_Data_augmentation:\n", - " # ref_3 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\n", - " # pdf.multi_cell(190, 5, txt = ref_3, align='L')\n", - " pdf.ln(3)\n", - " reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n", - " pdf.set_font('Arial', size = 11, style='B')\n", - " pdf.multi_cell(190, 5, txt=reminder, align='C')\n", - "\n", - " pdf.output(model_path+'/'+model_name+'/'+model_name+'_training_report.pdf')\n", - " print('------------------------------')\n", - " print('PDF report exported in '+model_path+'/'+model_name+'/')\n", - "\n", - "def qc_pdf_export():\n", - " class MyFPDF(FPDF, HTMLMixin):\n", - " pass\n", - "\n", - " pdf = MyFPDF()\n", - " pdf.add_page()\n", - " pdf.set_right_margin(-1)\n", - " pdf.set_font(\"Arial\", size = 11, style='B') \n", - "\n", - " Network = 'Deep-STORM'\n", - " #model_name = os.path.basename(full_QC_model_path)\n", - " day = datetime.now()\n", - " datetime_str = str(day)[0:10]\n", - "\n", - " Header = 'Quality Control report for '+Network+' model ('+os.path.basename(QC_model_path)+')\\nDate: '+datetime_str\n", - " pdf.multi_cell(180, 5, txt = Header, align = 'L') \n", - "\n", - " all_packages = ''\n", - " for requirement in freeze(local_only=True):\n", - " all_packages = all_packages+requirement+', '\n", - "\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 11, style = 'B')\n", - " pdf.ln(2)\n", - " pdf.cell(190, 5, txt = 'Loss curves', ln=1, align='L')\n", - " pdf.ln(1)\n", - " if os.path.exists(savePath+'/lossCurvePlots.png'):\n", - " exp_size = io.imread(savePath+'/lossCurvePlots.png').shape\n", - " pdf.image(savePath+'/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/10))\n", - " else:\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size=10)\n", - " pdf.cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.')\n", - " pdf.ln(2)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.ln(3)\n", - " pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n", - " pdf.ln(1)\n", - " exp_size = io.imread(savePath+'/QC_example_data.png').shape\n", - " pdf.image(savePath+'/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n", - " pdf.ln(1)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 11, style = 'B')\n", - " pdf.ln(1)\n", - " pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - "\n", - " pdf.ln(1)\n", - " html = \"\"\"\n", - " \n", - " \n", - " \"\"\"\n", - " with open(savePath+'/'+os.path.basename(QC_model_path)+'_QC_metrics.csv', 'r') as csvfile:\n", - " metrics = csv.reader(csvfile)\n", - " header = next(metrics)\n", - " image = header[0]\n", - " mSSIM_PvsGT = header[1]\n", - " mSSIM_SvsGT = header[2]\n", - " NRMSE_PvsGT = header[3]\n", - " NRMSE_SvsGT = header[4]\n", - " PSNR_PvsGT = header[5]\n", - " PSNR_SvsGT = header[6]\n", - " header = \"\"\"\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \"\"\".format(image,mSSIM_PvsGT,mSSIM_SvsGT,NRMSE_PvsGT,NRMSE_SvsGT,PSNR_PvsGT,PSNR_SvsGT)\n", - " html = html+header\n", - " for row in metrics:\n", - " image = row[0]\n", - " mSSIM_PvsGT = row[1]\n", - " mSSIM_SvsGT = row[2]\n", - " NRMSE_PvsGT = row[3]\n", - " NRMSE_SvsGT = row[4]\n", - " PSNR_PvsGT = row[5]\n", - " PSNR_SvsGT = row[6]\n", - " cells = \"\"\"\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \"\"\".format(image,str(round(float(mSSIM_PvsGT),3)),str(round(float(mSSIM_SvsGT),3)),str(round(float(NRMSE_PvsGT),3)),str(round(float(NRMSE_SvsGT),3)),str(round(float(PSNR_PvsGT),3)),str(round(float(PSNR_SvsGT),3)))\n", - " html = html+cells\n", - " html = html+\"\"\"
{0}{1}{2}{3}{4}{5}{6}
{0}{1}{2}{3}{4}{5}{6}
\"\"\"\n", - " \n", - " pdf.write_html(html)\n", - "\n", - " pdf.ln(1)\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - " ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" bioRxiv (2020).'\n", - " pdf.multi_cell(190, 5, txt = ref_1, align='L')\n", - " ref_2 = '- Deep-STORM: Nehme, Elias, et al. \"Deep-STORM: super-resolution single-molecule microscopy by deep learning.\" Optica 5.4 (2018): 458-464.'\n", - " pdf.multi_cell(190, 5, txt = ref_2, align='L')\n", - "\n", - " pdf.ln(3)\n", - " reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n", - "\n", - " pdf.set_font('Arial', size = 11, style='B')\n", - " pdf.multi_cell(190, 5, txt=reminder, align='C')\n", - "\n", - " pdf.output(savePath+'/'+os.path.basename(QC_model_path)+'_QC_report.pdf')\n", - "\n", - "\n", - " print('------------------------------')\n", - " print('QC PDF report exported as '+savePath+'/'+os.path.basename(QC_model_path)+'_QC_report.pdf')\n", - "\n", - "\n", - "\n", - "# Exporting requirements.txt for local run\n", - "!pip freeze > requirements.txt\n", - "\n", - "after = [str(m) for m in sys.modules]\n", - "# Get minimum requirements file\n", - "\n", - "#Add the following lines before all imports: \n", - "# import sys\n", - "# before = [str(m) for m in sys.modules]\n", - "\n", - "#Add the following line after the imports:\n", - "# after = [str(m) for m in sys.modules]\n", - "\n", - "from builtins import any as b_any\n", - "\n", - "def filter_files(file_list, filter_list):\n", - " filtered_list = []\n", - " for fname in file_list:\n", - " if b_any(fname.split('==')[0] in s for s in filter_list):\n", - " filtered_list.append(fname)\n", - " return filtered_list\n", - "\n", - "df = pd.read_csv('requirements.txt', delimiter = \"\\n\")\n", - "mod_list = [m.split('.')[0] for m in after if not m in before]\n", - "\n", - "req_list_temp = df.values.tolist()\n", - "req_list = [x[0] for x in req_list_temp]\n", - "\n", - "# Replace with package name \n", - "mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]\n", - "mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list] \n", - "filtered_list = filter_files(req_list, mod_replace_list)\n", - "\n", - "file=open('DeepSTORM_2D_requirements_simple.txt','w')\n", - "for item in filtered_list:\n", - " file.writelines(item + '\\n')\n", - "\n", - "file.close()\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "vu8f5NGJkJos" - }, - "source": [ - "\n", - "# **3. Generate patches for training**\n", - "---\n", - "\n", - "For Deep-STORM the training data can be obtained in two ways:\n", - "* Simulated using ThunderSTORM or other simulation tool and loaded here (**using Section 3.1.a**)\n", - "* Directly simulated in this notebook (**using Section 3.1.b**)\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "WSV8xnlynp0l" - }, - "source": [ - "## **3.1.a Load training data**\n", - "---\n", - "\n", - "Here you can load your simulated data along with its corresponding localization file.\n", - "* The `pixel_size` is defined in nanometer (nm). " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "CT6SNcfNg6j0" - }, - "outputs": [], - "source": [ - "#@markdown ##Load raw data\n", - "\n", - "load_raw_data = True\n", - "\n", - "# Get user input\n", - "ImageData_path = \"\" #@param {type:\"string\"}\n", - "LocalizationData_path = \"\" #@param {type: \"string\"}\n", - "#@markdown Get pixel size from file?\n", - "get_pixel_size_from_file = True #@param {type:\"boolean\"}\n", - "#@markdown Otherwise, use this value:\n", - "pixel_size = 100 #@param {type:\"number\"}\n", - "\n", - "if get_pixel_size_from_file:\n", - " pixel_size,_,_ = getPixelSizeTIFFmetadata(ImageData_path, True)\n", - "\n", - "# load the tiff data\n", - "Images = io.imread(ImageData_path)\n", - "# get dataset dimensions\n", - "if len(Images.shape) == 3:\n", - " (number_of_frames, M, N) = Images.shape\n", - "elif len(Images.shape) == 2:\n", - " (M, N) = Images.shape\n", - " number_of_frames = 1\n", - "print('Loaded images: '+str(M)+'x'+str(N)+' with '+str(number_of_frames)+' frames')\n", - "\n", - "# Interactive display of the stack\n", - "def scroll_in_time(frame):\n", - " f=plt.figure(figsize=(6,6))\n", - " plt.imshow(Images[frame-1], interpolation='nearest', cmap = 'gray')\n", - " plt.title('Training source at frame = ' + str(frame))\n", - " plt.axis('off');\n", - "\n", - "if number_of_frames > 1:\n", - " interact(scroll_in_time, frame=widgets.IntSlider(min=1, max=Images.shape[0], step=1, value=0, continuous_update=False));\n", - "else:\n", - " f=plt.figure(figsize=(6,6))\n", - " plt.imshow(Images, interpolation='nearest', cmap = 'gray')\n", - " plt.title('Training source')\n", - " plt.axis('off');\n", - "\n", - "# Load the localization file and display the first\n", - "LocData = pd.read_csv(LocalizationData_path, index_col=0)\n", - "LocData.tail()\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "K9xE5GeYiks9" - }, - "source": [ - "## **3.1.b Simulate training data**\n", - "---\n", - "This simulation tool allows you to generate SMLM data of randomly distrubuted emitters in a field-of-view. \n", - "The assumptions are as follows:\n", - "\n", - "* Gaussian Point Spread Function (PSF) with standard deviation defined by `Sigma`. The nominal value of `sigma` can be evaluated using `sigma = 0.21 x Lambda / NA`. \n", - "* Each emitter will emit `n_photons` per frame, and generate their equivalent Poisson noise.\n", - "* The camera will contribute Gaussian noise to the signal with a standard deviation defined by `ReadOutNoise_ADC` in ADC\n", - "* The `emitter_density` is defined as the number of emitters / um^2 on any given frame. Variability in the emitter density can be applied by adjusting `emitter_density_std`. The latter parameter represents the standard deviation of the normal distribution that the density is drawn from for each individual frame. `emitter_density` **is defined in number of emitters / um^2**.\n", - "* The `n_photons` and `sigma` can additionally include some Gaussian variability by setting `n_photons_std` and `sigma_std`.\n", - "\n", - "Important note:\n", - "- All dimensions are in nanometer (e.g. `FOV_size` = 6400 represents a field of view of 6.4 um x 6.4 um).\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "sQyLXpEhitsg" - }, - "outputs": [], - "source": [ - "load_raw_data = False\n", - "\n", - "# ---------------------------- User input ----------------------------\n", - "#@markdown Run the simulation\n", - "#@markdown --- \n", - "#@markdown Camera settings: \n", - "FOV_size = 6400#@param {type:\"number\"}\n", - "pixel_size = 100#@param {type:\"number\"}\n", - "ADC_per_photon_conversion = 1 #@param {type:\"number\"}\n", - "ReadOutNoise_ADC = 4.5#@param {type:\"number\"}\n", - "ADC_offset = 50#@param {type:\"number\"}\n", - "\n", - "#@markdown Acquisition settings: \n", - "emitter_density = 6#@param {type:\"number\"}\n", - "emitter_density_std = 0#@param {type:\"number\"}\n", - "\n", - "number_of_frames = 20#@param {type:\"integer\"}\n", - "\n", - "sigma = 110 #@param {type:\"number\"}\n", - "sigma_std = 5 #@param {type:\"number\"}\n", - "# NA = 1.1 #@param {type:\"number\"}\n", - "# wavelength = 800#@param {type:\"number\"}\n", - "# wavelength_std = 150#@param {type:\"number\"}\n", - "n_photons = 2250#@param {type:\"number\"}\n", - "n_photons_std = 250#@param {type:\"number\"}\n", - "\n", - "\n", - "# ---------------------------- Variable initialisation ----------------------------\n", - "# Start the clock to measure how long it takes\n", - "start = time.time()\n", - "\n", - "print('-----------------------------------------------------------')\n", - "n_molecules = emitter_density*FOV_size*FOV_size/10**6\n", - "n_molecules_std = emitter_density_std*FOV_size*FOV_size/10**6\n", - "print('Number of molecules / FOV: '+str(round(n_molecules,2))+' +/- '+str((round(n_molecules_std,2))))\n", - "\n", - "# sigma = 0.21*wavelength/NA\n", - "# sigma_std = 0.21*wavelength_std/NA\n", - "# print('Gaussian PSF sigma: '+str(round(sigma,2))+' +/- '+str(round(sigma_std,2))+' nm')\n", - "\n", - "M = N = round(FOV_size/pixel_size)\n", - "FOV_size = M*pixel_size\n", - "print('Final image size: '+str(M)+'x'+str(M)+' ('+str(round(FOV_size/1000, 3))+'um x'+str(round(FOV_size/1000,3))+' um)')\n", - "\n", - "np.random.seed(1)\n", - "display_upsampling = 8 # used to display the loc map here\n", - "NoiseFreeImages = np.zeros((number_of_frames, M, M))\n", - "locImage = np.zeros((number_of_frames, display_upsampling*M, display_upsampling*N))\n", - "\n", - "frames = []\n", - "all_xloc = []\n", - "all_yloc = []\n", - "all_photons = []\n", - "all_sigmas = []\n", - "\n", - "# ---------------------------- Main simulation loop ----------------------------\n", - "print('-----------------------------------------------------------')\n", - "for f in tqdm(range(number_of_frames)):\n", - " \n", - " # Define the coordinates of emitters by randomly distributing them across the FOV\n", - " n_mol = int(max(round(np.random.normal(n_molecules, n_molecules_std, size=1)[0]), 0))\n", - " x_c = np.random.uniform(low=0.0, high=FOV_size, size=n_mol)\n", - " y_c = np.random.uniform(low=0.0, high=FOV_size, size=n_mol)\n", - " photon_array = np.random.normal(n_photons, n_photons_std, size=n_mol)\n", - " sigma_array = np.random.normal(sigma, sigma_std, size=n_mol)\n", - " # x_c = np.linspace(0,3000,5)\n", - " # y_c = np.linspace(0,3000,5)\n", - "\n", - " all_xloc += x_c.tolist()\n", - " all_yloc += y_c.tolist()\n", - " frames += ((f+1)*np.ones(x_c.shape[0])).tolist()\n", - " all_photons += photon_array.tolist()\n", - " all_sigmas += sigma_array.tolist()\n", - "\n", - " locImage[f] = FromLoc2Image_SimpleHistogram(x_c, y_c, image_size = (N*display_upsampling, M*display_upsampling), pixel_size = pixel_size/display_upsampling)\n", - "\n", - " # # Get the approximated locations according to the grid pixel size\n", - " # Chr_emitters = [int(max(min(round(display_upsampling*x_c[i]/pixel_size),N*display_upsampling-1),0)) for i in range(len(x_c))]\n", - " # Rhr_emitters = [int(max(min(round(display_upsampling*y_c[i]/pixel_size),M*display_upsampling-1),0)) for i in range(len(y_c))]\n", - "\n", - " # # Build Localization image\n", - " # for (r,c) in zip(Rhr_emitters, Chr_emitters):\n", - " # locImage[f][r][c] += 1\n", - "\n", - " NoiseFreeImages[f] = FromLoc2Image_Erf(x_c, y_c, photon_array, sigma_array, image_size = (M,M), pixel_size = pixel_size)\n", - "\n", - "\n", - "# ---------------------------- Create DataFrame fof localization file ----------------------------\n", - "# Table with localization info as dataframe output\n", - "LocData = pd.DataFrame()\n", - "LocData[\"frame\"] = frames\n", - "LocData[\"x [nm]\"] = all_xloc\n", - "LocData[\"y [nm]\"] = all_yloc\n", - "LocData[\"Photon #\"] = all_photons\n", - "LocData[\"Sigma [nm]\"] = all_sigmas\n", - "LocData.index += 1 # set indices to start at 1 and not 0 (same as ThunderSTORM)\n", - "\n", - "\n", - "# ---------------------------- Estimation of SNR ----------------------------\n", - "n_frames_for_SNR = 100\n", - "M_SNR = 10\n", - "x_c = np.random.uniform(low=0.0, high=pixel_size*M_SNR, size=n_frames_for_SNR)\n", - "y_c = np.random.uniform(low=0.0, high=pixel_size*M_SNR, size=n_frames_for_SNR)\n", - "photon_array = np.random.normal(n_photons, n_photons_std, size=n_frames_for_SNR)\n", - "sigma_array = np.random.normal(sigma, sigma_std, size=n_frames_for_SNR)\n", - "\n", - "SNR = np.zeros(n_frames_for_SNR)\n", - "for i in range(n_frames_for_SNR):\n", - " SingleEmitterImage = FromLoc2Image_Erf(np.array([x_c[i]]), np.array([x_c[i]]), np.array([photon_array[i]]), np.array([sigma_array[i]]), (M_SNR, M_SNR), pixel_size)\n", - " Signal_photon = np.max(SingleEmitterImage)\n", - " Noise_photon = math.sqrt((ReadOutNoise_ADC/ADC_per_photon_conversion)**2 + Signal_photon)\n", - " SNR[i] = Signal_photon/Noise_photon\n", - "\n", - "print('SNR: '+str(round(np.mean(SNR),2))+' +/- '+str(round(np.std(SNR),2)))\n", - "# ---------------------------- ----------------------------\n", - "\n", - "\n", - "# Table with info\n", - "simParameters = pd.DataFrame()\n", - "simParameters[\"FOV size (nm)\"] = [FOV_size]\n", - "simParameters[\"Pixel size (nm)\"] = [pixel_size]\n", - "simParameters[\"ADC/photon\"] = [ADC_per_photon_conversion]\n", - "simParameters[\"Read-out noise (ADC)\"] = [ReadOutNoise_ADC]\n", - "simParameters[\"Constant offset (ADC)\"] = [ADC_offset]\n", - "\n", - "simParameters[\"Emitter density (emitters/um^2)\"] = [emitter_density]\n", - "simParameters[\"STD of emitter density (emitters/um^2)\"] = [emitter_density_std]\n", - "simParameters[\"Number of frames\"] = [number_of_frames]\n", - "# simParameters[\"NA\"] = [NA]\n", - "# simParameters[\"Wavelength (nm)\"] = [wavelength]\n", - "# simParameters[\"STD of wavelength (nm)\"] = [wavelength_std]\n", - "simParameters[\"Sigma (nm))\"] = [sigma]\n", - "simParameters[\"STD of Sigma (nm))\"] = [sigma_std]\n", - "simParameters[\"Number of photons\"] = [n_photons]\n", - "simParameters[\"STD of number of photons\"] = [n_photons_std]\n", - "simParameters[\"SNR\"] = [np.mean(SNR)]\n", - "simParameters[\"STD of SNR\"] = [np.std(SNR)]\n", - "\n", - "\n", - "# ---------------------------- Finish simulation ----------------------------\n", - "# Calculating the noisy image\n", - "Images = ADC_per_photon_conversion * np.random.poisson(NoiseFreeImages) + ReadOutNoise_ADC * np.random.normal(size = (number_of_frames, M, N)) + ADC_offset\n", - "Images[Images <= 0] = 0\n", - "\n", - "# Convert to 16-bit or 32-bits integers\n", - "if Images.max() < (2**16-1):\n", - " Images = Images.astype(np.uint16)\n", - "else:\n", - " Images = Images.astype(np.uint32)\n", - "\n", - "\n", - "# ---------------------------- Display ----------------------------\n", - "# Displaying the time elapsed for simulation\n", - "dt = time.time() - start\n", - "minutes, seconds = divmod(dt, 60) \n", - "hours, minutes = divmod(minutes, 60) \n", - "print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds,1),\"sec(s)\")\n", - "\n", - "\n", - "# Interactively display the results using Widgets\n", - "def scroll_in_time(frame):\n", - " f = plt.figure(figsize=(18,6))\n", - " plt.subplot(1,3,1)\n", - " plt.imshow(locImage[frame-1], interpolation='bilinear', vmin = 0, vmax=0.1)\n", - " plt.title('Localization image')\n", - " plt.axis('off');\n", - "\n", - " plt.subplot(1,3,2)\n", - " plt.imshow(NoiseFreeImages[frame-1], interpolation='nearest', cmap='gray')\n", - " plt.title('Noise-free simulation')\n", - " plt.axis('off');\n", - "\n", - " plt.subplot(1,3,3)\n", - " plt.imshow(Images[frame-1], interpolation='nearest', cmap='gray')\n", - " plt.title('Noisy simulation')\n", - " plt.axis('off');\n", - "\n", - "interact(scroll_in_time, frame=widgets.IntSlider(min=1, max=Images.shape[0], step=1, value=0, continuous_update=False));\n", - "\n", - "# Display the head of the dataframe with localizations\n", - "LocData.tail()\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "Pz7RfSuoeJeq" - }, - "outputs": [], - "source": [ - "# @markdown ---\n", - "# @markdown #Play this cell to save the simulated stack\n", - "# @markdown ####Please select a path to the folder where to save the simulated data. It is not necessary to save the data to run the training, but keeping the simulated for your own record can be useful to check its validity.\n", - "Save_path = \"\" #@param {type:\"string\"}\n", - "\n", - "if not os.path.exists(Save_path):\n", - " os.makedirs(Save_path)\n", - " print('Folder created.')\n", - "else:\n", - " print('Training data already exists in folder: Data overwritten.')\n", - "\n", - "saveAsTIF(Save_path, 'SimulatedDataset', Images, pixel_size)\n", - "# io.imsave(os.path.join(Save_path, 'SimulatedDataset.tif'),Images)\n", - "LocData.to_csv(os.path.join(Save_path, 'SimulatedDataset.csv'))\n", - "simParameters.to_csv(os.path.join(Save_path, 'SimulatedParameters.csv'))\n", - "print('Training dataset saved.')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "K_8e3kE-JhVY" - }, - "source": [ - "## **3.2. Generate training patches**\n", - "---\n", - "\n", - "Training patches need to be created from the training data generated above. \n", - "* The `patch_size` needs to give sufficient contextual information and for most cases a `patch_size` of 26 (corresponding to patches of 26x26 pixels) works fine. **DEFAULT: 26**\n", - "* The `upsampling_factor` defines the effective magnification of the final super-resolved image compared to the input image (this is called magnification in ThunderSTORM). This is used to generate the super-resolved patches as target dataset. Using an `upsampling_factor` of 16 will require the use of more memory and it may be necessary to decreae the `patch_size` to 16 for example. **DEFAULT: 8**\n", - "* The `num_patches_per_frame` defines the number of patches extracted from each frame generated in section 3.1. **DEFAULT: 500**\n", - "* The `min_number_of_emitters_per_patch` defines the minimum number of emitters that need to be present in the patch to be a valid patch. An empty patch does not contain useful information for the network to learn from. **DEFAULT: 7**\n", - "* The `max_num_patches` defines the maximum number of patches to generate. Fewer may be generated depending on how many pacthes are rejected and how many frames are available. **DEFAULT: 10000**\n", - "* The `gaussian_sigma` defines the Gaussian standard deviation (in magnified pixels) applied to generate the super-resolved target image. **DEFAULT: 1**\n", - "* The `L2_weighting_factor` is a normalization factor used in the loss function. It helps balancing the loss from the L2 norm. When using higher densities, this factor should be decreased and vice-versa. This factor can be autimatically calculated using an empiraical formula. **DEFAULT: 100**\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "AsNx5KzcFNvC" - }, - "outputs": [], - "source": [ - "#@markdown ## **Provide patch parameters**\n", - "\n", - "\n", - "# -------------------- User input --------------------\n", - "patch_size = 26 #@param {type:\"integer\"}\n", - "upsampling_factor = 8 #@param [\"4\", \"8\", \"16\"] {type:\"raw\"}\n", - "num_patches_per_frame = 500#@param {type:\"integer\"}\n", - "min_number_of_emitters_per_patch = 7#@param {type:\"integer\"}\n", - "max_num_patches = 10000#@param {type:\"integer\"}\n", - "gaussian_sigma = 1#@param {type:\"integer\"}\n", - "\n", - "#@markdown Estimate the optimal normalization factor automatically?\n", - "Automatic_normalization = True #@param {type:\"boolean\"}\n", - "#@markdown Otherwise, it will use the following value:\n", - "L2_weighting_factor = 100 #@param {type:\"number\"}\n", - "\n", - "\n", - "# -------------------- Prepare variables --------------------\n", - "# Start the clock to measure how long it takes\n", - "start = time.time()\n", - "\n", - "# Initialize some parameters\n", - "pixel_size_hr = pixel_size/upsampling_factor # in nm\n", - "n_patches = min(number_of_frames*num_patches_per_frame, max_num_patches)\n", - "patch_size = patch_size*upsampling_factor\n", - "\n", - "# Dimensions of the high-res grid\n", - "Mhr = upsampling_factor*M # in pixels\n", - "Nhr = upsampling_factor*N # in pixels\n", - "\n", - "# Initialize the training patches and labels\n", - "patches = np.zeros((n_patches, patch_size, patch_size), dtype = np.float32)\n", - "spikes = np.zeros((n_patches, patch_size, patch_size), dtype = np.float32)\n", - "heatmaps = np.zeros((n_patches, patch_size, patch_size), dtype = np.float32)\n", - "\n", - "# Run over all frames and construct the training examples\n", - "k = 1 # current patch count\n", - "skip_counter = 0 # number of dataset skipped due to low density\n", - "id_start = 0 # id position in LocData for current frame\n", - "print('Generating '+str(n_patches)+' patches of '+str(patch_size)+'x'+str(patch_size))\n", - "\n", - "n_locs = len(LocData.index)\n", - "print('Total number of localizations: '+str(n_locs))\n", - "density = n_locs/(M*N*number_of_frames*(0.001*pixel_size)**2)\n", - "print('Density: '+str(round(density,2))+' locs/um^2')\n", - "n_locs_per_patch = patch_size**2*density\n", - "\n", - "if Automatic_normalization:\n", - " # This empirical formulae attempts to balance the loss L2 function between the background and the bright spikes\n", - " # A value of 100 was originally chosen to balance L2 for a patch size of 2.6x2.6^2 0.1um pixel size and density of 3 (hence the 20.28), at upsampling_factor = 8\n", - " L2_weighting_factor = 100/math.sqrt(min(n_locs_per_patch, min_number_of_emitters_per_patch)*8**2/(upsampling_factor**2*20.28))\n", - " print('Normalization factor: '+str(round(L2_weighting_factor,2)))\n", - "\n", - "# -------------------- Patch generation loop --------------------\n", - "\n", - "print('-----------------------------------------------------------')\n", - "for (f, thisFrame) in enumerate(tqdm(Images)):\n", - "\n", - " # Upsample the frame\n", - " upsampledFrame = np.kron(thisFrame, np.ones((upsampling_factor,upsampling_factor)))\n", - " # Read all the provided high-resolution locations for current frame\n", - " DataFrame = LocData[LocData['frame'] == f+1].copy()\n", - "\n", - " # Get the approximated locations according to the high-res grid pixel size\n", - " Chr_emitters = [int(max(min(round(DataFrame['x [nm]'][i]/pixel_size_hr),Nhr-1),0)) for i in range(id_start+1,id_start+1+len(DataFrame.index))]\n", - " Rhr_emitters = [int(max(min(round(DataFrame['y [nm]'][i]/pixel_size_hr),Mhr-1),0)) for i in range(id_start+1,id_start+1+len(DataFrame.index))]\n", - " id_start += len(DataFrame.index)\n", - "\n", - " # Build Localization image\n", - " LocImage = np.zeros((Mhr,Nhr))\n", - " LocImage[(Rhr_emitters, Chr_emitters)] = 1\n", - "\n", - " # Here, there's a choice between the original Gaussian (classification approach) and using the erf function\n", - " HeatMapImage = L2_weighting_factor*gaussian_filter(LocImage, gaussian_sigma) \n", - " # HeatMapImage = L2_weighting_factor*FromLoc2Image_MultiThreaded(np.array(list(DataFrame['x [nm]'])), np.array(list(DataFrame['y [nm]'])), \n", - " # np.ones(len(DataFrame.index)), pixel_size_hr*gaussian_sigma*np.ones(len(DataFrame.index)), \n", - " # Mhr, pixel_size_hr)\n", - " \n", - "\n", - " # Generate random position for the top left corner of the patch\n", - " xc = np.random.randint(0, Mhr-patch_size, size=num_patches_per_frame)\n", - " yc = np.random.randint(0, Nhr-patch_size, size=num_patches_per_frame)\n", - "\n", - " for c in range(len(xc)):\n", - " if LocImage[xc[c]:xc[c]+patch_size, yc[c]:yc[c]+patch_size].sum() < min_number_of_emitters_per_patch:\n", - " skip_counter += 1\n", - " continue\n", - " \n", - " else:\n", - " # Limit maximal number of training examples to 15k\n", - " if k > max_num_patches:\n", - " break\n", - " else:\n", - " # Assign the patches to the right part of the images\n", - " patches[k-1] = upsampledFrame[xc[c]:xc[c]+patch_size, yc[c]:yc[c]+patch_size]\n", - " spikes[k-1] = LocImage[xc[c]:xc[c]+patch_size, yc[c]:yc[c]+patch_size]\n", - " heatmaps[k-1] = HeatMapImage[xc[c]:xc[c]+patch_size, yc[c]:yc[c]+patch_size]\n", - " k += 1 # increment current patch count\n", - "\n", - "# Remove the empty data\n", - "patches = patches[:k-1]\n", - "spikes = spikes[:k-1]\n", - "heatmaps = heatmaps[:k-1]\n", - "n_patches = k-1\n", - "\n", - "# -------------------- Failsafe --------------------\n", - "# Check if the size of the training set is smaller than 5k to notify user to simulate more images using ThunderSTORM\n", - "if ((k-1) < 5000):\n", - " # W = '\\033[0m' # white (normal)\n", - " # R = '\\033[31m' # red\n", - " print(bcolors.WARNING+'!! WARNING: Training set size is below 5K - Consider simulating more images in ThunderSTORM. !!'+bcolors.NORMAL)\n", - "\n", - "\n", - "\n", - "# -------------------- Displays --------------------\n", - "print('Number of patches skipped due to low density: '+str(skip_counter))\n", - "# dataSize = int((getsizeof(patches)+getsizeof(heatmaps)+getsizeof(spikes))/(1024*1024)) #rounded in MB\n", - "# print('Size of patches: '+str(dataSize)+' MB')\n", - "print(str(n_patches)+' patches were generated.')\n", - "\n", - "# Displaying the time elapsed for training\n", - "dt = time.time() - start\n", - "minutes, seconds = divmod(dt, 60) \n", - "hours, minutes = divmod(minutes, 60) \n", - "print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n", - "\n", - "# Display patches interactively with a slider\n", - "def scroll_patches(patch):\n", - " f = plt.figure(figsize=(16,6))\n", - " plt.subplot(1,3,1)\n", - " plt.imshow(patches[patch-1], interpolation='nearest', cmap='gray')\n", - " plt.title('Raw data (frame #'+str(patch)+')')\n", - " plt.axis('off');\n", - "\n", - " plt.subplot(1,3,2)\n", - " plt.imshow(heatmaps[patch-1], interpolation='nearest')\n", - " plt.title('Heat map')\n", - " plt.axis('off');\n", - "\n", - " plt.subplot(1,3,3)\n", - " plt.imshow(spikes[patch-1], interpolation='nearest')\n", - " plt.title('Localization map')\n", - " plt.axis('off');\n", - " \n", - " plt.savefig('/content/TrainingDataExample_DeepSTORM2D.png',bbox_inches='tight',pad_inches=0)\n", - "\n", - "\n", - "interact(scroll_patches, patch=widgets.IntSlider(min=1, max=patches.shape[0], step=1, value=0, continuous_update=False));\n", - "\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "DSjXFMevK7Iz" - }, - "source": [ - "# **4. Train the network**\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "hVeyKU0MdAPx" - }, - "source": [ - "## **4.1. Select your paths and parameters**\n", - "\n", - "---\n", - "\n", - "**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n", - "\n", - "**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n", - "\n", - "\n", - "**Training parameters**\n", - "\n", - "**`number_of_epochs`:**Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for ~100 epochs. Evaluate the performance after training (see 5). **Default value: 80**\n", - "\n", - "**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 16**\n", - "\n", - "**`number_of_steps`:** Define the number of training steps by epoch. **If this value is set to 0**, by default this parameter is calculated so that each patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n", - "\n", - "**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during training. **Default value: 30** \n", - "\n", - "**`initial_learning_rate`:** This parameter represents the initial value to be used as learning rate in the optimizer. **Default value: 0.001**" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "oa5cDZ7f_PF6" - }, - "outputs": [], - "source": [ - "#@markdown ###Path to training images and parameters\n", - "\n", - "model_path = \"\" #@param {type: \"string\"} \n", - "model_name = \"\" #@param {type: \"string\"} \n", - "number_of_epochs = 80#@param {type:\"integer\"}\n", - "batch_size = 16#@param {type:\"integer\"}\n", - "\n", - "number_of_steps = 0#@param {type:\"integer\"}\n", - "percentage_validation = 30 #@param {type:\"number\"}\n", - "initial_learning_rate = 0.001 #@param {type:\"number\"}\n", - "\n", - "\n", - "percentage_validation /= 100\n", - "if number_of_steps == 0: \n", - " number_of_steps = int((1-percentage_validation)*n_patches/batch_size)\n", - " print('Number of steps: '+str(number_of_steps))\n", - "\n", - "# Pretrained model path initialised here so next cell does not need to be run\n", - "h5_file_path = ''\n", - "Use_pretrained_model = False\n", - "\n", - "if not ('patches' in locals()):\n", - " # W = '\\033[0m' # white (normal)\n", - " # R = '\\033[31m' # red\n", - " print(WARNING+'!! WARNING: No patches were found in memory currently. !!')\n", - "\n", - "Save_path = os.path.join(model_path, model_name)\n", - "if os.path.exists(Save_path):\n", - " print(bcolors.WARNING+'The model folder already exists and will be overwritten.'+bcolors.NORMAL)\n", - "\n", - "print('-----------------------------')\n", - "print('Training parameters set.')\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "WIyEvQBWLp9n" - }, - "source": [ - "\n", - "## **4.2. Using weights from a pre-trained model as initial weights**\n", - "---\n", - " Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a Deep-STORM 2D model**. \n", - "\n", - " This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n", - "\n", - " In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "oHL5g0w8LqR0" - }, - "outputs": [], - "source": [ - "# @markdown ##Loading weights from a pre-trained network\n", - "\n", - "Use_pretrained_model = False #@param {type:\"boolean\"}\n", - "pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n", - "Weights_choice = \"best\" #@param [\"last\", \"best\"]\n", - "\n", - "#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n", - "pretrained_model_path = \"\" #@param {type:\"string\"}\n", - "\n", - "# --------------------- Check if we load a previously trained model ------------------------\n", - "if Use_pretrained_model:\n", - "\n", - "# --------------------- Load the model from the choosen path ------------------------\n", - " if pretrained_model_choice == \"Model_from_file\":\n", - " h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".hdf5\")\n", - "\n", - "# --------------------- Download the a model provided in the XXX ------------------------\n", - "\n", - " if pretrained_model_choice == \"Model_name\":\n", - " pretrained_model_name = \"Model_name\"\n", - " pretrained_model_path = \"/content/\"+pretrained_model_name\n", - " print(\"Downloading the 2D_Demo_Model_from_Stardist_2D_paper\")\n", - " if os.path.exists(pretrained_model_path):\n", - " shutil.rmtree(pretrained_model_path)\n", - " os.makedirs(pretrained_model_path)\n", - " wget.download(\"\", pretrained_model_path)\n", - " wget.download(\"\", pretrained_model_path)\n", - " wget.download(\"\", pretrained_model_path) \n", - " wget.download(\"\", pretrained_model_path)\n", - " h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".hdf5\")\n", - "\n", - "# --------------------- Add additional pre-trained models here ------------------------\n", - "\n", - "\n", - "\n", - "# --------------------- Check the model exist ------------------------\n", - "# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n", - " if not os.path.exists(h5_file_path):\n", - " print(bcolors.WARNING+'WARNING: weights_'+Weights_choice+'.hdf5 pretrained model does not exist'+bcolors.NORMAL)\n", - " Use_pretrained_model = False\n", - "\n", - " \n", - "# If the model path contains a pretrain model, we load the training rate, \n", - " if os.path.exists(h5_file_path):\n", - "#Here we check if the learning rate can be loaded from the quality control folder\n", - " if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n", - " with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n", - " csvRead = pd.read_csv(csvfile, sep=',')\n", - " #print(csvRead)\n", - " if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n", - " print(\"pretrained network learning rate found\")\n", - " #find the last learning rate\n", - " lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n", - " #Find the learning rate corresponding to the lowest validation loss\n", - " min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n", - " #print(min_val_loss)\n", - " bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n", - " if Weights_choice == \"last\":\n", - " print('Last learning rate: '+str(lastLearningRate))\n", - " if Weights_choice == \"best\":\n", - " print('Learning rate of best validation loss: '+str(bestLearningRate))\n", - " if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n", - " bestLearningRate = initial_learning_rate\n", - " lastLearningRate = initial_learning_rate\n", - " print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead.'+bcolors.NORMAL)\n", - "\n", - "#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n", - " if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n", - " print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+bcolors.NORMAL)\n", - " bestLearningRate = initial_learning_rate\n", - " lastLearningRate = initial_learning_rate\n", - "\n", - "\n", - "# Display info about the pretrained model to be loaded (or not)\n", - "if Use_pretrained_model:\n", - " print('Weights found in:')\n", - " print(h5_file_path)\n", - " print('will be loaded prior to training.')\n", - "\n", - "else:\n", - " print('No pretrained network will be used.')\n", - " h5_file_path = ''\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "OADNcie-LHxA" - }, - "source": [ - "## **4.4. Start Training**\n", - "---\n", - "When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n", - "\n", - "* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches.\n", - "\n", - "Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder from Google Drive as all data can be erased at the next training if using the same folder." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "qDgMu_mAK8US" - }, - "outputs": [], - "source": [ - "#@markdown ##Start training\n", - "\n", - "# Start the clock to measure how long it takes\n", - "start = time.time()\n", - "\n", - "# --------------------- Using pretrained model ------------------------\n", - "#Here we ensure that the learning rate set correctly when using pre-trained models\n", - "if Use_pretrained_model:\n", - " if Weights_choice == \"last\":\n", - " initial_learning_rate = lastLearningRate\n", - "\n", - " if Weights_choice == \"best\": \n", - " initial_learning_rate = bestLearningRate\n", - "# --------------------- ---------------------- ------------------------\n", - "\n", - "\n", - "#here we check that no model with the same name already exist, if so delete\n", - "if os.path.exists(Save_path):\n", - " shutil.rmtree(Save_path)\n", - "\n", - "# Create the model folder!\n", - "os.makedirs(Save_path)\n", - "\n", - "# Export pdf summary \n", - "pdf_export(raw_data = load_raw_data, pretrained_model = Use_pretrained_model)\n", - "\n", - "# Let's go !\n", - "train_model(patches, heatmaps, Save_path, \n", - " steps_per_epoch=number_of_steps, epochs=number_of_epochs, batch_size=batch_size,\n", - " upsampling_factor = upsampling_factor,\n", - " validation_split = percentage_validation,\n", - " initial_learning_rate = initial_learning_rate, \n", - " pretrained_model_path = h5_file_path,\n", - " L2_weighting_factor = L2_weighting_factor)\n", - "\n", - "# # Show info about the GPU memory useage\n", - "# !nvidia-smi\n", - "\n", - "# Displaying the time elapsed for training\n", - "dt = time.time() - start\n", - "minutes, seconds = divmod(dt, 60) \n", - "hours, minutes = divmod(minutes, 60) \n", - "print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n", - "\n", - "# export pdf after training to update the existing document\n", - "pdf_export(trained = True, raw_data = load_raw_data, pretrained_model = Use_pretrained_model)\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "4N7-ShZpLhwr" - }, - "source": [ - "# **5. Evaluate your model**\n", - "---\n", - "\n", - "This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n", - "\n", - "**We highly recommend to perform quality control on all newly trained models.**" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "JDRsm7uKoBa-" - }, - "outputs": [], - "source": [ - "# model name and path\n", - "#@markdown ###Do you want to assess the model you just trained ?\n", - "Use_the_current_trained_model = True #@param {type:\"boolean\"}\n", - "\n", - "#@markdown ###If not, please provide the path to the model folder:\n", - "#@markdown #####During training, the model files are automatically saved inside a folder named after the parameter `model_name` (see section 4.1). Provide the name of this folder as `QC_model_path` . \n", - "\n", - "QC_model_path = \"\" #@param {type:\"string\"}\n", - "\n", - "if (Use_the_current_trained_model): \n", - " QC_model_path = os.path.join(model_path, model_name)\n", - "\n", - "if os.path.exists(QC_model_path):\n", - " print(\"The \"+os.path.basename(QC_model_path)+\" model will be evaluated\")\n", - "else:\n", - " print(bcolors.WARNING+'!! WARNING: The chosen model does not exist !!'+bcolors.NORMAL)\n", - " print('Please make sure you provide a valid model path before proceeding further.')\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Gw7KaHZUoHC4" - }, - "source": [ - "## **5.1. Inspection of the loss function**\n", - "---\n", - "\n", - "First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n", - "\n", - "**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n", - "\n", - "**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n", - "\n", - "During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n", - "\n", - "Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "qUc-JMOcoGNZ" - }, - "outputs": [], - "source": [ - "#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n", - "import csv\n", - "from matplotlib import pyplot as plt\n", - "\n", - "lossDataFromCSV = []\n", - "vallossDataFromCSV = []\n", - "\n", - "with open(os.path.join(QC_model_path,'Quality Control/training_evaluation.csv'),'r') as csvfile:\n", - " csvRead = csv.reader(csvfile, delimiter=',')\n", - " next(csvRead)\n", - " for row in csvRead:\n", - " lossDataFromCSV.append(float(row[0]))\n", - " vallossDataFromCSV.append(float(row[1]))\n", - "\n", - "epochNumber = range(len(lossDataFromCSV))\n", - "plt.figure(figsize=(15,10))\n", - "\n", - "plt.subplot(2,1,1)\n", - "plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n", - "plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n", - "plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n", - "plt.ylabel('Loss')\n", - "plt.xlabel('Epoch number')\n", - "plt.legend()\n", - "\n", - "plt.subplot(2,1,2)\n", - "plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n", - "plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n", - "plt.title('Training loss and validation loss vs. epoch number (log scale)')\n", - "plt.ylabel('Loss')\n", - "plt.xlabel('Epoch number')\n", - "plt.legend()\n", - "plt.savefig(os.path.join(QC_model_path,'Quality Control/lossCurvePlots.png'), bbox_inches='tight', pad_inches=0)\n", - "plt.show()\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "32eNQjFioQkY" - }, - "source": [ - "## **5.2. Error mapping and quality metrics estimation**\n", - "---\n", - "\n", - "This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"QC_image_folder\" using teh corresponding localization data contained in \"QC_loc_folder\" !\n", - "\n", - "**1. The SSIM (structural similarity) map** \n", - "\n", - "The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n", - "\n", - "**mSSIM** is the SSIM value calculated across the entire window of both images.\n", - "\n", - "**The output below shows the SSIM maps with the mSSIM**\n", - "\n", - "**2. The RSE (Root Squared Error) map** \n", - "\n", - "This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n", - "\n", - "\n", - "**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n", - "\n", - "**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n", - "\n", - "**The output below shows the RSE maps with the NRMSE and PSNR values.**\n", - "\n", - "\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "dhlTnxC5lUZy" - }, - "outputs": [], - "source": [ - "\n", - "# ------------------------ User input ------------------------\n", - "#@markdown ##Choose the folders that contain your Quality Control dataset\n", - "QC_image_folder = \"\" #@param{type:\"string\"}\n", - "QC_loc_folder = \"\" #@param{type:\"string\"}\n", - "#@markdown Get pixel size from file?\n", - "get_pixel_size_from_file = True #@param {type:\"boolean\"}\n", - "#@markdown Otherwise, use this value:\n", - "pixel_size = 100 #@param {type:\"number\"}\n", - "\n", - "if get_pixel_size_from_file:\n", - " pixel_size_INPUT = None\n", - "else:\n", - " pixel_size_INPUT = pixel_size\n", - "\n", - "\n", - "# ------------------------ QC analysis loop over provided dataset ------------------------\n", - "\n", - "savePath = os.path.join(QC_model_path, 'Quality Control')\n", - "\n", - "# Open and create the csv file that will contain all the QC metrics\n", - "with open(os.path.join(savePath, os.path.basename(QC_model_path)+\"_QC_metrics.csv\"), \"w\", newline='') as file:\n", - " writer = csv.writer(file)\n", - "\n", - " # Write the header in the csv file\n", - " writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"WF v. GT mSSIM\", \"Prediction v. GT NRMSE\",\"WF v. GT NRMSE\", \"Prediction v. GT PSNR\", \"WF v. GT PSNR\"])\n", - "\n", - " # These lists will be used to collect all the metrics values per slice\n", - " file_name_list = []\n", - " slice_number_list = []\n", - " mSSIM_GvP_list = []\n", - " mSSIM_GvWF_list = []\n", - " NRMSE_GvP_list = []\n", - " NRMSE_GvWF_list = []\n", - " PSNR_GvP_list = []\n", - " PSNR_GvWF_list = []\n", - "\n", - " # Let's loop through the provided dataset in the QC folders\n", - "\n", - " for (imageFilename, locFilename) in zip(list_files(QC_image_folder, 'tif'), list_files(QC_loc_folder, 'csv')):\n", - " print('--------------')\n", - " print(imageFilename)\n", - " print(locFilename)\n", - "\n", - " # Get the prediction\n", - " batchFramePredictionLocalization(QC_image_folder, imageFilename, QC_model_path, savePath, pixel_size = pixel_size_INPUT)\n", - "\n", - " # test_model(QC_image_folder, imageFilename, QC_model_path, savePath, display=False);\n", - " thisPrediction = io.imread(os.path.join(savePath, 'Predicted_'+imageFilename))\n", - " thisWidefield = io.imread(os.path.join(savePath, 'Widefield_'+imageFilename))\n", - "\n", - " Mhr = thisPrediction.shape[0]\n", - " Nhr = thisPrediction.shape[1]\n", - "\n", - " if pixel_size_INPUT == None:\n", - " pixel_size, N, M = getPixelSizeTIFFmetadata(os.path.join(QC_image_folder,imageFilename))\n", - "\n", - " upsampling_factor = int(Mhr/M)\n", - " print('Upsampling factor: '+str(upsampling_factor))\n", - " pixel_size_hr = pixel_size/upsampling_factor # in nm\n", - "\n", - " # Load the localization file and display the first\n", - " LocData = pd.read_csv(os.path.join(QC_loc_folder,locFilename), index_col=0)\n", - "\n", - " x = np.array(list(LocData['x [nm]']))\n", - " y = np.array(list(LocData['y [nm]']))\n", - " locImage = FromLoc2Image_SimpleHistogram(x, y, image_size = (Mhr,Nhr), pixel_size = pixel_size_hr)\n", - "\n", - " # Remove extension from filename\n", - " imageFilename_no_extension = os.path.splitext(imageFilename)[0]\n", - "\n", - " # io.imsave(os.path.join(savePath, 'GT_image_'+imageFilename), locImage)\n", - " saveAsTIF(savePath, 'GT_image_'+imageFilename_no_extension, locImage, pixel_size_hr)\n", - "\n", - " # Normalize the images wrt each other by minimizing the MSE between GT and prediction\n", - " test_GT_norm, test_prediction_norm = norm_minmse(locImage, thisPrediction, normalize_gt=True)\n", - " # Normalize the images wrt each other by minimizing the MSE between GT and Source image\n", - " test_GT_norm, test_wf_norm = norm_minmse(locImage, thisWidefield, normalize_gt=True)\n", - "\n", - " # -------------------------------- Calculate the metric maps and save them --------------------------------\n", - "\n", - " # Calculate the SSIM maps\n", - " index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = structural_similarity(test_GT_norm, test_prediction_norm, data_range=1., full=True)\n", - " index_SSIM_GTvsWF, img_SSIM_GTvsWF = structural_similarity(test_GT_norm, test_wf_norm, data_range=1., full=True)\n", - "\n", - "\n", - " # Save ssim_maps\n", - " img_SSIM_GTvsPrediction_32bit = np.float32(img_SSIM_GTvsPrediction)\n", - " # io.imsave(os.path.join(savePath,'SSIM_GTvsPrediction_'+imageFilename),img_SSIM_GTvsPrediction_32bit)\n", - " saveAsTIF(savePath,'SSIM_GTvsPrediction_'+imageFilename_no_extension, img_SSIM_GTvsPrediction_32bit, pixel_size_hr)\n", - "\n", - "\n", - " img_SSIM_GTvsWF_32bit = np.float32(img_SSIM_GTvsWF)\n", - " # io.imsave(os.path.join(savePath,'SSIM_GTvsWF_'+imageFilename),img_SSIM_GTvsWF_32bit)\n", - " saveAsTIF(savePath,'SSIM_GTvsWF_'+imageFilename_no_extension, img_SSIM_GTvsWF_32bit, pixel_size_hr)\n", - "\n", - " \n", - " # Calculate the Root Squared Error (RSE) maps\n", - " img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n", - " img_RSE_GTvsWF = np.sqrt(np.square(test_GT_norm - test_wf_norm))\n", - "\n", - " # Save SE maps\n", - " img_RSE_GTvsPrediction_32bit = np.float32(img_RSE_GTvsPrediction)\n", - " # io.imsave(os.path.join(savePath,'RSE_GTvsPrediction_'+imageFilename),img_RSE_GTvsPrediction_32bit)\n", - " saveAsTIF(savePath,'RSE_GTvsPrediction_'+imageFilename_no_extension, img_RSE_GTvsPrediction_32bit, pixel_size_hr)\n", - "\n", - " img_RSE_GTvsWF_32bit = np.float32(img_RSE_GTvsWF)\n", - " # io.imsave(os.path.join(savePath,'RSE_GTvsWF_'+imageFilename),img_RSE_GTvsWF_32bit)\n", - " saveAsTIF(savePath,'RSE_GTvsWF_'+imageFilename_no_extension, img_RSE_GTvsWF_32bit, pixel_size_hr)\n", - "\n", - "\n", - " # -------------------------------- Calculate the RSE metrics and save them --------------------------------\n", - "\n", - " # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n", - " NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n", - " NRMSE_GTvsWF = np.sqrt(np.mean(img_RSE_GTvsWF))\n", - " \n", - " # We can also measure the peak signal to noise ratio between the images\n", - " PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n", - " PSNR_GTvsWF = psnr(test_GT_norm,test_wf_norm,data_range=1.0)\n", - "\n", - " writer.writerow([imageFilename,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsWF),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsWF),str(PSNR_GTvsPrediction), str(PSNR_GTvsWF)])\n", - "\n", - " # Collect values to display in dataframe output\n", - " file_name_list.append(imageFilename)\n", - " mSSIM_GvP_list.append(index_SSIM_GTvsPrediction)\n", - " mSSIM_GvWF_list.append(index_SSIM_GTvsWF)\n", - " NRMSE_GvP_list.append(NRMSE_GTvsPrediction)\n", - " NRMSE_GvWF_list.append(NRMSE_GTvsWF)\n", - " PSNR_GvP_list.append(PSNR_GTvsPrediction)\n", - " PSNR_GvWF_list.append(PSNR_GTvsWF)\n", - "\n", - "\n", - "# Table with metrics as dataframe output\n", - "pdResults = pd.DataFrame(index = file_name_list)\n", - "pdResults[\"Prediction v. GT mSSIM\"] = mSSIM_GvP_list\n", - "pdResults[\"Wide-field v. GT mSSIM\"] = mSSIM_GvWF_list\n", - "pdResults[\"Prediction v. GT NRMSE\"] = NRMSE_GvP_list\n", - "pdResults[\"Wide-field v. GT NRMSE\"] = NRMSE_GvWF_list\n", - "pdResults[\"Prediction v. GT PSNR\"] = PSNR_GvP_list\n", - "pdResults[\"Wide-field v. GT PSNR\"] = PSNR_GvWF_list\n", - "\n", - "\n", - "# ------------------------ Display ------------------------\n", - "\n", - "print('--------------------------------------------')\n", - "@interact\n", - "def show_QC_results(file = list_files(QC_image_folder, 'tif')):\n", - "\n", - " plt.figure(figsize=(15,15))\n", - " # Target (Ground-truth)\n", - " plt.subplot(3,3,1)\n", - " plt.axis('off')\n", - " img_GT = io.imread(os.path.join(savePath, 'GT_image_'+file))\n", - " plt.imshow(img_GT, norm = simple_norm(img_GT, percent = 99.5))\n", - " plt.title('Target',fontsize=15)\n", - "\n", - " # Wide-field\n", - " plt.subplot(3,3,2)\n", - " plt.axis('off')\n", - " img_Source = io.imread(os.path.join(savePath, 'Widefield_'+file))\n", - " plt.imshow(img_Source, norm = simple_norm(img_Source, percent = 99.5))\n", - " plt.title('Widefield',fontsize=15)\n", - "\n", - " #Prediction\n", - " plt.subplot(3,3,3)\n", - " plt.axis('off')\n", - " img_Prediction = io.imread(os.path.join(savePath, 'Predicted_'+file))\n", - " plt.imshow(img_Prediction, norm = simple_norm(img_Prediction, percent = 99.5))\n", - " plt.title('Prediction',fontsize=15)\n", - "\n", - " #Setting up colours\n", - " cmap = plt.cm.CMRmap\n", - "\n", - " #SSIM between GT and Source\n", - " plt.subplot(3,3,5)\n", - " #plt.axis('off')\n", - " plt.tick_params(\n", - " axis='both', # changes apply to the x-axis and y-axis\n", - " which='both', # both major and minor ticks are affected\n", - " bottom=False, # ticks along the bottom edge are off\n", - " top=False, # ticks along the top edge are off\n", - " left=False, # ticks along the left edge are off\n", - " right=False, # ticks along the right edge are off\n", - " labelbottom=False,\n", - " labelleft=False)\n", - " img_SSIM_GTvsWF = io.imread(os.path.join(savePath, 'SSIM_GTvsWF_'+file))\n", - " imSSIM_GTvsWF = plt.imshow(img_SSIM_GTvsWF, cmap = cmap, vmin=0, vmax=1)\n", - " plt.colorbar(imSSIM_GTvsWF,fraction=0.046, pad=0.04)\n", - " plt.title('Target vs. Widefield',fontsize=15)\n", - " plt.xlabel('mSSIM: '+str(round(pdResults.loc[file][\"Wide-field v. GT mSSIM\"],3)),fontsize=14)\n", - " plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n", - "\n", - " #SSIM between GT and Prediction\n", - " plt.subplot(3,3,6)\n", - " #plt.axis('off')\n", - " plt.tick_params(\n", - " axis='both', # changes apply to the x-axis and y-axis\n", - " which='both', # both major and minor ticks are affected\n", - " bottom=False, # ticks along the bottom edge are off\n", - " top=False, # ticks along the top edge are off\n", - " left=False, # ticks along the left edge are off\n", - " right=False, # ticks along the right edge are off\n", - " labelbottom=False,\n", - " labelleft=False)\n", - " img_SSIM_GTvsPrediction = io.imread(os.path.join(savePath, 'SSIM_GTvsPrediction_'+file))\n", - " imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n", - " plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n", - " plt.title('Target vs. Prediction',fontsize=15)\n", - " plt.xlabel('mSSIM: '+str(round(pdResults.loc[file][\"Prediction v. GT mSSIM\"],3)),fontsize=14)\n", - "\n", - " #Root Squared Error between GT and Source\n", - " plt.subplot(3,3,8)\n", - " #plt.axis('off')\n", - " plt.tick_params(\n", - " axis='both', # changes apply to the x-axis and y-axis\n", - " which='both', # both major and minor ticks are affected\n", - " bottom=False, # ticks along the bottom edge are off\n", - " top=False, # ticks along the top edge are off\n", - " left=False, # ticks along the left edge are off\n", - " right=False, # ticks along the right edge are off\n", - " labelbottom=False,\n", - " labelleft=False)\n", - " img_RSE_GTvsWF = io.imread(os.path.join(savePath, 'RSE_GTvsWF_'+file))\n", - " imRSE_GTvsWF = plt.imshow(img_RSE_GTvsWF, cmap = cmap, vmin=0, vmax = 1)\n", - " plt.colorbar(imRSE_GTvsWF,fraction=0.046,pad=0.04)\n", - " plt.title('Target vs. Widefield',fontsize=15)\n", - " plt.xlabel('NRMSE: '+str(round(pdResults.loc[file][\"Wide-field v. GT NRMSE\"],3))+', PSNR: '+str(round(pdResults.loc[file][\"Wide-field v. GT PSNR\"],3)),fontsize=14)\n", - " plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n", - "\n", - " #Root Squared Error between GT and Prediction\n", - " plt.subplot(3,3,9)\n", - " #plt.axis('off')\n", - " plt.tick_params(\n", - " axis='both', # changes apply to the x-axis and y-axis\n", - " which='both', # both major and minor ticks are affected\n", - " bottom=False, # ticks along the bottom edge are off\n", - " top=False, # ticks along the top edge are off\n", - " left=False, # ticks along the left edge are off\n", - " right=False, # ticks along the right edge are off\n", - " labelbottom=False,\n", - " labelleft=False)\n", - " img_RSE_GTvsPrediction = io.imread(os.path.join(savePath, 'RSE_GTvsPrediction_'+file))\n", - " imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n", - " plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n", - " plt.title('Target vs. Prediction',fontsize=15)\n", - " plt.xlabel('NRMSE: '+str(round(pdResults.loc[file][\"Prediction v. GT NRMSE\"],3))+', PSNR: '+str(round(pdResults.loc[file][\"Prediction v. GT PSNR\"],3)),fontsize=14)\n", - " plt.savefig(QC_model_path+'/Quality Control/QC_example_data.png', bbox_inches='tight', pad_inches=0)\n", - "print('--------------------------------------------')\n", - "pdResults.head()\n", - "\n", - "# Export pdf wth summary of QC results\n", - "qc_pdf_export()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "yTRou0izLjhd" - }, - "source": [ - "# **6. Using the trained model**\n", - "\n", - "---\n", - "\n", - "In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "eAf8aBDmWTx7" - }, - "source": [ - "## **6.1 Generate image prediction and localizations from unseen dataset**\n", - "---\n", - "\n", - "The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n", - "\n", - "**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n", - "\n", - "**`Result_folder`:** This folder will contain the found localizations csv.\n", - "\n", - "**`batch_size`:** This paramter determines how many frames are processed by any single pass on the GPU. A higher `batch_size` will make the prediction faster but will use more GPU memory. If an OutOfMemory (OOM) error occurs, decrease the `batch_size`. **DEFAULT: 4**\n", - "\n", - "**`threshold`:** This paramter determines threshold for local maxima finding. The value is expected to reside in the range **[0,1]**. A higher `threshold` will result in less localizations. **DEFAULT: 0.1**\n", - "\n", - "**`neighborhood_size`:** This paramter determines size of the neighborhood within which the prediction needs to be a local maxima in recovery pixels (CCD pixel/upsampling_factor). A high `neighborhood_size` will make the prediction slower and potentially discard nearby localizations. **DEFAULT: 3**\n", - "\n", - "**`use_local_average`:** This paramter determines whether to locally average the prediction in a 3x3 neighborhood to get the final localizations. If set to **True** it will make inference slightly slower depending on the size of the FOV. **DEFAULT: True**\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "7qn06T_A0lxf" - }, - "outputs": [], - "source": [ - "\n", - "# ------------------------------- User input -------------------------------\n", - "#@markdown ### Data parameters\n", - "Data_folder = \"\" #@param {type:\"string\"}\n", - "Result_folder = \"\" #@param {type:\"string\"}\n", - "#@markdown Get pixel size from file?\n", - "get_pixel_size_from_file = True #@param {type:\"boolean\"}\n", - "#@markdown Otherwise, use this value (in nm):\n", - "pixel_size = 100 #@param {type:\"number\"}\n", - "\n", - "#@markdown ### Model parameters\n", - "#@markdown Do you want to use the model you just trained?\n", - "Use_the_current_trained_model = True #@param {type:\"boolean\"}\n", - "#@markdown Otherwise, please provide path to the model folder below\n", - "prediction_model_path = \"\" #@param {type:\"string\"}\n", - "\n", - "#@markdown ### Prediction parameters\n", - "batch_size = 4#@param {type:\"integer\"}\n", - "\n", - "#@markdown ### Post processing parameters\n", - "threshold = 0.1#@param {type:\"number\"}\n", - "neighborhood_size = 3#@param {type:\"integer\"}\n", - "#@markdown Do you want to locally average the model output with CoG estimator ?\n", - "use_local_average = True #@param {type:\"boolean\"}\n", - "\n", - "\n", - "if get_pixel_size_from_file:\n", - " pixel_size = None\n", - "\n", - "if (Use_the_current_trained_model): \n", - " prediction_model_path = os.path.join(model_path, model_name)\n", - "\n", - "if os.path.exists(prediction_model_path):\n", - " print(\"The \"+os.path.basename(prediction_model_path)+\" model will be used.\")\n", - "else:\n", - " print(bcolors.WARNING+'!! WARNING: The chosen model does not exist !!'+bcolors.NORMAL)\n", - " print('Please make sure you provide a valid model path before proceeding further.')\n", - "\n", - "# inform user whether local averaging is being used\n", - "if use_local_average == True: \n", - " print('Using local averaging')\n", - "\n", - "if not os.path.exists(Result_folder):\n", - " print('Result folder was created.')\n", - " os.makedirs(Result_folder)\n", - "\n", - "\n", - "# ------------------------------- Run predictions -------------------------------\n", - "\n", - "start = time.time()\n", - "#%% This script tests the trained fully convolutional network based on the \n", - "# saved training weights, and normalization created using train_model.\n", - "\n", - "if os.path.isdir(Data_folder): \n", - " for filename in list_files(Data_folder, 'tif'):\n", - " # run the testing/reconstruction process\n", - " print(\"------------------------------------\")\n", - " print(\"Running prediction on: \"+ filename)\n", - " batchFramePredictionLocalization(Data_folder, filename, prediction_model_path, Result_folder, \n", - " batch_size, \n", - " threshold, \n", - " neighborhood_size, \n", - " use_local_average,\n", - " pixel_size = pixel_size)\n", - "\n", - "elif os.path.isfile(Data_folder):\n", - " batchFramePredictionLocalization(os.path.dirname(Data_folder), os.path.basename(Data_folder), prediction_model_path, Result_folder, \n", - " batch_size, \n", - " threshold, \n", - " neighborhood_size, \n", - " use_local_average, \n", - " pixel_size = pixel_size)\n", - "\n", - "\n", - "\n", - "print('--------------------------------------------------------------------')\n", - "# Displaying the time elapsed for training\n", - "dt = time.time() - start\n", - "minutes, seconds = divmod(dt, 60) \n", - "hours, minutes = divmod(minutes, 60) \n", - "print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n", - "\n", - "\n", - "# ------------------------------- Interactive display -------------------------------\n", - "\n", - "print('--------------------------------------------------------------------')\n", - "print('---------------------------- Previews ------------------------------')\n", - "print('--------------------------------------------------------------------')\n", - "\n", - "if os.path.isdir(Data_folder): \n", - " @interact\n", - " def show_QC_results(file = list_files(Data_folder, 'tif')):\n", - "\n", - " plt.figure(figsize=(15,7.5))\n", - " # Wide-field\n", - " plt.subplot(1,2,1)\n", - " plt.axis('off')\n", - " img_Source = io.imread(os.path.join(Result_folder, 'Widefield_'+file))\n", - " plt.imshow(img_Source, norm = simple_norm(img_Source, percent = 99.5))\n", - " plt.title('Widefield', fontsize=15)\n", - " # Prediction\n", - " plt.subplot(1,2,2)\n", - " plt.axis('off')\n", - " img_Prediction = io.imread(os.path.join(Result_folder, 'Predicted_'+file))\n", - " plt.imshow(img_Prediction, norm = simple_norm(img_Prediction, percent = 99.5))\n", - " plt.title('Predicted',fontsize=15)\n", - "\n", - "if os.path.isfile(Data_folder):\n", - "\n", - " plt.figure(figsize=(15,7.5))\n", - " # Wide-field\n", - " plt.subplot(1,2,1)\n", - " plt.axis('off')\n", - " img_Source = io.imread(os.path.join(Result_folder, 'Widefield_'+os.path.basename(Data_folder)))\n", - " plt.imshow(img_Source, norm = simple_norm(img_Source, percent = 99.5))\n", - " plt.title('Widefield', fontsize=15)\n", - " # Prediction\n", - " plt.subplot(1,2,2)\n", - " plt.axis('off')\n", - " img_Prediction = io.imread(os.path.join(Result_folder, 'Predicted_'+os.path.basename(Data_folder)))\n", - " plt.imshow(img_Prediction, norm = simple_norm(img_Prediction, percent = 99.5))\n", - " plt.title('Predicted',fontsize=15)\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ZekzexaPmzFZ" - }, - "source": [ - "## **6.2 Drift correction**\n", - "---\n", - "\n", - "The visualization above is the raw output of the network and displayed at the `upsampling_factor` chosen during model training. The display is a preview without any drift correction applied. This section performs drift correction using cross-correlation between time bins to estimate the drift.\n", - "\n", - "**`Loc_file_path`:** is the path to the localization file to use for visualization.\n", - "\n", - "**`original_image_path`:** is the path to the original image. This only serves to extract the original image size and pixel size to shape the visualization properly.\n", - "\n", - "**`visualization_pixel_size`:** This parameter corresponds to the pixel size to use for the image reconstructions used for the Drift Correction estmication (in **nm**). A smaller pixel size will be more precise but will take longer to compute. **DEFAULT: 20**\n", - "\n", - "**`number_of_bins`:** This parameter defines how many temporal bins are used across the full dataset. All localizations in each bins are used ot build an image. This image is used to find the drift with respect to the image obtained from the very first bin. A typical value would correspond to about 500 frames per bin. **DEFAULT: Total number of frames / 500**\n", - "\n", - "**`polynomial_fit_degree`:** The drift obtained for each temporal bins needs to be interpolated to every single frames. This is performed by polynomial fit, the degree of which is defined here. **DEFAULT: 4**\n", - "\n", - " The drift-corrected localization data is automaticaly saved in the `save_path` folder." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "hYtP_vh6mzUP" - }, - "outputs": [], - "source": [ - "# @markdown ##Data parameters\n", - "Loc_file_path = \"\" #@param {type:\"string\"}\n", - "# @markdown Provide information about original data. Get the info automatically from the raw data?\n", - "Get_info_from_file = True #@param {type:\"boolean\"}\n", - "# Loc_file_path = \"/content/gdrive/My Drive/Colab notebooks testing/DeepSTORM/Glia data from CL/Results from prediction/20200615-M6 with CoM localizations/Localizations_glia_actin_2D - 1-500fr_avg.csv\" #@param {type:\"string\"}\n", - "original_image_path = \"\" #@param {type:\"string\"}\n", - "# @markdown Otherwise, please provide image width, height (in pixels) and pixel size (in nm)\n", - "image_width = 256#@param {type:\"integer\"}\n", - "image_height = 256#@param {type:\"integer\"}\n", - "pixel_size = 100 #@param {type:\"number\"}\n", - "\n", - "# @markdown ##Drift correction parameters\n", - "visualization_pixel_size = 20#@param {type:\"number\"}\n", - "number_of_bins = 50#@param {type:\"integer\"}\n", - "polynomial_fit_degree = 4#@param {type:\"integer\"}\n", - "\n", - "# @markdown ##Saving parameters\n", - "save_path = '' #@param {type:\"string\"}\n", - "\n", - "\n", - "# Let's go !\n", - "start = time.time()\n", - "\n", - "# Get info from the raw file if selected\n", - "if Get_info_from_file:\n", - " pixel_size, image_width, image_height = getPixelSizeTIFFmetadata(original_image_path, display=True)\n", - "\n", - "# Read the localizations in\n", - "LocData = pd.read_csv(Loc_file_path)\n", - "\n", - "# Calculate a few variables \n", - "Mhr = int(math.ceil(image_height*pixel_size/visualization_pixel_size))\n", - "Nhr = int(math.ceil(image_width*pixel_size/visualization_pixel_size))\n", - "nFrames = max(LocData['frame'])\n", - "x_max = max(LocData['x [nm]'])\n", - "y_max = max(LocData['y [nm]'])\n", - "image_size = (Mhr, Nhr)\n", - "n_locs = len(LocData.index)\n", - "\n", - "print('Image size: '+str(image_size))\n", - "print('Number of frames in data: '+str(nFrames))\n", - "print('Number of localizations in data: '+str(n_locs))\n", - "\n", - "blocksize = math.ceil(nFrames/number_of_bins)\n", - "print('Number of frames per block: '+str(blocksize))\n", - "\n", - "blockDataFrame = LocData[(LocData['frame'] < blocksize)].copy()\n", - "xc_array = blockDataFrame['x [nm]'].to_numpy(dtype=np.float32)\n", - "yc_array = blockDataFrame['y [nm]'].to_numpy(dtype=np.float32)\n", - "\n", - "# Preparing the Reference image\n", - "photon_array = np.ones(yc_array.shape[0])\n", - "sigma_array = np.ones(yc_array.shape[0])\n", - "ImageRef = FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = image_size, pixel_size = visualization_pixel_size)\n", - "ImagesRef = np.rot90(ImageRef, k=2)\n", - "\n", - "xDrift = np.zeros(number_of_bins)\n", - "yDrift = np.zeros(number_of_bins)\n", - "\n", - "filename_no_extension = os.path.splitext(os.path.basename(Loc_file_path))[0]\n", - "\n", - "with open(os.path.join(save_path, filename_no_extension+\"_DriftCorrectionData.csv\"), \"w\", newline='') as file:\n", - " writer = csv.writer(file)\n", - "\n", - " # Write the header in the csv file\n", - " writer.writerow([\"Block #\", \"x-drift [nm]\",\"y-drift [nm]\"])\n", - "\n", - " for b in tqdm(range(number_of_bins)):\n", - "\n", - " blockDataFrame = LocData[(LocData['frame'] >= (b*blocksize)) & (LocData['frame'] < ((b+1)*blocksize))].copy()\n", - " xc_array = blockDataFrame['x [nm]'].to_numpy(dtype=np.float32)\n", - " yc_array = blockDataFrame['y [nm]'].to_numpy(dtype=np.float32)\n", - "\n", - " photon_array = np.ones(yc_array.shape[0])\n", - " sigma_array = np.ones(yc_array.shape[0])\n", - " ImageBlock = FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = image_size, pixel_size = visualization_pixel_size)\n", - "\n", - " XC = fftconvolve(ImagesRef, ImageBlock, mode = 'same')\n", - " yDrift[b], xDrift[b] = subPixelMaxLocalization(XC, method = 'CoM')\n", - "\n", - " # saveAsTIF(save_path, 'ImageBlock'+str(b), ImageBlock, visualization_pixel_size)\n", - " # saveAsTIF(save_path, 'XCBlock'+str(b), XC, visualization_pixel_size)\n", - " writer.writerow([str(b), str((xDrift[b]-xDrift[0])*visualization_pixel_size), str((yDrift[b]-yDrift[0])*visualization_pixel_size)])\n", - "\n", - "\n", - "print('--------------------------------------------------------------------')\n", - "# Displaying the time elapsed for training\n", - "dt = time.time() - start\n", - "minutes, seconds = divmod(dt, 60) \n", - "hours, minutes = divmod(minutes, 60) \n", - "print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n", - "\n", - "print('Fitting drift data...')\n", - "bin_number = np.arange(number_of_bins)*blocksize + blocksize/2\n", - "xDrift = (xDrift-xDrift[0])*visualization_pixel_size\n", - "yDrift = (yDrift-yDrift[0])*visualization_pixel_size\n", - "\n", - "xDriftCoeff = np.polyfit(bin_number, xDrift, polynomial_fit_degree)\n", - "yDriftCoeff = np.polyfit(bin_number, yDrift, polynomial_fit_degree)\n", - "\n", - "xDriftFit = np.poly1d(xDriftCoeff)\n", - "yDriftFit = np.poly1d(yDriftCoeff)\n", - "bins = np.arange(nFrames)\n", - "xDriftInterpolated = xDriftFit(bins)\n", - "yDriftInterpolated = yDriftFit(bins)\n", - "\n", - "\n", - "# ------------------ Displaying the image results ------------------\n", - "\n", - "plt.figure(figsize=(15,10))\n", - "plt.plot(bin_number,xDrift, 'r+', label='x-drift')\n", - "plt.plot(bin_number,yDrift, 'b+', label='y-drift')\n", - "plt.plot(bins,xDriftInterpolated, 'r-', label='y-drift (fit)')\n", - "plt.plot(bins,yDriftInterpolated, 'b-', label='y-drift (fit)')\n", - "plt.title('Cross-correlation estimated drift')\n", - "plt.ylabel('Drift [nm]')\n", - "plt.xlabel('Bin number')\n", - "plt.legend();\n", - "\n", - "dt = time.time() - start\n", - "minutes, seconds = divmod(dt, 60) \n", - "hours, minutes = divmod(minutes, 60) \n", - "print(\"Time elapsed:\", hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n", - "\n", - "\n", - "# ------------------ Actual drift correction -------------------\n", - "\n", - "print('Correcting localization data...')\n", - "xc_array = LocData['x [nm]'].to_numpy(dtype=np.float32)\n", - "yc_array = LocData['y [nm]'].to_numpy(dtype=np.float32)\n", - "frames = LocData['frame'].to_numpy(dtype=np.int32)\n", - "\n", - "\n", - "xc_array_Corr, yc_array_Corr = correctDriftLocalization(xc_array, yc_array, frames, xDriftInterpolated, yDriftInterpolated)\n", - "ImageRaw = FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = image_size, pixel_size = visualization_pixel_size)\n", - "ImageCorr = FromLoc2Image_SimpleHistogram(xc_array_Corr, yc_array_Corr, image_size = image_size, pixel_size = visualization_pixel_size)\n", - "\n", - "\n", - "# ------------------ Displaying the imge results ------------------\n", - "plt.figure(figsize=(15,7.5))\n", - "# Raw\n", - "plt.subplot(1,2,1)\n", - "plt.axis('off')\n", - "plt.imshow(ImageRaw, norm = simple_norm(ImageRaw, percent = 99.5))\n", - "plt.title('Raw', fontsize=15);\n", - "# Corrected\n", - "plt.subplot(1,2,2)\n", - "plt.axis('off')\n", - "plt.imshow(ImageCorr, norm = simple_norm(ImageCorr, percent = 99.5))\n", - "plt.title('Corrected',fontsize=15);\n", - "\n", - "\n", - "# ------------------ Table with info -------------------\n", - "driftCorrectedLocData = pd.DataFrame()\n", - "driftCorrectedLocData['frame'] = frames\n", - "driftCorrectedLocData['x [nm]'] = xc_array_Corr\n", - "driftCorrectedLocData['y [nm]'] = yc_array_Corr\n", - "driftCorrectedLocData['confidence [a.u]'] = LocData['confidence [a.u]']\n", - "\n", - "driftCorrectedLocData.to_csv(os.path.join(save_path, filename_no_extension+'_DriftCorrected.csv'))\n", - "print('-------------------------------')\n", - "print('Corrected localizations saved.')\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "mzOuc-V7rB-r" - }, - "source": [ - "## **6.3 Visualization of the localizations**\n", - "---\n", - "\n", - "\n", - "The visualization in section 6.1 is the raw output of the network and displayed at the `upsampling_factor` chosen during model training. This section performs visualization of the result by plotting the localizations as a simple histogram.\n", - "\n", - "**`Loc_file_path`:** is the path to the localization file to use for visualization.\n", - "\n", - "**`original_image_path`:** is the path to the original image. This only serves to extract the original image size and pixel size to shape the visualization properly.\n", - "\n", - "**`visualization_pixel_size`:** This parameter corresponds to the pixel size to use for the final image reconstruction (in **nm**). **DEFAULT: 10**\n", - "\n", - "**`visualization_mode`:** This parameter defines what visualization method is used to visualize the final image. NOTES: The Integrated Gaussian can be quite slow. **DEFAULT: Simple histogram.**\n", - "\n", - "\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "876yIXnqq-nW" - }, - "outputs": [], - "source": [ - "# @markdown ##Data parameters\n", - "Use_current_drift_corrected_localizations = True #@param {type:\"boolean\"}\n", - "# @markdown Otherwise provide a localization file path\n", - "Loc_file_path = \"\" #@param {type:\"string\"}\n", - "# @markdown Provide information about original data. Get the info automatically from the raw data?\n", - "Get_info_from_file = True #@param {type:\"boolean\"}\n", - "# Loc_file_path = \"/content/gdrive/My Drive/Colab notebooks testing/DeepSTORM/Glia data from CL/Results from prediction/20200615-M6 with CoM localizations/Localizations_glia_actin_2D - 1-500fr_avg.csv\" #@param {type:\"string\"}\n", - "original_image_path = \"\" #@param {type:\"string\"}\n", - "# @markdown Otherwise, please provide image width, height (in pixels) and pixel size (in nm)\n", - "image_width = 256#@param {type:\"integer\"}\n", - "image_height = 256#@param {type:\"integer\"}\n", - "pixel_size = 100#@param {type:\"number\"}\n", - "\n", - "# @markdown ##Visualization parameters\n", - "visualization_pixel_size = 10#@param {type:\"number\"}\n", - "visualization_mode = \"Simple histogram\" #@param [\"Simple histogram\", \"Integrated Gaussian (SLOW!)\"]\n", - "\n", - "if not Use_current_drift_corrected_localizations:\n", - " filename_no_extension = os.path.splitext(os.path.basename(Loc_file_path))[0]\n", - "\n", - "\n", - "if Get_info_from_file:\n", - " pixel_size, image_width, image_height = getPixelSizeTIFFmetadata(original_image_path, display=True)\n", - "\n", - "if Use_current_drift_corrected_localizations:\n", - " LocData = driftCorrectedLocData\n", - "else:\n", - " LocData = pd.read_csv(Loc_file_path)\n", - "\n", - "Mhr = int(math.ceil(image_height*pixel_size/visualization_pixel_size))\n", - "Nhr = int(math.ceil(image_width*pixel_size/visualization_pixel_size))\n", - "\n", - "\n", - "nFrames = max(LocData['frame'])\n", - "x_max = max(LocData['x [nm]'])\n", - "y_max = max(LocData['y [nm]'])\n", - "image_size = (Mhr, Nhr)\n", - "\n", - "print('Image size: '+str(image_size))\n", - "print('Number of frames in data: '+str(nFrames))\n", - "print('Number of localizations in data: '+str(len(LocData.index)))\n", - "\n", - "xc_array = LocData['x [nm]'].to_numpy()\n", - "yc_array = LocData['y [nm]'].to_numpy()\n", - "if (visualization_mode == 'Simple histogram'):\n", - " locImage = FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = image_size, pixel_size = visualization_pixel_size)\n", - "elif (visualization_mode == 'Shifted histogram'):\n", - " print(bcolors.WARNING+'Method not implemented yet!'+bcolors.NORMAL)\n", - " locImage = np.zeros(image_size)\n", - "elif (visualization_mode == 'Integrated Gaussian (SLOW!)'):\n", - " photon_array = np.ones(xc_array.shape)\n", - " sigma_array = np.ones(xc_array.shape)\n", - " locImage = FromLoc2Image_Erf(xc_array, yc_array, photon_array, sigma_array, image_size = image_size, pixel_size = visualization_pixel_size)\n", - "\n", - "print('--------------------------------------------------------------------')\n", - "# Displaying the time elapsed for training\n", - "dt = time.time() - start\n", - "minutes, seconds = divmod(dt, 60) \n", - "hours, minutes = divmod(minutes, 60) \n", - "print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n", - "\n", - "# Display\n", - "plt.figure(figsize=(20,10))\n", - "plt.axis('off')\n", - "# plt.imshow(locImage, cmap='gray');\n", - "plt.imshow(locImage, norm = simple_norm(locImage, percent = 99.5));\n", - "\n", - "\n", - "LocData.head()\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "PdOhWwMn1zIT" - }, - "outputs": [], - "source": [ - "# @markdown ---\n", - "# @markdown #Play this cell to save the visualization\n", - "# @markdown ####Please select a path to the folder where to save the visualization.\n", - "save_path = \"\" #@param {type:\"string\"}\n", - "\n", - "if not os.path.exists(save_path):\n", - " os.makedirs(save_path)\n", - " print('Folder created.')\n", - "\n", - "saveAsTIF(save_path, filename_no_extension+'_Visualization', locImage, visualization_pixel_size)\n", - "print('Image saved.')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1EszIF4Dkz_n" - }, - "source": [ - "## **6.4. Download your predictions**\n", - "---\n", - "\n", - "**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "UgN-NooKk3nV" - }, - "source": [ - "\n", - "#**Thank you for using Deep-STORM 2D!**" - ] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "collapsed_sections": [], - "machine_shape": "hm", - "name": "Deep-STORM_2D_ZeroCostDL4Mic.ipynb", - "provenance": [], - "toc_visible": true - }, - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.4" - } - }, - "nbformat": 4, - "nbformat_minor": 1 -} +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"Deep-STORM_2D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1kD3rjN5XX5C33cQuX1DVc_n89cMqNvS_","timestamp":1610633423190},{"file_id":"1w95RljMrg15FLDRnEJiLIEa-lW-jEjQS","timestamp":1602684895691},{"file_id":"169qcwQo-yw15PwoGatXAdBvjs4wt_foD","timestamp":1592147948265},{"file_id":"1gjRCgDORKi_GNBu4QnVCBkSWrfPtqL-E","timestamp":1588525976305},{"file_id":"1DFy6aCi1XAVdjA5KLRZirB2aMZkMFdv-","timestamp":1587998755430},{"file_id":"1NpzigQoXGy3GFdxh4_jvG1PnBfyrcpBs","timestamp":1587569988032},{"file_id":"1jdI540qAfMSQwjnMhoAFkGJH9EbHwNSf","timestamp":1587486196143}],"collapsed_sections":[],"toc_visible":true,"machine_shape":"hm"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"FpCtYevLHfl4"},"source":["# **Deep-STORM (2D)**\n","\n","---\n","\n","Deep-STORM is a neural network capable of image reconstruction from high-density single-molecule localization microscopy (SMLM), first published in 2018 by [Nehme *et al.* in Optica](https://www.osapublishing.org/optica/abstract.cfm?uri=optica-5-4-458). The architecture used here is a U-Net based network without skip connections. This network allows image reconstruction of 2D super-resolution images, in a supervised training manner. The network is trained using simulated high-density SMLM data for which the ground-truth is available. These simulations are obtained from random distribution of single molecules in a field-of-view and therefore do not imprint structural priors during training. The network output a super-resolution image with increased pixel density (typically upsampling factor of 8 in each dimension).\n","\n","Deep-STORM has **two key advantages**:\n","- SMLM reconstruction at high density of emitters\n","- fast prediction (reconstruction) once the model is trained appropriately, compared to more common multi-emitter fitting processes.\n","\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the *Zero-Cost Deep-Learning to Enhance Microscopy* project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is based on the following paper: \n","\n","**Deep-STORM: super-resolution single-molecule microscopy by deep learning**, Optica (2018) by *Elias Nehme, Lucien E. Weiss, Tomer Michaeli, and Yoav Shechtman* (https://www.osapublishing.org/optica/abstract.cfm?uri=optica-5-4-458)\n","\n","And source code found in: /~https://github.com/EliasNehme/Deep-STORM\n","\n","\n","**Please also cite this original paper when using or developing this notebook.**"]},{"cell_type":"markdown","metadata":{"id":"wyzTn3IcHq6Y"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"bEy4EBXHHyAX"},"source":["#**0. Before getting started**\n","---\n"," Deep-STORM is able to train on simulated dataset of SMLM data (see https://www.osapublishing.org/optica/abstract.cfm?uri=optica-5-4-458 for more info). Here, we provide a simulator that will generate training dataset (section 3.1.b). A few parameters will allow you to match the simulation to your experimental data. Similarly to what is described in the paper, simulations obtained from ThunderSTORM can also be loaded here (section 3.1.a).\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"jRnQZWSZhArJ"},"source":["# **1. Install Deep-STORM and dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"kSrZMo3X_NhO","cellView":"form"},"source":["Notebook_version = '1.13'\n","Network = 'Deep-STORM'\n","\n","\n","\n","from builtins import any as b_any\n","\n","def get_requirements_path():\n"," # Store requirements file in 'contents' directory \n"," current_dir = os.getcwd()\n"," dir_count = current_dir.count('/') - 1\n"," path = '../' * (dir_count) + 'requirements.txt'\n"," return path\n","\n","def filter_files(file_list, filter_list):\n"," filtered_list = []\n"," for fname in file_list:\n"," if b_any(fname.split('==')[0] in s for s in filter_list):\n"," filtered_list.append(fname)\n"," return filtered_list\n","\n","def build_requirements_file(before, after):\n"," path = get_requirements_path()\n","\n"," # Exporting requirements.txt for local run\n"," !pip freeze > $path\n","\n"," # Get minimum requirements file\n"," df = pd.read_csv(path, delimiter = \"\\n\")\n"," mod_list = [m.split('.')[0] for m in after if not m in before]\n"," req_list_temp = df.values.tolist()\n"," req_list = [x[0] for x in req_list_temp]\n","\n"," # Replace with package name and handle cases where import name is different to module name\n"," mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]\n"," mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list] \n"," filtered_list = filter_files(req_list, mod_replace_list)\n","\n"," file=open(path,'w')\n"," for item in filtered_list:\n"," file.writelines(item + '\\n')\n","\n"," file.close()\n","\n","import sys\n","before = [str(m) for m in sys.modules]\n","\n","#@markdown ##Install Deep-STORM and dependencies\n","# %% Model definition + helper functions\n","\n","!pip install fpdf\n","# Import keras modules and libraries\n","from tensorflow import keras\n","from tensorflow.keras.models import Model\n","from tensorflow.keras.layers import Input, Activation, UpSampling2D, Convolution2D, MaxPooling2D, BatchNormalization, Layer\n","from tensorflow.keras.callbacks import Callback\n","from tensorflow.keras import backend as K\n","from tensorflow.keras import optimizers, losses\n","\n","from tensorflow.keras.preprocessing.image import ImageDataGenerator\n","from tensorflow.keras.callbacks import ModelCheckpoint\n","from tensorflow.keras.callbacks import ReduceLROnPlateau\n","from skimage.transform import warp\n","from skimage.transform import SimilarityTransform\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from scipy.signal import fftconvolve\n","\n","# Import common libraries\n","import tensorflow as tf\n","import numpy as np\n","import pandas as pd\n","import matplotlib.pyplot as plt\n","import h5py\n","import scipy.io as sio\n","from os.path import abspath\n","from sklearn.model_selection import train_test_split\n","from skimage import io\n","import time\n","import os\n","import shutil\n","import csv\n","from PIL import Image \n","from PIL.TiffTags import TAGS\n","from scipy.ndimage import gaussian_filter\n","import math\n","from astropy.visualization import simple_norm\n","from sys import getsizeof\n","from fpdf import FPDF, HTMLMixin\n","from pip._internal.operations.freeze import freeze\n","import subprocess\n","from datetime import datetime\n","\n","# For sliders and dropdown menu, progress bar\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","from tqdm import tqdm\n","\n","# For Multi-threading in simulation\n","from numba import njit, prange\n","\n","\n","\n","# define a function that projects and rescales an image to the range [0,1]\n","def project_01(im):\n"," im = np.squeeze(im)\n"," min_val = im.min()\n"," max_val = im.max()\n"," return (im - min_val)/(max_val - min_val)\n","\n","# normalize image given mean and std\n","def normalize_im(im, dmean, dstd):\n"," im = np.squeeze(im)\n"," im_norm = np.zeros(im.shape,dtype=np.float32)\n"," im_norm = (im - dmean)/dstd\n"," return im_norm\n","\n","# Define the loss history recorder\n","class LossHistory(Callback):\n"," def on_train_begin(self, logs={}):\n"," self.losses = []\n","\n"," def on_batch_end(self, batch, logs={}):\n"," self.losses.append(logs.get('loss'))\n"," \n","# Define a matlab like gaussian 2D filter\n","def matlab_style_gauss2D(shape=(7,7),sigma=1):\n"," \"\"\" \n"," 2D gaussian filter - should give the same result as:\n"," MATLAB's fspecial('gaussian',[shape],[sigma]) \n"," \"\"\"\n"," m,n = [(ss-1.)/2. for ss in shape]\n"," y,x = np.ogrid[-m:m+1,-n:n+1]\n"," h = np.exp( -(x*x + y*y) / (2.*sigma*sigma) )\n"," h.astype(dtype=K.floatx())\n"," h[ h < np.finfo(h.dtype).eps*h.max() ] = 0\n"," sumh = h.sum()\n"," if sumh != 0:\n"," h /= sumh\n"," h = h*2.0\n"," h = h.astype('float32')\n"," return h\n","\n","# Expand the filter dimensions\n","psf_heatmap = matlab_style_gauss2D(shape = (7,7),sigma=1)\n","gfilter = tf.reshape(psf_heatmap, [7, 7, 1, 1])\n","\n","# Combined MSE + L1 loss\n","def L1L2loss(input_shape):\n"," def bump_mse(heatmap_true, spikes_pred):\n","\n"," # generate the heatmap corresponding to the predicted spikes\n"," heatmap_pred = K.conv2d(spikes_pred, gfilter, strides=(1, 1), padding='same')\n","\n"," # heatmaps MSE\n"," loss_heatmaps = losses.mean_squared_error(heatmap_true,heatmap_pred)\n","\n"," # l1 on the predicted spikes\n"," loss_spikes = losses.mean_absolute_error(spikes_pred,tf.zeros(input_shape))\n"," return loss_heatmaps + loss_spikes\n"," return bump_mse\n","\n","# Define the concatenated conv2, batch normalization, and relu block\n","def conv_bn_relu(nb_filter, rk, ck, name):\n"," def f(input):\n"," conv = Convolution2D(nb_filter, kernel_size=(rk, ck), strides=(1,1),\\\n"," padding=\"same\", use_bias=False,\\\n"," kernel_initializer=\"Orthogonal\",name='conv-'+name)(input)\n"," conv_norm = BatchNormalization(name='BN-'+name)(conv)\n"," conv_norm_relu = Activation(activation = \"relu\",name='Relu-'+name)(conv_norm)\n"," return conv_norm_relu\n"," return f\n","\n","# Define the model architechture\n","def CNN(input,names):\n"," Features1 = conv_bn_relu(32,3,3,names+'F1')(input)\n"," pool1 = MaxPooling2D(pool_size=(2,2),name=names+'Pool1')(Features1)\n"," Features2 = conv_bn_relu(64,3,3,names+'F2')(pool1)\n"," pool2 = MaxPooling2D(pool_size=(2, 2),name=names+'Pool2')(Features2)\n"," Features3 = conv_bn_relu(128,3,3,names+'F3')(pool2)\n"," pool3 = MaxPooling2D(pool_size=(2, 2),name=names+'Pool3')(Features3)\n"," Features4 = conv_bn_relu(512,3,3,names+'F4')(pool3)\n"," up5 = UpSampling2D(size=(2, 2),name=names+'Upsample1')(Features4)\n"," Features5 = conv_bn_relu(128,3,3,names+'F5')(up5)\n"," up6 = UpSampling2D(size=(2, 2),name=names+'Upsample2')(Features5)\n"," Features6 = conv_bn_relu(64,3,3,names+'F6')(up6)\n"," up7 = UpSampling2D(size=(2, 2),name=names+'Upsample3')(Features6)\n"," Features7 = conv_bn_relu(32,3,3,names+'F7')(up7)\n"," return Features7\n","\n","# Define the Model building for an arbitrary input size\n","def buildModel(input_dim, initial_learning_rate = 0.001):\n"," input_ = Input (shape = (input_dim))\n"," act_ = CNN (input_,'CNN')\n"," density_pred = Convolution2D(1, kernel_size=(1, 1), strides=(1, 1), padding=\"same\",\\\n"," activation=\"linear\", use_bias = False,\\\n"," kernel_initializer=\"Orthogonal\",name='Prediction')(act_)\n"," model = Model (inputs= input_, outputs=density_pred)\n"," opt = optimizers.Adam(lr = initial_learning_rate)\n"," model.compile(optimizer=opt, loss = L1L2loss(input_dim))\n"," return model\n","\n","\n","# define a function that trains a model for a given data SNR and density\n","def train_model(patches, heatmaps, modelPath, epochs, steps_per_epoch, batch_size, upsampling_factor=8, validation_split = 0.3, initial_learning_rate = 0.001, pretrained_model_path = '', L2_weighting_factor = 100):\n"," \n"," \"\"\"\n"," This function trains a CNN model on the desired training set, given the \n"," upsampled training images and labels generated in MATLAB.\n"," \n"," # Inputs\n"," # TO UPDATE ----------\n","\n"," # Outputs\n"," function saves the weights of the trained model to a hdf5, and the \n"," normalization factors to a mat file. These will be loaded later for testing \n"," the model in test_model. \n"," \"\"\"\n"," \n"," # for reproducibility\n"," np.random.seed(123)\n","\n"," X_train, X_test, y_train, y_test = train_test_split(patches, heatmaps, test_size = validation_split, random_state=42)\n"," print('Number of training examples: %d' % X_train.shape[0])\n"," print('Number of validation examples: %d' % X_test.shape[0])\n"," \n"," # Setting type\n"," X_train = X_train.astype('float32')\n"," X_test = X_test.astype('float32')\n"," y_train = y_train.astype('float32')\n"," y_test = y_test.astype('float32')\n","\n"," \n"," #===================== Training set normalization ==========================\n"," # normalize training images to be in the range [0,1] and calculate the \n"," # training set mean and std\n"," mean_train = np.zeros(X_train.shape[0],dtype=np.float32)\n"," std_train = np.zeros(X_train.shape[0], dtype=np.float32)\n"," for i in range(X_train.shape[0]):\n"," X_train[i, :, :] = project_01(X_train[i, :, :])\n"," mean_train[i] = X_train[i, :, :].mean()\n"," std_train[i] = X_train[i, :, :].std()\n","\n"," # resulting normalized training images\n"," mean_val_train = mean_train.mean()\n"," std_val_train = std_train.mean()\n"," X_train_norm = np.zeros(X_train.shape, dtype=np.float32)\n"," for i in range(X_train.shape[0]):\n"," X_train_norm[i, :, :] = normalize_im(X_train[i, :, :], mean_val_train, std_val_train)\n"," \n"," # patch size\n"," psize = X_train_norm.shape[1]\n","\n"," # Reshaping\n"," X_train_norm = X_train_norm.reshape(X_train.shape[0], psize, psize, 1)\n","\n"," # ===================== Test set normalization ==========================\n"," # normalize test images to be in the range [0,1] and calculate the test set \n"," # mean and std\n"," mean_test = np.zeros(X_test.shape[0],dtype=np.float32)\n"," std_test = np.zeros(X_test.shape[0], dtype=np.float32)\n"," for i in range(X_test.shape[0]):\n"," X_test[i, :, :] = project_01(X_test[i, :, :])\n"," mean_test[i] = X_test[i, :, :].mean()\n"," std_test[i] = X_test[i, :, :].std()\n","\n"," # resulting normalized test images\n"," mean_val_test = mean_test.mean()\n"," std_val_test = std_test.mean()\n"," X_test_norm = np.zeros(X_test.shape, dtype=np.float32)\n"," for i in range(X_test.shape[0]):\n"," X_test_norm[i, :, :] = normalize_im(X_test[i, :, :], mean_val_test, std_val_test)\n"," \n"," # Reshaping\n"," X_test_norm = X_test_norm.reshape(X_test.shape[0], psize, psize, 1)\n","\n"," # Reshaping labels\n"," Y_train = y_train.reshape(y_train.shape[0], psize, psize, 1)\n"," Y_test = y_test.reshape(y_test.shape[0], psize, psize, 1)\n","\n"," # Save datasets to a matfile to open later in matlab\n"," mdict = {\"mean_test\": mean_val_test, \"std_test\": std_val_test, \"upsampling_factor\": upsampling_factor, \"Normalization factor\": L2_weighting_factor}\n"," sio.savemat(os.path.join(modelPath,\"model_metadata.mat\"), mdict)\n","\n","\n"," # Set the dimensions ordering according to tensorflow consensous\n"," # K.set_image_dim_ordering('tf')\n"," K.set_image_data_format('channels_last')\n","\n"," # Save the model weights after each epoch if the validation loss decreased\n"," checkpointer = ModelCheckpoint(filepath=os.path.join(modelPath,\"weights_best.hdf5\"), verbose=1,\n"," save_best_only=True)\n","\n"," # Change learning when loss reaches a plataeu\n"," change_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5, min_lr=0.00005)\n"," \n"," # Model building and complitation\n"," model = buildModel((psize, psize, 1), initial_learning_rate = initial_learning_rate)\n"," model.summary()\n","\n"," # Load pretrained model\n"," if not pretrained_model_path:\n"," print('Using random initial model weights.')\n"," else:\n"," print('Loading model weights from '+pretrained_model_path)\n"," model.load_weights(pretrained_model_path)\n"," \n"," # Create an image data generator for real time data augmentation\n"," datagen = ImageDataGenerator(\n"," featurewise_center=False, # set input mean to 0 over the dataset\n"," samplewise_center=False, # set each sample mean to 0\n"," featurewise_std_normalization=False, # divide inputs by std of the dataset\n"," samplewise_std_normalization=False, # divide each input by its std\n"," zca_whitening=False, # apply ZCA whitening\n"," rotation_range=0., # randomly rotate images in the range (degrees, 0 to 180)\n"," width_shift_range=0., # randomly shift images horizontally (fraction of total width)\n"," height_shift_range=0., # randomly shift images vertically (fraction of total height)\n"," zoom_range=0.,\n"," shear_range=0.,\n"," horizontal_flip=False, # randomly flip images\n"," vertical_flip=False, # randomly flip images\n"," fill_mode='constant',\n"," data_format=K.image_data_format())\n","\n"," # Fit the image generator on the training data\n"," datagen.fit(X_train_norm)\n"," \n"," # loss history recorder\n"," history = LossHistory()\n","\n"," # Inform user training begun\n"," print('-------------------------------')\n"," print('Training model...')\n","\n"," # Fit model on the batches generated by datagen.flow()\n"," train_history = model.fit_generator(datagen.flow(X_train_norm, Y_train, batch_size=batch_size), \n"," steps_per_epoch=steps_per_epoch, epochs=epochs, verbose=1, \n"," validation_data=(X_test_norm, Y_test), \n"," callbacks=[history, checkpointer, change_lr]) \n","\n"," # Inform user training ended\n"," print('-------------------------------')\n"," print('Training Complete!')\n"," \n"," # Save the last model\n"," model.save(os.path.join(modelPath, 'weights_last.hdf5'))\n","\n"," # convert the history.history dict to a pandas DataFrame: \n"," lossData = pd.DataFrame(train_history.history) \n","\n"," if os.path.exists(os.path.join(modelPath,\"Quality Control\")):\n"," shutil.rmtree(os.path.join(modelPath,\"Quality Control\"))\n","\n"," os.makedirs(os.path.join(modelPath,\"Quality Control\"))\n","\n"," # The training evaluation.csv is saved (overwrites the Files if needed). \n"," lossDataCSVpath = os.path.join(modelPath,\"Quality Control/training_evaluation.csv\")\n"," with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss','learning rate'])\n"," for i in range(len(train_history.history['loss'])):\n"," writer.writerow([train_history.history['loss'][i], train_history.history['val_loss'][i], train_history.history['lr'][i]])\n","\n"," return\n","\n","\n","# Normalization functions from Martin Weigert used in CARE\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","\n","# Multi-threaded Erf-based image construction\n","@njit(parallel=True)\n","def FromLoc2Image_Erf(xc_array, yc_array, photon_array, sigma_array, image_size = (64,64), pixel_size = 100):\n"," w = image_size[0]\n"," h = image_size[1]\n"," erfImage = np.zeros((w, h))\n"," for ij in prange(w*h):\n"," j = int(ij/w)\n"," i = ij - j*w\n"," for (xc, yc, photon, sigma) in zip(xc_array, yc_array, photon_array, sigma_array):\n"," # Don't bother if the emitter has photons <= 0 or if Sigma <= 0\n"," if (sigma > 0) and (photon > 0):\n"," S = sigma*math.sqrt(2)\n"," x = i*pixel_size - xc\n"," y = j*pixel_size - yc\n"," # Don't bother if the emitter is further than 4 sigma from the centre of the pixel\n"," if (x+pixel_size/2)**2 + (y+pixel_size/2)**2 < 16*sigma**2:\n"," ErfX = math.erf((x+pixel_size)/S) - math.erf(x/S)\n"," ErfY = math.erf((y+pixel_size)/S) - math.erf(y/S)\n"," erfImage[j][i] += 0.25*photon*ErfX*ErfY\n"," return erfImage\n","\n","\n","@njit(parallel=True)\n","def FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = (64,64), pixel_size = 100):\n"," w = image_size[0]\n"," h = image_size[1]\n"," locImage = np.zeros((image_size[0],image_size[1]) )\n"," n_locs = len(xc_array)\n","\n"," for e in prange(n_locs):\n"," locImage[int(max(min(round(yc_array[e]/pixel_size),w-1),0))][int(max(min(round(xc_array[e]/pixel_size),h-1),0))] += 1\n","\n"," return locImage\n","\n","\n","\n","def getPixelSizeTIFFmetadata(TIFFpath, display=False):\n"," with Image.open(TIFFpath) as img:\n"," meta_dict = {TAGS[key] : img.tag[key] for key in img.tag.keys()}\n","\n","\n"," # TIFF tags\n"," # https://www.loc.gov/preservation/digital/formats/content/tiff_tags.shtml\n"," # https://www.awaresystems.be/imaging/tiff/tifftags/resolutionunit.html\n"," ResolutionUnit = meta_dict['ResolutionUnit'][0] # unit of resolution\n"," width = meta_dict['ImageWidth'][0]\n"," height = meta_dict['ImageLength'][0]\n","\n"," xResolution = meta_dict['XResolution'][0] # number of pixels / ResolutionUnit\n","\n"," if len(xResolution) == 1:\n"," xResolution = xResolution[0]\n"," elif len(xResolution) == 2:\n"," xResolution = xResolution[0]/xResolution[1]\n"," else:\n"," print('Image resolution not defined.')\n"," xResolution = 1\n","\n"," if ResolutionUnit == 2:\n"," # Units given are in inches\n"," pixel_size = 0.025*1e9/xResolution\n"," elif ResolutionUnit == 3:\n"," # Units given are in cm\n"," pixel_size = 0.01*1e9/xResolution\n"," else: \n"," # ResolutionUnit is therefore 1\n"," print('Resolution unit not defined. Assuming: um')\n"," pixel_size = 1e3/xResolution\n","\n"," if display:\n"," print('Pixel size obtained from metadata: '+str(pixel_size)+' nm')\n"," print('Image size: '+str(width)+'x'+str(height))\n"," \n"," return (pixel_size, width, height)\n","\n","\n","def saveAsTIF(path, filename, array, pixel_size):\n"," \"\"\"\n"," Image saving using PIL to save as .tif format\n"," # Input \n"," path - path where it will be saved\n"," filename - name of the file to save (no extension)\n"," array - numpy array conatining the data at the required format\n"," pixel_size - physical size of pixels in nanometers (identical for x and y)\n"," \"\"\"\n","\n"," # print('Data type: '+str(array.dtype))\n"," if (array.dtype == np.uint16):\n"," mode = 'I;16'\n"," elif (array.dtype == np.uint32):\n"," mode = 'I'\n"," else:\n"," mode = 'F'\n","\n"," # Rounding the pixel size to the nearest number that divides exactly 1cm.\n"," # Resolution needs to be a rational number --> see TIFF format\n"," # pixel_size = 10000/(round(10000/pixel_size))\n","\n"," if len(array.shape) == 2:\n"," im = Image.fromarray(array)\n"," im.save(os.path.join(path, filename+'.tif'),\n"," mode = mode, \n"," resolution_unit = 3,\n"," resolution = 0.01*1e9/pixel_size)\n","\n","\n"," elif len(array.shape) == 3:\n"," imlist = []\n"," for frame in array:\n"," imlist.append(Image.fromarray(frame))\n","\n"," imlist[0].save(os.path.join(path, filename+'.tif'), save_all=True,\n"," append_images=imlist[1:],\n"," mode = mode, \n"," resolution_unit = 3,\n"," resolution = 0.01*1e9/pixel_size)\n","\n"," return\n","\n","\n","\n","\n","class Maximafinder(Layer):\n"," def __init__(self, thresh, neighborhood_size, use_local_avg, **kwargs):\n"," super(Maximafinder, self).__init__(**kwargs)\n"," self.thresh = tf.constant(thresh, dtype=tf.float32)\n"," self.nhood = neighborhood_size\n"," self.use_local_avg = use_local_avg\n","\n"," def build(self, input_shape):\n"," if self.use_local_avg is True:\n"," self.kernel_x = tf.reshape(tf.constant([[-1,0,1],[-1,0,1],[-1,0,1]], dtype=tf.float32), [3, 3, 1, 1])\n"," self.kernel_y = tf.reshape(tf.constant([[-1,-1,-1],[0,0,0],[1,1,1]], dtype=tf.float32), [3, 3, 1, 1])\n"," self.kernel_sum = tf.reshape(tf.constant([[1,1,1],[1,1,1],[1,1,1]], dtype=tf.float32), [3, 3, 1, 1])\n","\n"," def call(self, inputs):\n","\n"," # local maxima positions\n"," max_pool_image = MaxPooling2D(pool_size=(self.nhood,self.nhood), strides=(1,1), padding='same')(inputs)\n"," cond = tf.math.greater(max_pool_image, self.thresh) & tf.math.equal(max_pool_image, inputs)\n"," indices = tf.where(cond)\n"," bind, xind, yind = indices[:, 0], indices[:, 2], indices[:, 1]\n"," confidence = tf.gather_nd(inputs, indices)\n","\n"," # local CoG estimator\n"," if self.use_local_avg:\n"," x_image = K.conv2d(inputs, self.kernel_x, padding='same')\n"," y_image = K.conv2d(inputs, self.kernel_y, padding='same')\n"," sum_image = K.conv2d(inputs, self.kernel_sum, padding='same')\n"," confidence = tf.cast(tf.gather_nd(sum_image, indices), dtype=tf.float32)\n"," x_local = tf.math.divide(tf.gather_nd(x_image, indices),tf.gather_nd(sum_image, indices))\n"," y_local = tf.math.divide(tf.gather_nd(y_image, indices),tf.gather_nd(sum_image, indices))\n"," xind = tf.cast(xind, dtype=tf.float32) + tf.cast(x_local, dtype=tf.float32)\n"," yind = tf.cast(yind, dtype=tf.float32) + tf.cast(y_local, dtype=tf.float32)\n"," else:\n"," xind = tf.cast(xind, dtype=tf.float32)\n"," yind = tf.cast(yind, dtype=tf.float32)\n"," \n"," return bind, xind, yind, confidence\n","\n"," def get_config(self):\n","\n"," # Implement get_config to enable serialization. This is optional.\n"," base_config = super(Maximafinder, self).get_config()\n"," config = {}\n"," return dict(list(base_config.items()) + list(config.items()))\n","\n","\n","\n","# ------------------------------- Prediction with postprocessing function-------------------------------\n","def batchFramePredictionLocalization(dataPath, filename, modelPath, savePath, batch_size=1, thresh=0.1, neighborhood_size=3, use_local_avg = False, pixel_size = None):\n"," \"\"\"\n"," This function tests a trained model on the desired test set, given the \n"," tiff stack of test images, learned weights, and normalization factors.\n"," \n"," # Inputs\n"," dataPath - the path to the folder containing the tiff stack(s) to run prediction on \n"," filename - the name of the file to process\n"," modelPath - the path to the folder containing the weights file and the mean and standard deviation file generated in train_model\n"," savePath - the path to the folder where to save the prediction\n"," batch_size. - the number of frames to predict on for each iteration\n"," thresh - threshoold percentage from the maximum of the gaussian scaling\n"," neighborhood_size - the size of the neighborhood for local maxima finding\n"," use_local_average - Boolean whether to perform local averaging or not\n"," \"\"\"\n"," \n"," # load mean and std\n"," matfile = sio.loadmat(os.path.join(modelPath,'model_metadata.mat'))\n"," test_mean = np.array(matfile['mean_test'])\n"," test_std = np.array(matfile['std_test']) \n"," upsampling_factor = np.array(matfile['upsampling_factor'])\n"," upsampling_factor = upsampling_factor.item() # convert to scalar\n"," L2_weighting_factor = np.array(matfile['Normalization factor'])\n"," L2_weighting_factor = L2_weighting_factor.item() # convert to scalar\n","\n"," # Read in the raw file\n"," Images = io.imread(os.path.join(dataPath, filename))\n"," if pixel_size == None:\n"," pixel_size, _, _ = getPixelSizeTIFFmetadata(os.path.join(dataPath, filename), display=True)\n"," pixel_size_hr = pixel_size/upsampling_factor\n","\n"," # get dataset dimensions\n"," (nFrames, M, N) = Images.shape\n"," print('Input image is '+str(N)+'x'+str(M)+' with '+str(nFrames)+' frames.')\n","\n"," # Build the model for a bigger image\n"," model = buildModel((upsampling_factor*M, upsampling_factor*N, 1))\n","\n"," # Load the trained weights\n"," model.load_weights(os.path.join(modelPath,'weights_best.hdf5'))\n","\n"," # add a post-processing module\n"," max_layer = Maximafinder(thresh*L2_weighting_factor, neighborhood_size, use_local_avg)\n","\n"," # Initialise the results: lists will be used to collect all the localizations\n"," frame_number_list, x_nm_list, y_nm_list, confidence_au_list = [], [], [], []\n","\n"," # Initialise the results\n"," Prediction = np.zeros((M*upsampling_factor, N*upsampling_factor), dtype=np.float32)\n"," Widefield = np.zeros((M, N), dtype=np.float32)\n","\n"," # run model in batches\n"," n_batches = math.ceil(nFrames/batch_size)\n"," for b in tqdm(range(n_batches)):\n","\n"," nF = min(batch_size, nFrames - b*batch_size)\n"," Images_norm = np.zeros((nF, M, N),dtype=np.float32)\n"," Images_upsampled = np.zeros((nF, M*upsampling_factor, N*upsampling_factor), dtype=np.float32)\n","\n"," # Upsampling using a simple nearest neighbor interp and calculating - MULTI-THREAD this?\n"," for f in range(nF):\n"," Images_norm[f,:,:] = project_01(Images[b*batch_size+f,:,:])\n"," Images_norm[f,:,:] = normalize_im(Images_norm[f,:,:], test_mean, test_std)\n"," Images_upsampled[f,:,:] = np.kron(Images_norm[f,:,:], np.ones((upsampling_factor,upsampling_factor)))\n"," Widefield += Images[b*batch_size+f,:,:]\n","\n"," # Reshaping\n"," Images_upsampled = np.expand_dims(Images_upsampled,axis=3)\n","\n"," # Run prediction and local amxima finding\n"," predicted_density = model.predict_on_batch(Images_upsampled)\n"," predicted_density[predicted_density < 0] = 0\n"," Prediction += predicted_density.sum(axis = 3).sum(axis = 0)\n","\n"," bind, xind, yind, confidence = max_layer(predicted_density)\n"," \n"," # normalizing the confidence by the L2_weighting_factor\n"," confidence /= L2_weighting_factor \n","\n"," # turn indices to nms and append to the results\n"," xind, yind = xind*pixel_size_hr, yind*pixel_size_hr\n"," frmind = (bind.numpy() + b*batch_size + 1).tolist()\n"," xind = xind.numpy().tolist()\n"," yind = yind.numpy().tolist()\n"," confidence = confidence.numpy().tolist()\n"," frame_number_list += frmind\n"," x_nm_list += xind\n"," y_nm_list += yind\n"," confidence_au_list += confidence\n","\n"," # Open and create the csv file that will contain all the localizations\n"," if use_local_avg:\n"," ext = '_avg'\n"," else:\n"," ext = '_max'\n"," with open(os.path.join(savePath, 'Localizations_' + os.path.splitext(filename)[0] + ext + '.csv'), \"w\", newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow(['frame', 'x [nm]', 'y [nm]', 'confidence [a.u]'])\n"," locs = list(zip(frame_number_list, x_nm_list, y_nm_list, confidence_au_list))\n"," writer.writerows(locs)\n","\n"," # Save the prediction and widefield image\n"," Widefield = np.kron(Widefield, np.ones((upsampling_factor,upsampling_factor)))\n"," Widefield = np.float32(Widefield)\n","\n"," # io.imsave(os.path.join(savePath, 'Predicted_'+os.path.splitext(filename)[0]+'.tif'), Prediction)\n"," # io.imsave(os.path.join(savePath, 'Widefield_'+os.path.splitext(filename)[0]+'.tif'), Widefield)\n","\n"," saveAsTIF(savePath, 'Predicted_'+os.path.splitext(filename)[0], Prediction, pixel_size_hr)\n"," saveAsTIF(savePath, 'Widefield_'+os.path.splitext(filename)[0], Widefield, pixel_size_hr)\n","\n","\n"," return\n","\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n"," NORMAL = '\\033[0m' # white (normal)\n","\n","\n","\n","def list_files(directory, extension):\n"," return (f for f in os.listdir(directory) if f.endswith('.' + extension))\n","\n","\n","# @njit(parallel=True)\n","def subPixelMaxLocalization(array, method = 'CoM', patch_size = 3):\n"," xMaxInd, yMaxInd = np.unravel_index(array.argmax(), array.shape, order='C')\n"," centralPatch = XC[(xMaxInd-patch_size):(xMaxInd+patch_size+1),(yMaxInd-patch_size):(yMaxInd+patch_size+1)]\n","\n"," if (method == 'MAX'):\n"," x0 = xMaxInd\n"," y0 = yMaxInd\n","\n"," elif (method == 'CoM'):\n"," x0 = 0\n"," y0 = 0\n"," S = 0\n"," for xy in range(patch_size*patch_size):\n"," y = math.floor(xy/patch_size)\n"," x = xy - y*patch_size\n"," x0 += x*array[x,y]\n"," y0 += y*array[x,y]\n"," S = array[x,y]\n"," \n"," x0 = x0/S - patch_size/2 + xMaxInd\n"," y0 = y0/S - patch_size/2 + yMaxInd\n"," \n"," elif (method == 'Radiality'):\n"," # Not implemented yet\n"," x0 = xMaxInd\n"," y0 = yMaxInd\n"," \n"," return (x0, y0)\n","\n","\n","@njit(parallel=True)\n","def correctDriftLocalization(xc_array, yc_array, frames, xDrift, yDrift):\n"," n_locs = xc_array.shape[0]\n"," xc_array_Corr = np.empty(n_locs)\n"," yc_array_Corr = np.empty(n_locs)\n"," \n"," for loc in prange(n_locs):\n"," xc_array_Corr[loc] = xc_array[loc] - xDrift[frames[loc]]\n"," yc_array_Corr[loc] = yc_array[loc] - yDrift[frames[loc]]\n","\n"," return (xc_array_Corr, yc_array_Corr)\n","\n","\n","print('--------------------------------')\n","print('DeepSTORM installation complete.')\n","\n","# Check if this is the latest version of the notebook\n","\n","All_notebook_versions = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_Notebook_versions.csv\", dtype=str)\n","print('Notebook version: '+Notebook_version)\n","\n","Latest_Notebook_version = All_notebook_versions[All_notebook_versions[\"Notebook\"] == Network]['Version'].iloc[0]\n","print('Latest notebook version: '+Latest_Notebook_version)\n","\n","\n","if Notebook_version == Latest_Notebook_version:\n"," print(\"This notebook is up-to-date.\")\n","else:\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","\n","# Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n","\n","# if Notebook_version == list(Latest_notebook_version.columns):\n","# print(\"This notebook is up-to-date.\")\n","\n","# if not Notebook_version == list(Latest_notebook_version.columns):\n","# print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","def pdf_export(trained = False, raw_data = False, pretrained_model = False):\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," \n"," #model_name = 'little_CARE_test'\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," \n"," # add another cell \n"," if trained:\n"," training_time = \"Training time: \"+str(hours)+ \"hour(s) \"+str(minutes)+\"min(s) \"+str(round(seconds))+\"sec(s)\"\n"," pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n"," pdf.ln(1)\n","\n"," Header_2 = 'Information for your materials and method:'\n"," pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n"," #print(all_packages)\n","\n"," #Main Packages\n"," main_packages = ''\n"," version_numbers = []\n"," for name in ['tensorflow','numpy','Keras']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n"," cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n"," cuda_version = cuda_version.stdout.decode('utf-8')\n"," cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n"," gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n"," gpu_name = gpu_name.stdout.decode('utf-8')\n"," gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n"," #print(cuda_version[cuda_version.find(', V')+3:-1])\n"," #print(gpu_name)\n"," if raw_data == True:\n"," shape = (M,N)\n"," else:\n"," shape = (int(FOV_size/pixel_size),int(FOV_size/pixel_size))\n"," #dataset_size = len(os.listdir(Training_source))\n","\n"," text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(n_patches)+' paired image patches (image dimensions: '+str(patch_size)+', patch size (upsampled): ('+str(int(patch_size))+','+str(int(patch_size))+') with a batch size of '+str(batch_size)+', using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Losses were calculated using MSE for the heatmaps and L1 loss for the spike prediction. Key python packages used include tensorflow (v '+version_numbers[0]+'), numpy (v '+version_numbers[1]+'), Keras (v '+version_numbers[2]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+' GPU.'\n","\n"," if pretrained_model:\n"," text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(n_patches)+' paired image patches (image dimensions: '+str(patch_size)+', patch size (upsampled): ('+str(int(patch_size))+','+str(int(patch_size))+') with a batch size of '+str(batch_size)+', using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Losses were calculated using MSE for the heatmaps and L1 loss for the spike prediction. The models was retrained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), numpy (v '+version_numbers[1]+'), Keras (v '+version_numbers[2]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+' GPU.'\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.multi_cell(180, 5, txt = text, align='L')\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," pdf.ln(1)\n"," pdf.cell(190, 5, txt = 'Training dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," if raw_data==False:\n"," simul_text = 'The training dataset was created in the notebook using the following simulation settings:'\n"," pdf.cell(200, 5, txt=simul_text, align='L')\n"," pdf.ln(1)\n"," html = \"\"\" \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
SettingSimulated Value
FOV_size{0}
pixel_size{1}
ADC_per_photon_conversion{2}
ReadOutNoise_ADC{3}
ADC_offset{4}
emitter_density{5}
emitter_density_std{6}
number_of_frames{7}
sigma{8}
sigma_std{9}
n_photons{10}
n_photons_std{11}
\n"," \"\"\".format(FOV_size, pixel_size, ADC_per_photon_conversion, ReadOutNoise_ADC, ADC_offset, emitter_density, emitter_density_std, number_of_frames, sigma, sigma_std, n_photons, n_photons_std)\n"," pdf.write_html(html)\n"," else:\n"," simul_text = 'The training dataset was simulated using ThunderSTORM and loaded into the notebook.'\n"," pdf.multi_cell(190, 5, txt=simul_text, align='L')\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," #pdf.ln(1)\n"," #pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(29, 5, txt= 'ImageData_path', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = ImageData_path, align = 'L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(28, 5, txt= 'LocalizationData_path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = LocalizationData_path, align = 'L')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(28, 5, txt= 'pixel_size:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = str(pixel_size), align = 'L')\n"," #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," # if Use_Default_Advanced_Parameters:\n"," # pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n"," pdf.cell(200, 5, txt='The following parameters were used to generate patches:')\n"," pdf.ln(1)\n"," html = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(str(patch_size)+'x'+str(patch_size), upsampling_factor, num_patches_per_frame, min_number_of_emitters_per_patch, max_num_patches, gaussian_sigma, Automatic_normalization, L2_weighting_factor)\n"," pdf.write_html(html)\n"," pdf.ln(3)\n"," pdf.set_font('Arial', size=10)\n"," pdf.cell(200, 5, txt='The following parameters were used for training:')\n"," pdf.ln(1)\n"," html = \"\"\" \n","
Patch ParameterValue
patch_size{0}
upsampling_factor{1}
num_patches_per_frame{2}
min_number_of_emitters_per_patch{3}
max_num_patches{4}
gaussian_sigma{5}
Automatic_normalization{6}
L2_weighting_factor{7}
\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
Training ParameterValue
number_of_epochs{0}
batch_size{1}
number_of_steps{2}
percentage_validation{3}
initial_learning_rate{4}
\n"," \"\"\".format(number_of_epochs,batch_size,number_of_steps,percentage_validation,initial_learning_rate)\n"," pdf.write_html(html)\n","\n"," pdf.ln(1)\n"," # pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(21, 5, txt= 'Model Path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n","\n"," pdf.ln(1)\n"," pdf.cell(60, 5, txt = 'Example Training Images', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread('/content/TrainingDataExample_DeepSTORM2D.png').shape\n"," pdf.image('/content/TrainingDataExample_DeepSTORM2D.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- Deep-STORM: Nehme, Elias, et al. \"Deep-STORM: super-resolution single-molecule microscopy by deep learning.\" Optica 5.4 (2018): 458-464.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," # if Use_Data_augmentation:\n"," # ref_3 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\n"," # pdf.multi_cell(190, 5, txt = ref_3, align='L')\n"," pdf.ln(3)\n"," reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(model_path+'/'+model_name+'/'+model_name+'_training_report.pdf')\n"," print('------------------------------')\n"," print('PDF report exported in '+model_path+'/'+model_name+'/')\n","\n","def qc_pdf_export():\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'Deep-STORM'\n"," #model_name = os.path.basename(full_QC_model_path)\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Quality Control report for '+Network+' model ('+os.path.basename(QC_model_path)+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(2)\n"," pdf.cell(190, 5, txt = 'Loss curves', ln=1, align='L')\n"," pdf.ln(1)\n"," if os.path.exists(savePath+'/lossCurvePlots.png'):\n"," exp_size = io.imread(savePath+'/lossCurvePlots.png').shape\n"," pdf.image(savePath+'/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/10))\n"," else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," pdf.cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.')\n"," pdf.ln(2)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(3)\n"," pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread(savePath+'/QC_example_data.png').shape\n"," pdf.image(savePath+'/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n","\n"," pdf.ln(1)\n"," html = \"\"\"\n"," \n"," \n"," \"\"\"\n"," with open(savePath+'/'+os.path.basename(QC_model_path)+'_QC_metrics.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," mSSIM_PvsGT = header[1]\n"," mSSIM_SvsGT = header[2]\n"," NRMSE_PvsGT = header[3]\n"," NRMSE_SvsGT = header[4]\n"," PSNR_PvsGT = header[5]\n"," PSNR_SvsGT = header[6]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,mSSIM_PvsGT,mSSIM_SvsGT,NRMSE_PvsGT,NRMSE_SvsGT,PSNR_PvsGT,PSNR_SvsGT)\n"," html = html+header\n"," for row in metrics:\n"," image = row[0]\n"," mSSIM_PvsGT = row[1]\n"," mSSIM_SvsGT = row[2]\n"," NRMSE_PvsGT = row[3]\n"," NRMSE_SvsGT = row[4]\n"," PSNR_PvsGT = row[5]\n"," PSNR_SvsGT = row[6]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,str(round(float(mSSIM_PvsGT),3)),str(round(float(mSSIM_SvsGT),3)),str(round(float(NRMSE_PvsGT),3)),str(round(float(NRMSE_SvsGT),3)),str(round(float(PSNR_PvsGT),3)),str(round(float(PSNR_SvsGT),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}{3}{4}{5}{6}
{0}{1}{2}{3}{4}{5}{6}
\"\"\"\n"," \n"," pdf.write_html(html)\n","\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- Deep-STORM: Nehme, Elias, et al. \"Deep-STORM: super-resolution single-molecule microscopy by deep learning.\" Optica 5.4 (2018): 458-464.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n"," pdf.ln(3)\n"," reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(savePath+'/'+os.path.basename(QC_model_path)+'_QC_report.pdf')\n","\n","\n"," print('------------------------------')\n"," print('QC PDF report exported as '+savePath+'/'+os.path.basename(QC_model_path)+'_QC_report.pdf')\n","\n","# Build requirements file for local run\n","after = [str(m) for m in sys.modules]\n","build_requirements_file(before, after)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"E04mOlG_H5Tz"},"source":["# **2. Complete the Colab session**\n","---"]},{"cell_type":"markdown","metadata":{"id":"F_tjlGzsH-Dn"},"source":["\n","## **2.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelerator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"gn-LaaNNICqL","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","# %tensorflow_version 1.x\n","\n","import tensorflow as tf\n","# if tf.__version__ != '2.2.0':\n","# !pip install tensorflow==2.2.0\n","\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime settings are correct then Google did not allocate GPU to your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi\n","\n","# from tensorflow.python.client import device_lib \n","# device_lib.list_local_devices()\n","\n","# print the tensorflow version\n","print('Tensorflow version is ' + str(tf.__version__))\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"tnP7wM79IKW-"},"source":["## **2.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"1R-7Fo34_gOd","cellView":"form"},"source":["#@markdown ##Run this cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","#mounts user's Google Drive to Google Colab.\n","\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"vu8f5NGJkJos"},"source":["\n","# **3. Generate patches for training**\n","---\n","\n","For Deep-STORM the training data can be obtained in two ways:\n","* Simulated using ThunderSTORM or other simulation tool and loaded here (**using Section 3.1.a**)\n","* Directly simulated in this notebook (**using Section 3.1.b**)\n"]},{"cell_type":"markdown","metadata":{"id":"WSV8xnlynp0l"},"source":["## **3.1.a Load training data**\n","---\n","\n","Here you can load your simulated data along with its corresponding localization file.\n","* The `pixel_size` is defined in nanometer (nm). "]},{"cell_type":"code","metadata":{"id":"CT6SNcfNg6j0","cellView":"form"},"source":["#@markdown ##Load raw data\n","\n","load_raw_data = True\n","\n","# Get user input\n","ImageData_path = \"\" #@param {type:\"string\"}\n","LocalizationData_path = \"\" #@param {type: \"string\"}\n","#@markdown Get pixel size from file?\n","get_pixel_size_from_file = True #@param {type:\"boolean\"}\n","#@markdown Otherwise, use this value:\n","pixel_size = 100 #@param {type:\"number\"}\n","\n","if get_pixel_size_from_file:\n"," pixel_size,_,_ = getPixelSizeTIFFmetadata(ImageData_path, True)\n","\n","# load the tiff data\n","Images = io.imread(ImageData_path)\n","# get dataset dimensions\n","if len(Images.shape) == 3:\n"," (number_of_frames, M, N) = Images.shape\n","elif len(Images.shape) == 2:\n"," (M, N) = Images.shape\n"," number_of_frames = 1\n","print('Loaded images: '+str(M)+'x'+str(N)+' with '+str(number_of_frames)+' frames')\n","\n","# Interactive display of the stack\n","def scroll_in_time(frame):\n"," f=plt.figure(figsize=(6,6))\n"," plt.imshow(Images[frame-1], interpolation='nearest', cmap = 'gray')\n"," plt.title('Training source at frame = ' + str(frame))\n"," plt.axis('off');\n","\n","if number_of_frames > 1:\n"," interact(scroll_in_time, frame=widgets.IntSlider(min=1, max=Images.shape[0], step=1, value=0, continuous_update=False));\n","else:\n"," f=plt.figure(figsize=(6,6))\n"," plt.imshow(Images, interpolation='nearest', cmap = 'gray')\n"," plt.title('Training source')\n"," plt.axis('off');\n","\n","# Load the localization file and display the first\n","LocData = pd.read_csv(LocalizationData_path, index_col=0)\n","LocData.tail()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"K9xE5GeYiks9"},"source":["## **3.1.b Simulate training data**\n","---\n","This simulation tool allows you to generate SMLM data of randomly distrubuted emitters in a field-of-view. \n","The assumptions are as follows:\n","\n","* Gaussian Point Spread Function (PSF) with standard deviation defined by `Sigma`. The nominal value of `sigma` can be evaluated using `sigma = 0.21 x Lambda / NA`. (from [Zhang *et al.*, Applied Optics 2007](https://doi.org/10.1364/AO.46.001819))\n","* Each emitter will emit `n_photons` per frame, and generate their equivalent Poisson noise.\n","* The camera will contribute Gaussian noise to the signal with a standard deviation defined by `ReadOutNoise_ADC` in ADC\n","* The `emitter_density` is defined as the number of emitters / um^2 on any given frame. Variability in the emitter density can be applied by adjusting `emitter_density_std`. The latter parameter represents the standard deviation of the normal distribution that the density is drawn from for each individual frame. `emitter_density` **is defined in number of emitters / um^2**.\n","* The `n_photons` and `sigma` can additionally include some Gaussian variability by setting `n_photons_std` and `sigma_std`.\n","\n","Important note:\n","- All dimensions are in nanometer (e.g. `FOV_size` = 6400 represents a field of view of 6.4 um x 6.4 um).\n","\n"]},{"cell_type":"code","metadata":{"id":"sQyLXpEhitsg","cellView":"form"},"source":["load_raw_data = False\n","\n","# ---------------------------- User input ----------------------------\n","#@markdown Run the simulation\n","#@markdown --- \n","#@markdown Camera settings: \n","FOV_size = 6400#@param {type:\"number\"}\n","pixel_size = 100#@param {type:\"number\"}\n","ADC_per_photon_conversion = 1 #@param {type:\"number\"}\n","ReadOutNoise_ADC = 4.5#@param {type:\"number\"}\n","ADC_offset = 50#@param {type:\"number\"}\n","\n","#@markdown Acquisition settings: \n","emitter_density = 6#@param {type:\"number\"}\n","emitter_density_std = 0#@param {type:\"number\"}\n","\n","number_of_frames = 20#@param {type:\"integer\"}\n","\n","sigma = 110 #@param {type:\"number\"}\n","sigma_std = 5 #@param {type:\"number\"}\n","# NA = 1.1 #@param {type:\"number\"}\n","# wavelength = 800#@param {type:\"number\"}\n","# wavelength_std = 150#@param {type:\"number\"}\n","n_photons = 2250#@param {type:\"number\"}\n","n_photons_std = 250#@param {type:\"number\"}\n","\n","\n","# ---------------------------- Variable initialisation ----------------------------\n","# Start the clock to measure how long it takes\n","start = time.time()\n","\n","print('-----------------------------------------------------------')\n","n_molecules = emitter_density*FOV_size*FOV_size/10**6\n","n_molecules_std = emitter_density_std*FOV_size*FOV_size/10**6\n","print('Number of molecules / FOV: '+str(round(n_molecules,2))+' +/- '+str((round(n_molecules_std,2))))\n","\n","# sigma = 0.21*wavelength/NA\n","# sigma_std = 0.21*wavelength_std/NA\n","# print('Gaussian PSF sigma: '+str(round(sigma,2))+' +/- '+str(round(sigma_std,2))+' nm')\n","\n","M = N = round(FOV_size/pixel_size)\n","FOV_size = M*pixel_size\n","print('Final image size: '+str(M)+'x'+str(M)+' ('+str(round(FOV_size/1000, 3))+'um x'+str(round(FOV_size/1000,3))+' um)')\n","\n","np.random.seed(1)\n","display_upsampling = 8 # used to display the loc map here\n","NoiseFreeImages = np.zeros((number_of_frames, M, M))\n","locImage = np.zeros((number_of_frames, display_upsampling*M, display_upsampling*N))\n","\n","frames = []\n","all_xloc = []\n","all_yloc = []\n","all_photons = []\n","all_sigmas = []\n","\n","# ---------------------------- Main simulation loop ----------------------------\n","print('-----------------------------------------------------------')\n","for f in tqdm(range(number_of_frames)):\n"," \n"," # Define the coordinates of emitters by randomly distributing them across the FOV\n"," n_mol = int(max(round(np.random.normal(n_molecules, n_molecules_std, size=1)[0]), 0))\n"," x_c = np.random.uniform(low=0.0, high=FOV_size, size=n_mol)\n"," y_c = np.random.uniform(low=0.0, high=FOV_size, size=n_mol)\n"," photon_array = np.random.normal(n_photons, n_photons_std, size=n_mol)\n"," sigma_array = np.random.normal(sigma, sigma_std, size=n_mol)\n"," # x_c = np.linspace(0,3000,5)\n"," # y_c = np.linspace(0,3000,5)\n","\n"," all_xloc += x_c.tolist()\n"," all_yloc += y_c.tolist()\n"," frames += ((f+1)*np.ones(x_c.shape[0])).tolist()\n"," all_photons += photon_array.tolist()\n"," all_sigmas += sigma_array.tolist()\n","\n"," locImage[f] = FromLoc2Image_SimpleHistogram(x_c, y_c, image_size = (N*display_upsampling, M*display_upsampling), pixel_size = pixel_size/display_upsampling)\n","\n"," # # Get the approximated locations according to the grid pixel size\n"," # Chr_emitters = [int(max(min(round(display_upsampling*x_c[i]/pixel_size),N*display_upsampling-1),0)) for i in range(len(x_c))]\n"," # Rhr_emitters = [int(max(min(round(display_upsampling*y_c[i]/pixel_size),M*display_upsampling-1),0)) for i in range(len(y_c))]\n","\n"," # # Build Localization image\n"," # for (r,c) in zip(Rhr_emitters, Chr_emitters):\n"," # locImage[f][r][c] += 1\n","\n"," NoiseFreeImages[f] = FromLoc2Image_Erf(x_c, y_c, photon_array, sigma_array, image_size = (M,M), pixel_size = pixel_size)\n","\n","\n","# ---------------------------- Create DataFrame fof localization file ----------------------------\n","# Table with localization info as dataframe output\n","LocData = pd.DataFrame()\n","LocData[\"frame\"] = frames\n","LocData[\"x [nm]\"] = all_xloc\n","LocData[\"y [nm]\"] = all_yloc\n","LocData[\"Photon #\"] = all_photons\n","LocData[\"Sigma [nm]\"] = all_sigmas\n","LocData.index += 1 # set indices to start at 1 and not 0 (same as ThunderSTORM)\n","\n","\n","# ---------------------------- Estimation of SNR ----------------------------\n","n_frames_for_SNR = 100\n","M_SNR = 10\n","x_c = np.random.uniform(low=0.0, high=pixel_size*M_SNR, size=n_frames_for_SNR)\n","y_c = np.random.uniform(low=0.0, high=pixel_size*M_SNR, size=n_frames_for_SNR)\n","photon_array = np.random.normal(n_photons, n_photons_std, size=n_frames_for_SNR)\n","sigma_array = np.random.normal(sigma, sigma_std, size=n_frames_for_SNR)\n","\n","SNR = np.zeros(n_frames_for_SNR)\n","for i in range(n_frames_for_SNR):\n"," SingleEmitterImage = FromLoc2Image_Erf(np.array([x_c[i]]), np.array([x_c[i]]), np.array([photon_array[i]]), np.array([sigma_array[i]]), (M_SNR, M_SNR), pixel_size)\n"," Signal_photon = np.max(SingleEmitterImage)\n"," Noise_photon = math.sqrt((ReadOutNoise_ADC/ADC_per_photon_conversion)**2 + Signal_photon)\n"," SNR[i] = Signal_photon/Noise_photon\n","\n","print('SNR: '+str(round(np.mean(SNR),2))+' +/- '+str(round(np.std(SNR),2)))\n","# ---------------------------- ----------------------------\n","\n","\n","# Table with info\n","simParameters = pd.DataFrame()\n","simParameters[\"FOV size (nm)\"] = [FOV_size]\n","simParameters[\"Pixel size (nm)\"] = [pixel_size]\n","simParameters[\"ADC/photon\"] = [ADC_per_photon_conversion]\n","simParameters[\"Read-out noise (ADC)\"] = [ReadOutNoise_ADC]\n","simParameters[\"Constant offset (ADC)\"] = [ADC_offset]\n","\n","simParameters[\"Emitter density (emitters/um^2)\"] = [emitter_density]\n","simParameters[\"STD of emitter density (emitters/um^2)\"] = [emitter_density_std]\n","simParameters[\"Number of frames\"] = [number_of_frames]\n","# simParameters[\"NA\"] = [NA]\n","# simParameters[\"Wavelength (nm)\"] = [wavelength]\n","# simParameters[\"STD of wavelength (nm)\"] = [wavelength_std]\n","simParameters[\"Sigma (nm))\"] = [sigma]\n","simParameters[\"STD of Sigma (nm))\"] = [sigma_std]\n","simParameters[\"Number of photons\"] = [n_photons]\n","simParameters[\"STD of number of photons\"] = [n_photons_std]\n","simParameters[\"SNR\"] = [np.mean(SNR)]\n","simParameters[\"STD of SNR\"] = [np.std(SNR)]\n","\n","\n","# ---------------------------- Finish simulation ----------------------------\n","# Calculating the noisy image\n","Images = ADC_per_photon_conversion * np.random.poisson(NoiseFreeImages) + ReadOutNoise_ADC * np.random.normal(size = (number_of_frames, M, N)) + ADC_offset\n","Images[Images <= 0] = 0\n","\n","# Convert to 16-bit or 32-bits integers\n","if Images.max() < (2**16-1):\n"," Images = Images.astype(np.uint16)\n","else:\n"," Images = Images.astype(np.uint32)\n","\n","\n","# ---------------------------- Display ----------------------------\n","# Displaying the time elapsed for simulation\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds,1),\"sec(s)\")\n","\n","\n","# Interactively display the results using Widgets\n","def scroll_in_time(frame):\n"," f = plt.figure(figsize=(18,6))\n"," plt.subplot(1,3,1)\n"," plt.imshow(locImage[frame-1], interpolation='bilinear', vmin = 0, vmax=0.1)\n"," plt.title('Localization image')\n"," plt.axis('off');\n","\n"," plt.subplot(1,3,2)\n"," plt.imshow(NoiseFreeImages[frame-1], interpolation='nearest', cmap='gray')\n"," plt.title('Noise-free simulation')\n"," plt.axis('off');\n","\n"," plt.subplot(1,3,3)\n"," plt.imshow(Images[frame-1], interpolation='nearest', cmap='gray')\n"," plt.title('Noisy simulation')\n"," plt.axis('off');\n","\n","interact(scroll_in_time, frame=widgets.IntSlider(min=1, max=Images.shape[0], step=1, value=0, continuous_update=False));\n","\n","# Display the head of the dataframe with localizations\n","LocData.tail()\n"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"Pz7RfSuoeJeq","cellView":"form"},"source":["#@markdown ---\n","#@markdown ##Play this cell to save the simulated stack\n","#@markdown Please select a path to the folder where to save the simulated data. It is not necessary to save the data to run the training, but keeping the simulated for your own record can be useful to check its validity.\n","Save_path = \"\" #@param {type:\"string\"}\n","\n","if not os.path.exists(Save_path):\n"," os.makedirs(Save_path)\n"," print('Folder created.')\n","else:\n"," print('Training data already exists in folder: Data overwritten.')\n","\n","saveAsTIF(Save_path, 'SimulatedDataset', Images, pixel_size)\n","# io.imsave(os.path.join(Save_path, 'SimulatedDataset.tif'),Images)\n","LocData.to_csv(os.path.join(Save_path, 'SimulatedDataset.csv'))\n","simParameters.to_csv(os.path.join(Save_path, 'SimulatedParameters.csv'))\n","print('Training dataset saved.')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"K_8e3kE-JhVY"},"source":["## **3.2. Generate training patches**\n","---\n","\n","Training patches need to be created from the training data generated above. \n","* The `patch_size` needs to give sufficient contextual information and for most cases a `patch_size` of 26 (corresponding to patches of 26x26 pixels) works fine. **DEFAULT: 26**\n","* The `upsampling_factor` defines the effective magnification of the final super-resolved image compared to the input image (this is called magnification in ThunderSTORM). This is used to generate the super-resolved patches as target dataset. Using an `upsampling_factor` of 16 will require the use of more memory and it may be necessary to decreae the `patch_size` to 16 for example. **DEFAULT: 8**\n","* The `num_patches_per_frame` defines the number of patches extracted from each frame generated in section 3.1. **DEFAULT: 500**\n","* The `min_number_of_emitters_per_patch` defines the minimum number of emitters that need to be present in the patch to be a valid patch. An empty patch does not contain useful information for the network to learn from. **DEFAULT: 7**\n","* The `max_num_patches` defines the maximum number of patches to generate. Fewer may be generated depending on how many pacthes are rejected and how many frames are available. **DEFAULT: 10000**\n","* The `gaussian_sigma` defines the Gaussian standard deviation (in magnified pixels) applied to generate the super-resolved target image. **DEFAULT: 1**\n","* The `L2_weighting_factor` is a normalization factor used in the loss function. It helps balancing the loss from the L2 norm. When using higher densities, this factor should be decreased and vice-versa. This factor can be autimatically calculated using an empiraical formula. **DEFAULT: 100**\n","\n"]},{"cell_type":"code","metadata":{"id":"AsNx5KzcFNvC","cellView":"form"},"source":["#@markdown ## **Provide patch parameters**\n","\n","\n","# -------------------- User input --------------------\n","patch_size = 26 #@param {type:\"integer\"}\n","upsampling_factor = 8 #@param [\"4\", \"8\", \"16\"] {type:\"raw\"}\n","num_patches_per_frame = 500#@param {type:\"integer\"}\n","min_number_of_emitters_per_patch = 7#@param {type:\"integer\"}\n","max_num_patches = 10000#@param {type:\"integer\"}\n","gaussian_sigma = 1#@param {type:\"integer\"}\n","\n","#@markdown Estimate the optimal normalization factor automatically?\n","Automatic_normalization = True #@param {type:\"boolean\"}\n","#@markdown Otherwise, it will use the following value:\n","L2_weighting_factor = 100 #@param {type:\"number\"}\n","\n","\n","# -------------------- Prepare variables --------------------\n","# Start the clock to measure how long it takes\n","start = time.time()\n","\n","# Initialize some parameters\n","pixel_size_hr = pixel_size/upsampling_factor # in nm\n","n_patches = min(number_of_frames*num_patches_per_frame, max_num_patches)\n","patch_size = patch_size*upsampling_factor\n","\n","# Dimensions of the high-res grid\n","Mhr = upsampling_factor*M # in pixels\n","Nhr = upsampling_factor*N # in pixels\n","\n","# Initialize the training patches and labels\n","patches = np.zeros((n_patches, patch_size, patch_size), dtype = np.float32)\n","spikes = np.zeros((n_patches, patch_size, patch_size), dtype = np.float32)\n","heatmaps = np.zeros((n_patches, patch_size, patch_size), dtype = np.float32)\n","\n","# Run over all frames and construct the training examples\n","k = 1 # current patch count\n","skip_counter = 0 # number of dataset skipped due to low density\n","id_start = 0 # id position in LocData for current frame\n","print('Generating '+str(n_patches)+' patches of '+str(patch_size)+'x'+str(patch_size))\n","\n","n_locs = len(LocData.index)\n","print('Total number of localizations: '+str(n_locs))\n","density = n_locs/(M*N*number_of_frames*(0.001*pixel_size)**2)\n","print('Density: '+str(round(density,2))+' locs/um^2')\n","n_locs_per_patch = patch_size**2*density\n","\n","if Automatic_normalization:\n"," # This empirical formulae attempts to balance the loss L2 function between the background and the bright spikes\n"," # A value of 100 was originally chosen to balance L2 for a patch size of 2.6x2.6^2 0.1um pixel size and density of 3 (hence the 20.28), at upsampling_factor = 8\n"," L2_weighting_factor = 100/math.sqrt(min(n_locs_per_patch, min_number_of_emitters_per_patch)*8**2/(upsampling_factor**2*20.28))\n"," print('Normalization factor: '+str(round(L2_weighting_factor,2)))\n","\n","# -------------------- Patch generation loop --------------------\n","\n","print('-----------------------------------------------------------')\n","for (f, thisFrame) in enumerate(tqdm(Images)):\n","\n"," # Upsample the frame\n"," upsampledFrame = np.kron(thisFrame, np.ones((upsampling_factor,upsampling_factor)))\n"," # Read all the provided high-resolution locations for current frame\n"," DataFrame = LocData[LocData['frame'] == f+1].copy()\n","\n"," # Get the approximated locations according to the high-res grid pixel size\n"," Chr_emitters = [int(max(min(round(DataFrame['x [nm]'][i]/pixel_size_hr),Nhr-1),0)) for i in range(id_start+1,id_start+1+len(DataFrame.index))]\n"," Rhr_emitters = [int(max(min(round(DataFrame['y [nm]'][i]/pixel_size_hr),Mhr-1),0)) for i in range(id_start+1,id_start+1+len(DataFrame.index))]\n"," id_start += len(DataFrame.index)\n","\n"," # Build Localization image\n"," LocImage = np.zeros((Mhr,Nhr))\n"," LocImage[(Rhr_emitters, Chr_emitters)] = 1\n","\n"," # Here, there's a choice between the original Gaussian (classification approach) and using the erf function\n"," HeatMapImage = L2_weighting_factor*gaussian_filter(LocImage, gaussian_sigma) \n"," # HeatMapImage = L2_weighting_factor*FromLoc2Image_MultiThreaded(np.array(list(DataFrame['x [nm]'])), np.array(list(DataFrame['y [nm]'])), \n"," # np.ones(len(DataFrame.index)), pixel_size_hr*gaussian_sigma*np.ones(len(DataFrame.index)), \n"," # Mhr, pixel_size_hr)\n"," \n","\n"," # Generate random position for the top left corner of the patch\n"," xc = np.random.randint(0, Mhr-patch_size, size=num_patches_per_frame)\n"," yc = np.random.randint(0, Nhr-patch_size, size=num_patches_per_frame)\n","\n"," for c in range(len(xc)):\n"," if LocImage[xc[c]:xc[c]+patch_size, yc[c]:yc[c]+patch_size].sum() < min_number_of_emitters_per_patch:\n"," skip_counter += 1\n"," continue\n"," \n"," else:\n"," # Limit maximal number of training examples to 15k\n"," if k > max_num_patches:\n"," break\n"," else:\n"," # Assign the patches to the right part of the images\n"," patches[k-1] = upsampledFrame[xc[c]:xc[c]+patch_size, yc[c]:yc[c]+patch_size]\n"," spikes[k-1] = LocImage[xc[c]:xc[c]+patch_size, yc[c]:yc[c]+patch_size]\n"," heatmaps[k-1] = HeatMapImage[xc[c]:xc[c]+patch_size, yc[c]:yc[c]+patch_size]\n"," k += 1 # increment current patch count\n","\n","# Remove the empty data\n","patches = patches[:k-1]\n","spikes = spikes[:k-1]\n","heatmaps = heatmaps[:k-1]\n","n_patches = k-1\n","\n","# -------------------- Failsafe --------------------\n","# Check if the size of the training set is smaller than 5k to notify user to simulate more images using ThunderSTORM\n","if ((k-1) < 5000):\n"," # W = '\\033[0m' # white (normal)\n"," # R = '\\033[31m' # red\n"," print(bcolors.WARNING+'!! WARNING: Training set size is below 5K - Consider simulating more images in ThunderSTORM. !!'+bcolors.NORMAL)\n","\n","\n","\n","# -------------------- Displays --------------------\n","print('Number of patches skipped due to low density: '+str(skip_counter))\n","# dataSize = int((getsizeof(patches)+getsizeof(heatmaps)+getsizeof(spikes))/(1024*1024)) #rounded in MB\n","# print('Size of patches: '+str(dataSize)+' MB')\n","print(str(n_patches)+' patches were generated.')\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n","\n","# Display patches interactively with a slider\n","def scroll_patches(patch):\n"," f = plt.figure(figsize=(16,6))\n"," plt.subplot(1,3,1)\n"," plt.imshow(patches[patch-1], interpolation='nearest', cmap='gray')\n"," plt.title('Raw data (frame #'+str(patch)+')')\n"," plt.axis('off');\n","\n"," plt.subplot(1,3,2)\n"," plt.imshow(heatmaps[patch-1], interpolation='nearest')\n"," plt.title('Heat map')\n"," plt.axis('off');\n","\n"," plt.subplot(1,3,3)\n"," plt.imshow(spikes[patch-1], interpolation='nearest')\n"," plt.title('Localization map')\n"," plt.axis('off');\n"," \n"," plt.savefig('/content/TrainingDataExample_DeepSTORM2D.png',bbox_inches='tight',pad_inches=0)\n","\n","\n","interact(scroll_patches, patch=widgets.IntSlider(min=1, max=patches.shape[0], step=1, value=0, continuous_update=False));\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"DSjXFMevK7Iz"},"source":["# **4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"hVeyKU0MdAPx"},"source":["## **4.1. Select your paths and parameters**\n","\n","---\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","\n","**Training parameters**\n","\n","**`number_of_epochs`:**Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for ~100 epochs. Evaluate the performance after training (see 5). **Default value: 80**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 16**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. **If this value is set to 0**, by default this parameter is calculated so that each patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during training. **Default value: 30** \n","\n","**`initial_learning_rate`:** This parameter represents the initial value to be used as learning rate in the optimizer. **Default value: 0.001**"]},{"cell_type":"code","metadata":{"id":"oa5cDZ7f_PF6","cellView":"form"},"source":["#@markdown ###Path to training images and parameters\n","\n","model_path = \"\" #@param {type: \"string\"} \n","model_name = \"\" #@param {type: \"string\"} \n","number_of_epochs = 80#@param {type:\"integer\"}\n","batch_size = 16#@param {type:\"integer\"}\n","\n","number_of_steps = 0#@param {type:\"integer\"}\n","percentage_validation = 30 #@param {type:\"number\"}\n","initial_learning_rate = 0.001 #@param {type:\"number\"}\n","\n","\n","percentage_validation /= 100\n","if number_of_steps == 0: \n"," number_of_steps = int((1-percentage_validation)*n_patches/batch_size)\n"," print('Number of steps: '+str(number_of_steps))\n","\n","# Pretrained model path initialised here so next cell does not need to be run\n","h5_file_path = ''\n","Use_pretrained_model = False\n","\n","if not ('patches' in locals()):\n"," # W = '\\033[0m' # white (normal)\n"," # R = '\\033[31m' # red\n"," print(WARNING+'!! WARNING: No patches were found in memory currently. !!')\n","\n","Save_path = os.path.join(model_path, model_name)\n","if os.path.exists(Save_path):\n"," print(bcolors.WARNING+'The model folder already exists and will be overwritten.'+bcolors.NORMAL)\n","\n","print('-----------------------------')\n","print('Training parameters set.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"WIyEvQBWLp9n"},"source":["\n","## **4.2. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a Deep-STORM 2D model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"oHL5g0w8LqR0","cellView":"form"},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","Weights_choice = \"best\" #@param [\"last\", \"best\"]\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".hdf5\")\n","\n","# --------------------- Download the a model provided in the XXX ------------------------\n","\n"," if pretrained_model_choice == \"Model_name\":\n"," pretrained_model_name = \"Model_name\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the 2D_Demo_Model_from_Stardist_2D_paper\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path) \n"," wget.download(\"\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".hdf5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_'+Weights_choice+'.hdf5 pretrained model does not exist'+bcolors.NORMAL)\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead.'+bcolors.NORMAL)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+bcolors.NORMAL)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print('No pretrained network will be used.')\n"," h5_file_path = ''\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"OADNcie-LHxA"},"source":["## **4.4. Start Training**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches.\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder from Google Drive as all data can be erased at the next training if using the same folder."]},{"cell_type":"code","metadata":{"id":"qDgMu_mAK8US","cellView":"form"},"source":["#@markdown ##Start training\n","\n","# Start the clock to measure how long it takes\n","start = time.time()\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","\n","#here we check that no model with the same name already exist, if so delete\n","if os.path.exists(Save_path):\n"," shutil.rmtree(Save_path)\n","\n","# Create the model folder!\n","os.makedirs(Save_path)\n","\n","# Export pdf summary \n","pdf_export(raw_data = load_raw_data, pretrained_model = Use_pretrained_model)\n","\n","# Let's go !\n","train_model(patches, heatmaps, Save_path, \n"," steps_per_epoch=number_of_steps, epochs=number_of_epochs, batch_size=batch_size,\n"," upsampling_factor = upsampling_factor,\n"," validation_split = percentage_validation,\n"," initial_learning_rate = initial_learning_rate, \n"," pretrained_model_path = h5_file_path,\n"," L2_weighting_factor = L2_weighting_factor)\n","\n","# # Show info about the GPU memory useage\n","# !nvidia-smi\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n","\n","# export pdf after training to update the existing document\n","pdf_export(trained = True, raw_data = load_raw_data, pretrained_model = Use_pretrained_model)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"4N7-ShZpLhwr"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**"]},{"cell_type":"code","metadata":{"id":"JDRsm7uKoBa-","cellView":"form"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","#@markdown #####During training, the model files are automatically saved inside a folder named after the parameter `model_name` (see section 4.1). Provide the name of this folder as `QC_model_path` . \n","\n","QC_model_path = \"\" #@param {type:\"string\"}\n","\n","if (Use_the_current_trained_model): \n"," QC_model_path = os.path.join(model_path, model_name)\n","\n","if os.path.exists(QC_model_path):\n"," print(\"The \"+os.path.basename(QC_model_path)+\" model will be evaluated\")\n","else:\n"," print(bcolors.WARNING+'!! WARNING: The chosen model does not exist !!'+bcolors.NORMAL)\n"," print('Please make sure you provide a valid model path before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Gw7KaHZUoHC4"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"id":"qUc-JMOcoGNZ","cellView":"form"},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(os.path.join(QC_model_path,'Quality Control/training_evaluation.csv'),'r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," if row:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(os.path.join(QC_model_path,'Quality Control/lossCurvePlots.png'), bbox_inches='tight', pad_inches=0)\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"32eNQjFioQkY"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"QC_image_folder\" using teh corresponding localization data contained in \"QC_loc_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"dhlTnxC5lUZy","cellView":"form"},"source":["\n","# ------------------------ User input ------------------------\n","#@markdown ##Choose the folders that contain your Quality Control dataset\n","QC_image_folder = \"\" #@param{type:\"string\"}\n","QC_loc_folder = \"\" #@param{type:\"string\"}\n","#@markdown Get pixel size from file?\n","get_pixel_size_from_file = True #@param {type:\"boolean\"}\n","#@markdown Otherwise, use this value:\n","pixel_size = 100 #@param {type:\"number\"}\n","\n","if get_pixel_size_from_file:\n"," pixel_size_INPUT = None\n","else:\n"," pixel_size_INPUT = pixel_size\n","\n","\n","# ------------------------ QC analysis loop over provided dataset ------------------------\n","\n","savePath = os.path.join(QC_model_path, 'Quality Control')\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(os.path.join(savePath, os.path.basename(QC_model_path)+\"_QC_metrics.csv\"), \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"WF v. GT mSSIM\", \"Prediction v. GT NRMSE\",\"WF v. GT NRMSE\", \"Prediction v. GT PSNR\", \"WF v. GT PSNR\"])\n","\n"," # These lists will be used to collect all the metrics values per slice\n"," file_name_list = []\n"," slice_number_list = []\n"," mSSIM_GvP_list = []\n"," mSSIM_GvWF_list = []\n"," NRMSE_GvP_list = []\n"," NRMSE_GvWF_list = []\n"," PSNR_GvP_list = []\n"," PSNR_GvWF_list = []\n","\n"," # Let's loop through the provided dataset in the QC folders\n","\n"," for (imageFilename, locFilename) in zip(list_files(QC_image_folder, 'tif'), list_files(QC_loc_folder, 'csv')):\n"," print('--------------')\n"," print(imageFilename)\n"," print(locFilename)\n","\n"," # Get the prediction\n"," batchFramePredictionLocalization(QC_image_folder, imageFilename, QC_model_path, savePath, pixel_size = pixel_size_INPUT)\n","\n"," # test_model(QC_image_folder, imageFilename, QC_model_path, savePath, display=False);\n"," thisPrediction = io.imread(os.path.join(savePath, 'Predicted_'+imageFilename))\n"," thisWidefield = io.imread(os.path.join(savePath, 'Widefield_'+imageFilename))\n","\n"," Mhr = thisPrediction.shape[0]\n"," Nhr = thisPrediction.shape[1]\n","\n"," if pixel_size_INPUT == None:\n"," pixel_size, N, M = getPixelSizeTIFFmetadata(os.path.join(QC_image_folder,imageFilename))\n","\n"," upsampling_factor = int(Mhr/M)\n"," print('Upsampling factor: '+str(upsampling_factor))\n"," pixel_size_hr = pixel_size/upsampling_factor # in nm\n","\n"," # Load the localization file and display the first\n"," LocData = pd.read_csv(os.path.join(QC_loc_folder,locFilename), index_col=0)\n","\n"," x = np.array(list(LocData['x [nm]']))\n"," y = np.array(list(LocData['y [nm]']))\n"," locImage = FromLoc2Image_SimpleHistogram(x, y, image_size = (Mhr,Nhr), pixel_size = pixel_size_hr)\n","\n"," # Remove extension from filename\n"," imageFilename_no_extension = os.path.splitext(imageFilename)[0]\n","\n"," # io.imsave(os.path.join(savePath, 'GT_image_'+imageFilename), locImage)\n"," saveAsTIF(savePath, 'GT_image_'+imageFilename_no_extension, locImage, pixel_size_hr)\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and prediction\n"," test_GT_norm, test_prediction_norm = norm_minmse(locImage, thisPrediction, normalize_gt=True)\n"," # Normalize the images wrt each other by minimizing the MSE between GT and Source image\n"," test_GT_norm, test_wf_norm = norm_minmse(locImage, thisWidefield, normalize_gt=True)\n","\n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\n","\n"," # Calculate the SSIM maps\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = structural_similarity(test_GT_norm, test_prediction_norm, data_range=1., full=True)\n"," index_SSIM_GTvsWF, img_SSIM_GTvsWF = structural_similarity(test_GT_norm, test_wf_norm, data_range=1., full=True)\n","\n","\n"," # Save ssim_maps\n"," img_SSIM_GTvsPrediction_32bit = np.float32(img_SSIM_GTvsPrediction)\n"," # io.imsave(os.path.join(savePath,'SSIM_GTvsPrediction_'+imageFilename),img_SSIM_GTvsPrediction_32bit)\n"," saveAsTIF(savePath,'SSIM_GTvsPrediction_'+imageFilename_no_extension, img_SSIM_GTvsPrediction_32bit, pixel_size_hr)\n","\n","\n"," img_SSIM_GTvsWF_32bit = np.float32(img_SSIM_GTvsWF)\n"," # io.imsave(os.path.join(savePath,'SSIM_GTvsWF_'+imageFilename),img_SSIM_GTvsWF_32bit)\n"," saveAsTIF(savePath,'SSIM_GTvsWF_'+imageFilename_no_extension, img_SSIM_GTvsWF_32bit, pixel_size_hr)\n","\n"," \n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsWF = np.sqrt(np.square(test_GT_norm - test_wf_norm))\n","\n"," # Save SE maps\n"," img_RSE_GTvsPrediction_32bit = np.float32(img_RSE_GTvsPrediction)\n"," # io.imsave(os.path.join(savePath,'RSE_GTvsPrediction_'+imageFilename),img_RSE_GTvsPrediction_32bit)\n"," saveAsTIF(savePath,'RSE_GTvsPrediction_'+imageFilename_no_extension, img_RSE_GTvsPrediction_32bit, pixel_size_hr)\n","\n"," img_RSE_GTvsWF_32bit = np.float32(img_RSE_GTvsWF)\n"," # io.imsave(os.path.join(savePath,'RSE_GTvsWF_'+imageFilename),img_RSE_GTvsWF_32bit)\n"," saveAsTIF(savePath,'RSE_GTvsWF_'+imageFilename_no_extension, img_RSE_GTvsWF_32bit, pixel_size_hr)\n","\n","\n"," # -------------------------------- Calculate the RSE metrics and save them --------------------------------\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsWF = np.sqrt(np.mean(img_RSE_GTvsWF))\n"," \n"," # We can also measure the peak signal to noise ratio between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsWF = psnr(test_GT_norm,test_wf_norm,data_range=1.0)\n","\n"," writer.writerow([imageFilename,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsWF),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsWF),str(PSNR_GTvsPrediction), str(PSNR_GTvsWF)])\n","\n"," # Collect values to display in dataframe output\n"," file_name_list.append(imageFilename)\n"," mSSIM_GvP_list.append(index_SSIM_GTvsPrediction)\n"," mSSIM_GvWF_list.append(index_SSIM_GTvsWF)\n"," NRMSE_GvP_list.append(NRMSE_GTvsPrediction)\n"," NRMSE_GvWF_list.append(NRMSE_GTvsWF)\n"," PSNR_GvP_list.append(PSNR_GTvsPrediction)\n"," PSNR_GvWF_list.append(PSNR_GTvsWF)\n","\n","\n","# Table with metrics as dataframe output\n","pdResults = pd.DataFrame(index = file_name_list)\n","pdResults[\"Prediction v. GT mSSIM\"] = mSSIM_GvP_list\n","pdResults[\"Wide-field v. GT mSSIM\"] = mSSIM_GvWF_list\n","pdResults[\"Prediction v. GT NRMSE\"] = NRMSE_GvP_list\n","pdResults[\"Wide-field v. GT NRMSE\"] = NRMSE_GvWF_list\n","pdResults[\"Prediction v. GT PSNR\"] = PSNR_GvP_list\n","pdResults[\"Wide-field v. GT PSNR\"] = PSNR_GvWF_list\n","\n","\n","# ------------------------ Display ------------------------\n","\n","print('--------------------------------------------')\n","@interact\n","def show_QC_results(file = list_files(QC_image_folder, 'tif')):\n","\n"," plt.figure(figsize=(15,15))\n"," # Target (Ground-truth)\n"," plt.subplot(3,3,1)\n"," plt.axis('off')\n"," img_GT = io.imread(os.path.join(savePath, 'GT_image_'+file))\n"," plt.imshow(img_GT, norm = simple_norm(img_GT, percent = 99.5))\n"," plt.title('Target',fontsize=15)\n","\n"," # Wide-field\n"," plt.subplot(3,3,2)\n"," plt.axis('off')\n"," img_Source = io.imread(os.path.join(savePath, 'Widefield_'+file))\n"," plt.imshow(img_Source, norm = simple_norm(img_Source, percent = 99.5))\n"," plt.title('Widefield',fontsize=15)\n","\n"," #Prediction\n"," plt.subplot(3,3,3)\n"," plt.axis('off')\n"," img_Prediction = io.imread(os.path.join(savePath, 'Predicted_'+file))\n"," plt.imshow(img_Prediction, norm = simple_norm(img_Prediction, percent = 99.5))\n"," plt.title('Prediction',fontsize=15)\n","\n"," #Setting up colours\n"," cmap = plt.cm.CMRmap\n","\n"," #SSIM between GT and Source\n"," plt.subplot(3,3,5)\n"," #plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n"," img_SSIM_GTvsWF = io.imread(os.path.join(savePath, 'SSIM_GTvsWF_'+file))\n"," imSSIM_GTvsWF = plt.imshow(img_SSIM_GTvsWF, cmap = cmap, vmin=0, vmax=1)\n"," plt.colorbar(imSSIM_GTvsWF,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Widefield',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(pdResults.loc[file][\"Wide-field v. GT mSSIM\"],3)),fontsize=14)\n"," plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n"," #SSIM between GT and Prediction\n"," plt.subplot(3,3,6)\n"," #plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n"," img_SSIM_GTvsPrediction = io.imread(os.path.join(savePath, 'SSIM_GTvsPrediction_'+file))\n"," imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n"," plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(pdResults.loc[file][\"Prediction v. GT mSSIM\"],3)),fontsize=14)\n","\n"," #Root Squared Error between GT and Source\n"," plt.subplot(3,3,8)\n"," #plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n"," img_RSE_GTvsWF = io.imread(os.path.join(savePath, 'RSE_GTvsWF_'+file))\n"," imRSE_GTvsWF = plt.imshow(img_RSE_GTvsWF, cmap = cmap, vmin=0, vmax = 1)\n"," plt.colorbar(imRSE_GTvsWF,fraction=0.046,pad=0.04)\n"," plt.title('Target vs. Widefield',fontsize=15)\n"," plt.xlabel('NRMSE: '+str(round(pdResults.loc[file][\"Wide-field v. GT NRMSE\"],3))+', PSNR: '+str(round(pdResults.loc[file][\"Wide-field v. GT PSNR\"],3)),fontsize=14)\n"," plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n"," #Root Squared Error between GT and Prediction\n"," plt.subplot(3,3,9)\n"," #plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n"," img_RSE_GTvsPrediction = io.imread(os.path.join(savePath, 'RSE_GTvsPrediction_'+file))\n"," imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n"," plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('NRMSE: '+str(round(pdResults.loc[file][\"Prediction v. GT NRMSE\"],3))+', PSNR: '+str(round(pdResults.loc[file][\"Prediction v. GT PSNR\"],3)),fontsize=14)\n"," plt.savefig(QC_model_path+'/Quality Control/QC_example_data.png', bbox_inches='tight', pad_inches=0)\n","print('--------------------------------------------')\n","pdResults.head()\n","\n","# Export pdf wth summary of QC results\n","qc_pdf_export()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"yTRou0izLjhd"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"eAf8aBDmWTx7"},"source":["## **6.1 Generate image prediction and localizations from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the found localizations csv.\n","\n","**`batch_size`:** This paramter determines how many frames are processed by any single pass on the GPU. A higher `batch_size` will make the prediction faster but will use more GPU memory. If an OutOfMemory (OOM) error occurs, decrease the `batch_size`. **DEFAULT: 4**\n","\n","**`threshold`:** This paramter determines threshold for local maxima finding. The value is expected to reside in the range **[0,1]**. A higher `threshold` will result in less localizations. **DEFAULT: 0.1**\n","\n","**`neighborhood_size`:** This paramter determines size of the neighborhood within which the prediction needs to be a local maxima in recovery pixels (CCD pixel/upsampling_factor). A high `neighborhood_size` will make the prediction slower and potentially discard nearby localizations. **DEFAULT: 3**\n","\n","**`use_local_average`:** This paramter determines whether to locally average the prediction in a 3x3 neighborhood to get the final localizations. If set to **True** it will make inference slightly slower depending on the size of the FOV. **DEFAULT: True**\n"]},{"cell_type":"code","metadata":{"id":"7qn06T_A0lxf","cellView":"form"},"source":["\n","# ------------------------------- User input -------------------------------\n","#@markdown ### Data parameters\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","#@markdown Get pixel size from file?\n","get_pixel_size_from_file = True #@param {type:\"boolean\"}\n","#@markdown Otherwise, use this value (in nm):\n","pixel_size = 100 #@param {type:\"number\"}\n","\n","#@markdown ### Model parameters\n","#@markdown Do you want to use the model you just trained?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","#@markdown Otherwise, please provide path to the model folder below\n","prediction_model_path = \"\" #@param {type:\"string\"}\n","\n","#@markdown ### Prediction parameters\n","batch_size = 4#@param {type:\"integer\"}\n","\n","#@markdown ### Post processing parameters\n","threshold = 0.1#@param {type:\"number\"}\n","neighborhood_size = 3#@param {type:\"integer\"}\n","#@markdown Do you want to locally average the model output with CoG estimator ?\n","use_local_average = True #@param {type:\"boolean\"}\n","\n","\n","if get_pixel_size_from_file:\n"," pixel_size = None\n","\n","if (Use_the_current_trained_model): \n"," prediction_model_path = os.path.join(model_path, model_name)\n","\n","if os.path.exists(prediction_model_path):\n"," print(\"The \"+os.path.basename(prediction_model_path)+\" model will be used.\")\n","else:\n"," print(bcolors.WARNING+'!! WARNING: The chosen model does not exist !!'+bcolors.NORMAL)\n"," print('Please make sure you provide a valid model path before proceeding further.')\n","\n","# inform user whether local averaging is being used\n","if use_local_average == True: \n"," print('Using local averaging')\n","\n","if not os.path.exists(Result_folder):\n"," print('Result folder was created.')\n"," os.makedirs(Result_folder)\n","\n","\n","# ------------------------------- Run predictions -------------------------------\n","\n","start = time.time()\n","#%% This script tests the trained fully convolutional network based on the \n","# saved training weights, and normalization created using train_model.\n","\n","if os.path.isdir(Data_folder): \n"," for filename in list_files(Data_folder, 'tif'):\n"," # run the testing/reconstruction process\n"," print(\"------------------------------------\")\n"," print(\"Running prediction on: \"+ filename)\n"," batchFramePredictionLocalization(Data_folder, filename, prediction_model_path, Result_folder, \n"," batch_size, \n"," threshold, \n"," neighborhood_size, \n"," use_local_average,\n"," pixel_size = pixel_size)\n","\n","elif os.path.isfile(Data_folder):\n"," batchFramePredictionLocalization(os.path.dirname(Data_folder), os.path.basename(Data_folder), prediction_model_path, Result_folder, \n"," batch_size, \n"," threshold, \n"," neighborhood_size, \n"," use_local_average, \n"," pixel_size = pixel_size)\n","\n","\n","\n","print('--------------------------------------------------------------------')\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n","\n","\n","# ------------------------------- Interactive display -------------------------------\n","\n","print('--------------------------------------------------------------------')\n","print('---------------------------- Previews ------------------------------')\n","print('--------------------------------------------------------------------')\n","\n","if os.path.isdir(Data_folder): \n"," @interact\n"," def show_QC_results(file = list_files(Data_folder, 'tif')):\n","\n"," plt.figure(figsize=(15,7.5))\n"," # Wide-field\n"," plt.subplot(1,2,1)\n"," plt.axis('off')\n"," img_Source = io.imread(os.path.join(Result_folder, 'Widefield_'+file))\n"," plt.imshow(img_Source, norm = simple_norm(img_Source, percent = 99.5))\n"," plt.title('Widefield', fontsize=15)\n"," # Prediction\n"," plt.subplot(1,2,2)\n"," plt.axis('off')\n"," img_Prediction = io.imread(os.path.join(Result_folder, 'Predicted_'+file))\n"," plt.imshow(img_Prediction, norm = simple_norm(img_Prediction, percent = 99.5))\n"," plt.title('Predicted',fontsize=15)\n","\n","if os.path.isfile(Data_folder):\n","\n"," plt.figure(figsize=(15,7.5))\n"," # Wide-field\n"," plt.subplot(1,2,1)\n"," plt.axis('off')\n"," img_Source = io.imread(os.path.join(Result_folder, 'Widefield_'+os.path.basename(Data_folder)))\n"," plt.imshow(img_Source, norm = simple_norm(img_Source, percent = 99.5))\n"," plt.title('Widefield', fontsize=15)\n"," # Prediction\n"," plt.subplot(1,2,2)\n"," plt.axis('off')\n"," img_Prediction = io.imread(os.path.join(Result_folder, 'Predicted_'+os.path.basename(Data_folder)))\n"," plt.imshow(img_Prediction, norm = simple_norm(img_Prediction, percent = 99.5))\n"," plt.title('Predicted',fontsize=15)\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"ZekzexaPmzFZ"},"source":["## **6.2 Drift correction**\n","---\n","\n","The visualization above is the raw output of the network and displayed at the `upsampling_factor` chosen during model training. The display is a preview without any drift correction applied. This section performs drift correction using cross-correlation between time bins to estimate the drift.\n","\n","**`Loc_file_path`:** is the path to the localization file to use for visualization.\n","\n","**`original_image_path`:** is the path to the original image. This only serves to extract the original image size and pixel size to shape the visualization properly.\n","\n","**`visualization_pixel_size`:** This parameter corresponds to the pixel size to use for the image reconstructions used for the Drift Correction estmication (in **nm**). A smaller pixel size will be more precise but will take longer to compute. **DEFAULT: 20**\n","\n","**`number_of_bins`:** This parameter defines how many temporal bins are used across the full dataset. All localizations in each bins are used ot build an image. This image is used to find the drift with respect to the image obtained from the very first bin. A typical value would correspond to about 500 frames per bin. **DEFAULT: Total number of frames / 500**\n","\n","**`polynomial_fit_degree`:** The drift obtained for each temporal bins needs to be interpolated to every single frames. This is performed by polynomial fit, the degree of which is defined here. **DEFAULT: 4**\n","\n"," The drift-corrected localization data is automaticaly saved in the `save_path` folder."]},{"cell_type":"code","metadata":{"id":"hYtP_vh6mzUP","cellView":"form"},"source":["# @markdown ##Data parameters\n","Loc_file_path = \"\" #@param {type:\"string\"}\n","# @markdown Provide information about original data. Get the info automatically from the raw data?\n","Get_info_from_file = True #@param {type:\"boolean\"}\n","# Loc_file_path = \"/content/gdrive/My Drive/Colab notebooks testing/DeepSTORM/Glia data from CL/Results from prediction/20200615-M6 with CoM localizations/Localizations_glia_actin_2D - 1-500fr_avg.csv\" #@param {type:\"string\"}\n","original_image_path = \"\" #@param {type:\"string\"}\n","# @markdown Otherwise, please provide image width, height (in pixels) and pixel size (in nm)\n","image_width = 256#@param {type:\"integer\"}\n","image_height = 256#@param {type:\"integer\"}\n","pixel_size = 100 #@param {type:\"number\"}\n","\n","# @markdown ##Drift correction parameters\n","visualization_pixel_size = 20#@param {type:\"number\"}\n","number_of_bins = 50#@param {type:\"integer\"}\n","polynomial_fit_degree = 4#@param {type:\"integer\"}\n","\n","# @markdown ##Saving parameters\n","save_path = '' #@param {type:\"string\"}\n","\n","\n","# Let's go !\n","start = time.time()\n","\n","# Get info from the raw file if selected\n","if Get_info_from_file:\n"," pixel_size, image_width, image_height = getPixelSizeTIFFmetadata(original_image_path, display=True)\n","\n","# Read the localizations in\n","LocData = pd.read_csv(Loc_file_path)\n","\n","# Calculate a few variables \n","Mhr = int(math.ceil(image_height*pixel_size/visualization_pixel_size))\n","Nhr = int(math.ceil(image_width*pixel_size/visualization_pixel_size))\n","nFrames = max(LocData['frame'])\n","x_max = max(LocData['x [nm]'])\n","y_max = max(LocData['y [nm]'])\n","image_size = (Mhr, Nhr)\n","n_locs = len(LocData.index)\n","\n","print('Image size: '+str(image_size))\n","print('Number of frames in data: '+str(nFrames))\n","print('Number of localizations in data: '+str(n_locs))\n","\n","blocksize = math.ceil(nFrames/number_of_bins)\n","print('Number of frames per block: '+str(blocksize))\n","\n","blockDataFrame = LocData[(LocData['frame'] < blocksize)].copy()\n","xc_array = blockDataFrame['x [nm]'].to_numpy(dtype=np.float32)\n","yc_array = blockDataFrame['y [nm]'].to_numpy(dtype=np.float32)\n","\n","# Preparing the Reference image\n","photon_array = np.ones(yc_array.shape[0])\n","sigma_array = np.ones(yc_array.shape[0])\n","ImageRef = FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = image_size, pixel_size = visualization_pixel_size)\n","ImagesRef = np.rot90(ImageRef, k=2)\n","\n","xDrift = np.zeros(number_of_bins)\n","yDrift = np.zeros(number_of_bins)\n","\n","filename_no_extension = os.path.splitext(os.path.basename(Loc_file_path))[0]\n","\n","with open(os.path.join(save_path, filename_no_extension+\"_DriftCorrectionData.csv\"), \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"Block #\", \"x-drift [nm]\",\"y-drift [nm]\"])\n","\n"," for b in tqdm(range(number_of_bins)):\n","\n"," blockDataFrame = LocData[(LocData['frame'] >= (b*blocksize)) & (LocData['frame'] < ((b+1)*blocksize))].copy()\n"," xc_array = blockDataFrame['x [nm]'].to_numpy(dtype=np.float32)\n"," yc_array = blockDataFrame['y [nm]'].to_numpy(dtype=np.float32)\n","\n"," photon_array = np.ones(yc_array.shape[0])\n"," sigma_array = np.ones(yc_array.shape[0])\n"," ImageBlock = FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = image_size, pixel_size = visualization_pixel_size)\n","\n"," XC = fftconvolve(ImagesRef, ImageBlock, mode = 'same')\n"," yDrift[b], xDrift[b] = subPixelMaxLocalization(XC, method = 'CoM')\n","\n"," # saveAsTIF(save_path, 'ImageBlock'+str(b), ImageBlock, visualization_pixel_size)\n"," # saveAsTIF(save_path, 'XCBlock'+str(b), XC, visualization_pixel_size)\n"," writer.writerow([str(b), str((xDrift[b]-xDrift[0])*visualization_pixel_size), str((yDrift[b]-yDrift[0])*visualization_pixel_size)])\n","\n","\n","print('--------------------------------------------------------------------')\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n","\n","print('Fitting drift data...')\n","bin_number = np.arange(number_of_bins)*blocksize + blocksize/2\n","xDrift = (xDrift-xDrift[0])*visualization_pixel_size\n","yDrift = (yDrift-yDrift[0])*visualization_pixel_size\n","\n","xDriftCoeff = np.polyfit(bin_number, xDrift, polynomial_fit_degree)\n","yDriftCoeff = np.polyfit(bin_number, yDrift, polynomial_fit_degree)\n","\n","xDriftFit = np.poly1d(xDriftCoeff)\n","yDriftFit = np.poly1d(yDriftCoeff)\n","bins = np.arange(nFrames)\n","xDriftInterpolated = xDriftFit(bins)\n","yDriftInterpolated = yDriftFit(bins)\n","\n","\n","# ------------------ Displaying the image results ------------------\n","\n","plt.figure(figsize=(15,10))\n","plt.plot(bin_number,xDrift, 'r+', label='x-drift')\n","plt.plot(bin_number,yDrift, 'b+', label='y-drift')\n","plt.plot(bins,xDriftInterpolated, 'r-', label='y-drift (fit)')\n","plt.plot(bins,yDriftInterpolated, 'b-', label='y-drift (fit)')\n","plt.title('Cross-correlation estimated drift')\n","plt.ylabel('Drift [nm]')\n","plt.xlabel('Bin number')\n","plt.legend();\n","\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\", hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n","\n","\n","# ------------------ Actual drift correction -------------------\n","\n","print('Correcting localization data...')\n","xc_array = LocData['x [nm]'].to_numpy(dtype=np.float32)\n","yc_array = LocData['y [nm]'].to_numpy(dtype=np.float32)\n","frames = LocData['frame'].to_numpy(dtype=np.int32)\n","\n","\n","xc_array_Corr, yc_array_Corr = correctDriftLocalization(xc_array, yc_array, frames, xDriftInterpolated, yDriftInterpolated)\n","ImageRaw = FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = image_size, pixel_size = visualization_pixel_size)\n","ImageCorr = FromLoc2Image_SimpleHistogram(xc_array_Corr, yc_array_Corr, image_size = image_size, pixel_size = visualization_pixel_size)\n","\n","\n","# ------------------ Displaying the imge results ------------------\n","plt.figure(figsize=(15,7.5))\n","# Raw\n","plt.subplot(1,2,1)\n","plt.axis('off')\n","plt.imshow(ImageRaw, norm = simple_norm(ImageRaw, percent = 99.5))\n","plt.title('Raw', fontsize=15);\n","# Corrected\n","plt.subplot(1,2,2)\n","plt.axis('off')\n","plt.imshow(ImageCorr, norm = simple_norm(ImageCorr, percent = 99.5))\n","plt.title('Corrected',fontsize=15);\n","\n","\n","# ------------------ Table with info -------------------\n","driftCorrectedLocData = pd.DataFrame()\n","driftCorrectedLocData['frame'] = frames\n","driftCorrectedLocData['x [nm]'] = xc_array_Corr\n","driftCorrectedLocData['y [nm]'] = yc_array_Corr\n","driftCorrectedLocData['confidence [a.u]'] = LocData['confidence [a.u]']\n","\n","driftCorrectedLocData.to_csv(os.path.join(save_path, filename_no_extension+'_DriftCorrected.csv'))\n","print('-------------------------------')\n","print('Corrected localizations saved.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"mzOuc-V7rB-r"},"source":["## **6.3 Visualization of the localizations**\n","---\n","\n","\n","The visualization in section 6.1 is the raw output of the network and displayed at the `upsampling_factor` chosen during model training. This section performs visualization of the result by plotting the localizations as a simple histogram.\n","\n","**`Loc_file_path`:** is the path to the localization file to use for visualization.\n","\n","**`original_image_path`:** is the path to the original image. This only serves to extract the original image size and pixel size to shape the visualization properly.\n","\n","**`visualization_pixel_size`:** This parameter corresponds to the pixel size to use for the final image reconstruction (in **nm**). **DEFAULT: 10**\n","\n","**`visualization_mode`:** This parameter defines what visualization method is used to visualize the final image. NOTES: The Integrated Gaussian can be quite slow. **DEFAULT: Simple histogram.**\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"876yIXnqq-nW","cellView":"form"},"source":["# @markdown ##Data parameters\n","Use_current_drift_corrected_localizations = True #@param {type:\"boolean\"}\n","# @markdown Otherwise provide a localization file path\n","Loc_file_path = \"\" #@param {type:\"string\"}\n","# @markdown Provide information about original data. Get the info automatically from the raw data?\n","Get_info_from_file = True #@param {type:\"boolean\"}\n","# Loc_file_path = \"/content/gdrive/My Drive/Colab notebooks testing/DeepSTORM/Glia data from CL/Results from prediction/20200615-M6 with CoM localizations/Localizations_glia_actin_2D - 1-500fr_avg.csv\" #@param {type:\"string\"}\n","original_image_path = \"\" #@param {type:\"string\"}\n","# @markdown Otherwise, please provide image width, height (in pixels) and pixel size (in nm)\n","image_width = 256#@param {type:\"integer\"}\n","image_height = 256#@param {type:\"integer\"}\n","pixel_size = 100#@param {type:\"number\"}\n","\n","# @markdown ##Visualization parameters\n","visualization_pixel_size = 10#@param {type:\"number\"}\n","visualization_mode = \"Simple histogram\" #@param [\"Simple histogram\", \"Integrated Gaussian (SLOW!)\"]\n","\n","if not Use_current_drift_corrected_localizations:\n"," filename_no_extension = os.path.splitext(os.path.basename(Loc_file_path))[0]\n","\n","\n","if Get_info_from_file:\n"," pixel_size, image_width, image_height = getPixelSizeTIFFmetadata(original_image_path, display=True)\n","\n","if Use_current_drift_corrected_localizations:\n"," LocData = driftCorrectedLocData\n","else:\n"," LocData = pd.read_csv(Loc_file_path)\n","\n","Mhr = int(math.ceil(image_height*pixel_size/visualization_pixel_size))\n","Nhr = int(math.ceil(image_width*pixel_size/visualization_pixel_size))\n","\n","\n","nFrames = max(LocData['frame'])\n","x_max = max(LocData['x [nm]'])\n","y_max = max(LocData['y [nm]'])\n","image_size = (Mhr, Nhr)\n","\n","print('Image size: '+str(image_size))\n","print('Number of frames in data: '+str(nFrames))\n","print('Number of localizations in data: '+str(len(LocData.index)))\n","\n","xc_array = LocData['x [nm]'].to_numpy()\n","yc_array = LocData['y [nm]'].to_numpy()\n","if (visualization_mode == 'Simple histogram'):\n"," locImage = FromLoc2Image_SimpleHistogram(xc_array, yc_array, image_size = image_size, pixel_size = visualization_pixel_size)\n","elif (visualization_mode == 'Shifted histogram'):\n"," print(bcolors.WARNING+'Method not implemented yet!'+bcolors.NORMAL)\n"," locImage = np.zeros(image_size)\n","elif (visualization_mode == 'Integrated Gaussian (SLOW!)'):\n"," photon_array = np.ones(xc_array.shape)\n"," sigma_array = np.ones(xc_array.shape)\n"," locImage = FromLoc2Image_Erf(xc_array, yc_array, photon_array, sigma_array, image_size = image_size, pixel_size = visualization_pixel_size)\n","\n","print('--------------------------------------------------------------------')\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","minutes, seconds = divmod(dt, 60) \n","hours, minutes = divmod(minutes, 60) \n","print(\"Time elapsed:\",hours, \"hour(s)\",minutes,\"min(s)\",round(seconds),\"sec(s)\")\n","\n","# Display\n","plt.figure(figsize=(20,10))\n","plt.axis('off')\n","# plt.imshow(locImage, cmap='gray');\n","plt.imshow(locImage, norm = simple_norm(locImage, percent = 99.5));\n","\n","\n","LocData.head()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"PdOhWwMn1zIT","cellView":"form"},"source":["# @markdown ---\n","\n","# @markdown #Play this cell to save the visualization\n","\n","# @markdown ####Please select a path to the folder where to save the visualization.\n","save_path = \"\" #@param {type:\"string\"}\n","\n","if not os.path.exists(save_path):\n"," os.makedirs(save_path)\n"," print('Folder created.')\n","\n","saveAsTIF(save_path, filename_no_extension+'_Visualization', locImage, visualization_pixel_size)\n","print('Image saved.')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"1EszIF4Dkz_n"},"source":["## **6.4. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"0BvykD0YIk89"},"source":["# **7. Version log**\n","\n","---\n","**v1.13**: \n","* The section 1 and 2 are now swapped for better export of *requirements.txt*. \n","* This version also now includes built-in version check and the version log that you're reading now.\n","\n","---"]},{"cell_type":"markdown","metadata":{"id":"UgN-NooKk3nV"},"source":["\n","#**Thank you for using Deep-STORM 2D!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv b/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv index 809bdcb8..d3456a90 100644 --- a/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv +++ b/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv @@ -1 +1 @@ -1.12 +1.13 diff --git a/Colab_notebooks/Noise2Void_2D_ZeroCostDL4Mic.ipynb b/Colab_notebooks/Noise2Void_2D_ZeroCostDL4Mic.ipynb index 69d53ef4..6a37f4ac 100644 --- a/Colab_notebooks/Noise2Void_2D_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/Noise2Void_2D_ZeroCostDL4Mic.ipynb @@ -1 +1 @@ -{"nbformat":4,"nbformat_minor":0,"metadata":{"accelerator":"GPU","colab":{"name":"Noise2Void_2D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1hMjEc-Ex7j-jeYGclaPw2x3OgbkeC6Bl","timestamp":1610626439596},{"file_id":"1_W4q9V1ExGFldTUBvGK91E0LG5QMc7K6","timestamp":1602523405636},{"file_id":"1t9a-44km730bI7F4I08-6Xh7wEZuL98p","timestamp":1591013189418},{"file_id":"11TigzvLl4FSSwFHUNwLzZKI2IAix4Nmu","timestamp":1586415689249},{"file_id":"1_dSnxUg_qtNWjrPc7D6RWDWlCanEL4Ve","timestamp":1585153449937},{"file_id":"1bKo8jYVZPPgXPa_-Gdu1KhDnNN4vYfLx","timestamp":1583200150464}],"collapsed_sections":[],"toc_visible":true,"machine_shape":"hm"},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.4"}},"cells":[{"cell_type":"markdown","metadata":{"id":"V9zNGvape2-I"},"source":["# **Noise2Void (2D)**\n","\n","---\n","\n"," Noise2Void is a deep-learning method that can be used to denoise many types of images, including microscopy images and which was originally published by [Krull *et al.* on arXiv](https://arxiv.org/abs/1811.10980). It allows denoising of image data in a self-supervised manner, therefore high-quality, low noise equivalent images are not necessary to train this network. This is performed by \"masking\" a random subset of pixels in the noisy image and training the network to predict the values in these pixels. The resulting output is a denoised version of the image. Noise2Void is based on the popular U-Net network architecture, adapted from [CARE](https://www.nature.com/articles/s41592-018-0216-7).\n","\n"," **This particular notebook enables self-supervised denoised of 2D dataset. If you are interested in 3D dataset, you should use the Noise2Void 3D notebook instead.**\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the following paper:\n","\n","**Noise2Void - Learning Denoising from Single Noisy Images**\n","from Krull *et al.* published on arXiv in 2018 (https://arxiv.org/abs/1811.10980)\n","\n","And source code found in: /~https://github.com/juglab/n2v\n","\n","**Please also cite this original paper when using or developing this notebook.**\n"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"vNMDQHm0Ah-Z"},"source":["# **0. Before getting started**\n","---\n","\n","Before you run the notebook, please ensure that you are logged into your Google account and have the training and/or data to process in your Google Drive.\n","\n","For Noise2Void to train, it only requires a single noisy image but multiple images can be used. Information on how to generate a training dataset is available in our Wiki page: /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","Please note that you currently can **only use .tif files!**\n","\n","**We strongly recommend that you generate high signal to noise ration version of your noisy images (Quality control dataset). These images can be used to assess the quality of your trained model**. The quality control assessment can be done directly in this notebook.\n","\n"," You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed.\n","\n","Here is a common data structure that can work:\n","\n","* Data\n"," - **Training dataset**\n"," - **Quality control dataset** (Optional but recomended)\n"," - Low SNR images\n"," - img_1.tif, img_2.tif\n"," - High SNR images\n"," - img_1.tif, img_2.tif \n"," - **Data to be predicted** \n"," - Results\n","\n","\n","The **Results** folder will contain the processed images, trained model and network parameters as csv file. Your original images remain unmodified.\n","\n","---\n","**Important note**\n","\n","- If you wish to **train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"L1Nwo9k5kXCm"},"source":["# **1. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"2JydTyW1kafd"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelerator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"C30tLtARkQve","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","# %tensorflow_version 1.x\n","\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-oqBTeLaImnU"},"source":["## **1.2 Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","cellView":"form"},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **2. Install Noise2Void and dependencies**\n","---"]},{"cell_type":"markdown","metadata":{"id":"6sReQBH5VKlw"},"source":["## **2.1. Install key dependencies**\n","---\n"," "]},{"cell_type":"code","metadata":{"id":"3u2mXn3XsWzd","cellView":"form"},"source":["#@markdown ##Install Noise2Void and dependencies\n","\n","\n","!pip uninstall -y tensorflow\n","!pip install tensorflow==2.4.1\n","\n","import tensorflow \n","\n","!pip uninstall -y keras-nightly\n","\n","!pip3 install h5py==2.10.0\n","\n","!pip install n2v\n","\n","!pip install wget\n","!pip install fpdf\n","!pip install memory_profiler\n","\n","!pip install q keras==2.3.1\n","\n","#Force session restart\n","exit(0)\n","print('--------')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"JF4AHIp2VOWo"},"source":["## **2.2. Restart your runtime**\n","---\n","\n","\n","\n","** Your Runtime has automatically restarted. This is normal.**\n","\n"]},{"cell_type":"markdown","metadata":{"id":"P3b6lUliVRdr"},"source":["## **2.3. Load key dependencies**\n","---\n"," "]},{"cell_type":"code","metadata":{"id":"WFKHpNCrVUPc","cellView":"form"},"source":["#@markdown ##Load key dependencies\n","\n","Notebook_version = ['1.12.3']\n","\n","\n","from builtins import any as b_any\n","\n","def get_requirements_path():\n"," # Store requirements file in 'contents' directory \n"," current_dir = os.getcwd()\n"," dir_count = current_dir.count('/') - 1\n"," path = '../' * (dir_count) + 'requirements.txt'\n"," return path\n","\n","def filter_files(file_list, filter_list):\n"," filtered_list = []\n"," for fname in file_list:\n"," if b_any(fname.split('==')[0] in s for s in filter_list):\n"," filtered_list.append(fname)\n"," return filtered_list\n","\n","def build_requirements_file(before, after):\n"," path = get_requirements_path()\n","\n"," # Exporting requirements.txt for local run\n"," !pip freeze > $path\n","\n"," # Get minimum requirements file\n"," df = pd.read_csv(path, delimiter = \"\\n\")\n"," mod_list = [m.split('.')[0] for m in after if not m in before]\n"," req_list_temp = df.values.tolist()\n"," req_list = [x[0] for x in req_list_temp]\n","\n"," # Replace with package name and handle cases where import name is different to module name\n"," mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]\n"," mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list] \n"," filtered_list = filter_files(req_list, mod_replace_list)\n","\n"," file=open(path,'w')\n"," for item in filtered_list:\n"," file.writelines(item + '\\n')\n","\n"," file.close()\n","\n","import sys\n","before = [str(m) for m in sys.modules]\n","\n","#import tensorflow\n","import tensorflow \n","print('TensorFlow version:')\n","print(tensorflow.__version__)\n","print(\"Tensorflow enabled.\")\n","\n","# Here we install all libraries and other depencies to run the notebook.\n","\n","# ------- Variable specific to N2V -------\n","\n","from n2v.models import N2VConfig, N2V\n","from csbdeep.utils import plot_history\n","from n2v.utils.n2v_utils import manipulate_val_data\n","from n2v.internals.N2V_DataGenerator import N2V_DataGenerator\n","from csbdeep.io import save_tiff_imagej_compatible\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","from fpdf import FPDF, HTMLMixin\n","from datetime import datetime\n","from pip._internal.operations.freeze import freeze\n","import subprocess\n","\n","\n","%load_ext memory_profiler\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","\n","# Check if this is the latest version of the notebook\n","Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n","print('Notebook version: '+Notebook_version[0])\n","strlist = Notebook_version[0].split('.')\n","Notebook_version_main = strlist[0]+'.'+strlist[1]\n","if Notebook_version_main == Latest_notebook_version.columns:\n"," print(\"This notebook is up-to-date.\")\n","else:\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'Noise2Void 2D'\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," # add another cell \n"," if trained:\n"," training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n"," pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n"," pdf.ln(1)\n","\n"," Header_2 = 'Information for your materials and method:'\n"," pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n"," #print(all_packages)\n","\n"," #Main Packages\n"," main_packages = ''\n"," version_numbers = []\n"," for name in ['tensorflow','numpy','Keras','csbdeep']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n"," cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n"," cuda_version = cuda_version.stdout.decode('utf-8')\n"," cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n"," gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n"," gpu_name = gpu_name.stdout.decode('utf-8')\n"," gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n"," #print(cuda_version[cuda_version.find(', V')+3:-1])\n"," #print(gpu_name)\n","\n"," shape = io.imread(Training_source+'/'+os.listdir(Training_source)[0]).shape\n"," dataset_size = len(os.listdir(Training_source))\n","\n"," text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(Xdata.shape[0])+' image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+config.train_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," if pretrained_model:\n"," text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(Xdata.shape[0])+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+config.train_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was re-trained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.multi_cell(190, 5, txt = text, align='L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(26, 5, txt='Augmentation: ', ln=0)\n"," pdf.set_font('')\n"," if augmentation:\n"," aug_text = 'The dataset was augmented by default.'\n"," else:\n"," aug_text = 'No augmentation was used for training.'\n"," pdf.multi_cell(190, 5, txt=aug_text, align='L')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," if Use_Default_Advanced_Parameters:\n"," pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n"," pdf.cell(200, 5, txt='The following parameters were used for training:')\n"," pdf.ln(1)\n"," html = \"\"\" \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
number_of_epochs{0}
patch_size{1}
batch_size{2}
number_of_steps{3}
percentage_validation{4}
initial_learning_rate{5}
\n"," \"\"\".format(number_of_epochs,str(patch_size)+'x'+str(patch_size),batch_size,number_of_steps,percentage_validation,initial_learning_rate)\n"," pdf.write_html(html)\n","\n"," #pdf.multi_cell(190, 5, txt = text_2, align='L')\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," pdf.ln(1)\n"," pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(28, 5, txt= 'Training_source:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n"," # pdf.set_font('')\n"," # pdf.set_font('Arial', size = 10, style = 'B')\n"," # pdf.cell(28, 5, txt= 'Training_target:', align = 'L', ln=0)\n"," # pdf.set_font('')\n"," # pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n"," #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(21, 5, txt= 'Model Path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n"," pdf.ln(1)\n"," pdf.cell(60, 5, txt = 'Example Training Image', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread('/content/TrainingDataExample_N2V2D.png').shape\n"," pdf.image('/content/TrainingDataExample_N2V2D.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- Noise2Void: Krull, Alexander, Tim-Oliver Buchholz, and Florian Jug. \"Noise2void-learning denoising from single noisy images.\" Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2019.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," pdf.ln(3)\n"," reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(model_path+'/'+model_name+'/'+model_name+\"_training_report.pdf\")\n","\n","\n","\n"," #Make a pdf summary of the QC results\n","\n","def qc_pdf_export():\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'Noise2Void 2D'\n","\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(2)\n"," pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/lossCurvePlots.png').shape\n"," if os.path.exists(full_QC_model_path+'/Quality Control/lossCurvePlots.png'):\n"," pdf.image(full_QC_model_path+'/Quality Control/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," pdf.cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.')\n"," pdf.ln(2)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(3)\n"," pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n"," pdf.image(full_QC_model_path+'/Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/10))\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n","\n"," pdf.ln(1)\n"," html = \"\"\"\n"," \n"," \n"," \"\"\"\n"," with open(full_QC_model_path+'/Quality Control/QC_metrics_'+QC_model_name+'.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," mSSIM_PvsGT = header[1]\n"," mSSIM_SvsGT = header[2]\n"," NRMSE_PvsGT = header[3]\n"," NRMSE_SvsGT = header[4]\n"," PSNR_PvsGT = header[5]\n"," PSNR_SvsGT = header[6]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,mSSIM_PvsGT,mSSIM_SvsGT,NRMSE_PvsGT,NRMSE_SvsGT,PSNR_PvsGT,PSNR_SvsGT)\n"," html = html+header\n"," for row in metrics:\n"," image = row[0]\n"," mSSIM_PvsGT = row[1]\n"," mSSIM_SvsGT = row[2]\n"," NRMSE_PvsGT = row[3]\n"," NRMSE_SvsGT = row[4]\n"," PSNR_PvsGT = row[5]\n"," PSNR_SvsGT = row[6]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,str(round(float(mSSIM_PvsGT),3)),str(round(float(mSSIM_SvsGT),3)),str(round(float(NRMSE_PvsGT),3)),str(round(float(NRMSE_SvsGT),3)),str(round(float(PSNR_PvsGT),3)),str(round(float(PSNR_SvsGT),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}{3}{4}{5}{6}
{0}{1}{2}{3}{4}{5}{6}
\"\"\"\n"," \n"," pdf.write_html(html)\n","\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- Noise2Void: Krull, Alexander, Tim-Oliver Buchholz, and Florian Jug. \"Noise2void-learning denoising from single noisy images.\" Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2019.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," pdf.ln(3)\n"," reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n","\n","\n","# Build requirements file for local run\n","after = [str(m) for m in sys.modules]\n","build_requirements_file(before, after)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Fw0kkTU6CsU4"},"source":["# **3. Select your parameters and paths**\n","\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"WzYAA-MuaYrT"},"source":["## **3.1. Setting main training parameters**\n","---\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"CB6acvUFtWqd"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`:** These is the path to your folders containing the Training_source (noisy images). To find the path of the folder containing your datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Do not re-use the name of an existing model (saved in the same folder), otherwise it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","\n","**Training Parameters**\n","\n","**`number_of_epochs`:** Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for 100-200 epochs. Evaluate the performance after training (see 4.3.). **Default value: 100**\n"," \n","**`patch_size`:** Noise2Void divides the image into patches for training. Input the size of the patches (length of a side). The value should be between 64 and the dimensions of the image and divisible by 8. **Default value: 64**\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Noise2Void requires a large batch size for stable training. Reduce this parameter if your GPU runs out of memory. **Default value: 128**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each image / patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during the training. **Default value: 10**\n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0004**\n"]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","cellView":"form"},"source":["# create DataGenerator-object.\n","\n","datagen = N2V_DataGenerator()\n","\n","#@markdown ###Path to training image(s): \n","Training_source = \"\" #@param {type:\"string\"}\n","\n","#compatibility to easily change the name of the parameters\n","training_images = Training_source \n","imgs = datagen.load_imgs_from_directory(directory = Training_source)\n","\n","#@markdown ### Model name and path:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","full_model_path = model_path+'/'+model_name+'/'\n","\n","#@markdown ###Training Parameters\n","#@markdown Number of epochs:\n","number_of_epochs = 100#@param {type:\"number\"}\n","\n","#@markdown Patch size (pixels)\n","patch_size = 64#@param {type:\"number\"}\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True#@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please input:\n","batch_size = 128#@param {type:\"number\"}\n","number_of_steps = 0#@param {type:\"number\"}\n","percentage_validation = 10#@param {type:\"number\"}\n","initial_learning_rate = 0.0004 #@param {type:\"number\"}\n","\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," # number_of_steps is defined in the following cell in this case\n"," batch_size = 128\n"," percentage_validation = 10\n"," initial_learning_rate = 0.0004\n"," \n","\n","#here we check that no model with the same name already exist, if so print a warning\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n"," print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\"+W)\n"," \n","\n","# This will open a randomly chosen dataset input image\n","random_choice = random.choice(os.listdir(Training_source))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","# Here we check that the input images contains the expected dimensions\n","if len(x.shape) == 2:\n"," print(\"Image dimensions (y,x)\",x.shape)\n","\n","if not len(x.shape) == 2:\n"," print(bcolors.WARNING +\"Your images appear to have the wrong dimensions. Image dimension\",x.shape)\n","\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","#Hyperparameters failsafes\n","\n","# Here we check that patch_size is smaller than the smallest xy dimension of the image \n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is divisible by 8\n","if not patch_size % 8 == 0:\n"," patch_size = ((int(patch_size / 8)-1) * 8)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we disable pre-trained model by default (in case the next cell is not run)\n","Use_pretrained_model = False\n","\n","# Here we enable data augmentation by default (in case the cell is not ran)\n","Use_Data_augmentation = True\n","\n","print(\"Parameters initiated.\")\n","\n","#Here we display one image\n","norm = simple_norm(x, percent = 99)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest', norm=norm, cmap='magma')\n","plt.title('Training source')\n","plt.axis('off');\n","plt.savefig('/content/TrainingDataExample_N2V2D.png',bbox_inches='tight',pad_inches=0)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xGcl7WGP4WHt"},"source":["## **3.2. Data augmentation**\n","---"]},{"cell_type":"markdown","metadata":{"id":"5Lio8hpZ4PJ1"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n","Data augmentation is performed here by rotating the patches in XY-Plane and flip them along X-Axis. This only works if the patches are square in XY.\n","\n"," **By default data augmentation is enabled. Disable this option is you run out of RAM during the training**.\n"," "]},{"cell_type":"code","metadata":{"id":"htqjkJWt5J_8","cellView":"form"},"source":["#Data augmentation\n","\n","#@markdown ##Play this cell to enable or disable data augmentation: \n","\n","Use_Data_augmentation = True #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation:\n"," print(\"Data augmentation enabled\")\n","\n","if not Use_Data_augmentation:\n"," print(\"Data augmentation disabled\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"bQDuybvyadKU"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a N2V 2D model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"8vPkzEBNamE4","cellView":"form"},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","\n","Weights_choice = \"best\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","\n","# --------------------- Download the a model provided in the XXX ------------------------\n","\n"," if pretrained_model_choice == \"Model_name\":\n"," pretrained_model_name = \"Model_name\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the 2D_Demo_Model_from_Stardist_2D_paper\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path) \n"," wget.download(\"\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_last.h5 pretrained model does not exist')\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print(bcolors.WARNING+'No pretrained nerwork will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"rQndJj70FzfL"},"source":["# **4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"tGW2iaU6X5zi"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"id":"WMJnGJpCMa4y","cellView":"form"},"source":["#@markdown ##Create the model and dataset objects\n","\n","# --------------------- Here we delete the model folder if it already exist ------------------------\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: Model folder already exists and has been removed !!\" + W)\n"," shutil.rmtree(model_path+'/'+model_name)\n","\n","\n","# split patches from the training images\n","Xdata = datagen.generate_patches_from_list(imgs, shape=(patch_size,patch_size), augment=Use_Data_augmentation)\n","shape_of_Xdata = Xdata.shape\n","# create a threshold (10 % patches for the validation)\n","threshold = int(shape_of_Xdata[0]*(percentage_validation/100))\n","# split the patches into training patches and validation patches\n","X = Xdata[threshold:]\n","X_val = Xdata[:threshold]\n","print(Xdata.shape[0],\"patches created.\")\n","print(threshold,\"patch images for validation (\",percentage_validation,\"%).\")\n","print(Xdata.shape[0]-threshold,\"patch images for training.\")\n","%memit\n","\n","#Here we automatically define number_of_step in function of training data and batch size\n","if (Use_Default_Advanced_Parameters): \n"," number_of_steps= int(X.shape[0]/batch_size)+1\n","\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","# create a Config object\n","config = N2VConfig(X, unet_kern_size=3, \n"," train_steps_per_epoch=number_of_steps, train_epochs=number_of_epochs, \n"," train_loss='mse', batch_norm=True, train_batch_size=batch_size, n2v_perc_pix=0.198, \n"," n2v_manipulator='uniform_withCP', n2v_neighborhood_radius=5, train_learning_rate = initial_learning_rate)\n","\n","# Let's look at the parameters stored in the config-object.\n","vars(config)\n"," \n"," \n","# create network model.\n","model = N2V(config=config, name=model_name, basedir=model_path)\n","\n","# --------------------- Using pretrained model ------------------------\n","# Load the pretrained weights \n","if Use_pretrained_model:\n"," model.load_weights(h5_file_path)\n","# --------------------- ---------------------- ------------------------\n","\n","\n","print(\"Setup done.\")\n","print(config)\n","\n","\n","# creates a plot and shows one training patch and one validation patch.\n","plt.figure(figsize=(16,87))\n","plt.subplot(1,2,1)\n","plt.imshow(X[0,...,0], cmap='magma')\n","plt.axis('off')\n","plt.title('Training Patch');\n","plt.subplot(1,2,2)\n","plt.imshow(X_val[0,...,0], cmap='magma')\n","plt.axis('off')\n","plt.title('Validation Patch');\n","\n","pdf_export(pretrained_model = Use_pretrained_model)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"wQPz0F6JlvJR"},"source":["## **4.2. Start Training**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this \n","point.\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder from Google Drive as all data can be erased at the next training if using the same folder.\n","\n","**Of Note:** At the end of the training, your model will be automatically exported so it can be used in the CSB Fiji plugin (Run your Network). You can find it in your model folder (TF_SavedModel.zip). In Fiji, Make sure to choose the right version of tensorflow. You can check at: Edit-- Options-- Tensorflow. Choose the version 1.4 (CPU or GPU depending on your system).\n"]},{"cell_type":"code","metadata":{"id":"j_Qm5JBmlvJg","cellView":"form"},"source":["start = time.time()\n","\n","#@markdown ##Start training\n","%memit\n","\n","history = model.train(X, X_val)\n","print(\"Training done.\")\n","%memit\n","\n","\n","print(\"Training, done.\")\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n"," shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","model.export_TF(name='Noise2Void', \n"," description='Noise2Void 2D trained using ZeroCostDL4Mic.', \n"," authors=[\"You\"],\n"," test_img=X_val[0,...,0], axes='YX',\n"," patch_shape=(patch_size, patch_size))\n","\n","print(\"Your model has been sucessfully exported and can now also be used in the CSBdeep Fiji plugin\")\n","\n","pdf_export(trained = True, pretrained_model = Use_pretrained_model)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"QYuIOWQ3imuU"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"id":"zazOZ3wDx0zQ","cellView":"form"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," \n"," print(bcolors.WARNING + '!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"yDY9dtzdUTLh"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","It is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact noise patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"id":"vMzSP50kMv5p","cellView":"form"},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png')\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"biT9FI9Ri77_"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n"]},{"cell_type":"code","metadata":{"id":"nAs4Wni7VYbq","cellView":"form"},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","# Create a quality control/Prediction Folder\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","# Activate the pretrained model. \n","model_training = N2V(config=None, name=QC_model_name, basedir=QC_model_path)\n","\n","\n","# List Tif images in Source_QC_folder\n","Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\n","Z = sorted(glob(Source_QC_folder_tif))\n","Z = list(map(imread,Z))\n","\n","print('Number of test dataset found in the folder: '+str(len(Z)))\n","\n","\n","# Perform prediction on all datasets in the Source_QC folder\n","for filename in os.listdir(Source_QC_folder):\n"," img = imread(os.path.join(Source_QC_folder, filename))\n"," predicted = model_training.predict(img, axes='YX', n_tiles=(2,1))\n"," os.chdir(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n"," imsave(filename, predicted)\n","\n","def ssim(img1, img2):\n"," return structural_similarity(img1,img2,data_range=1.,full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n","\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/QC_metrics_\"+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n","\n"," # Let's loop through the provided dataset in the QC folders\n","\n","\n"," for i in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder,i)):\n"," print('Running QC on: '+i)\n"," # -------------------------------- Target test data (Ground truth) --------------------------------\n"," test_GT = io.imread(os.path.join(Target_QC_folder, i))\n","\n"," # -------------------------------- Source test data --------------------------------\n"," test_source = io.imread(os.path.join(Source_QC_folder,i))\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and Source image\n"," test_GT_norm,test_source_norm = norm_minmse(test_GT, test_source, normalize_gt=True)\n","\n"," # -------------------------------- Prediction --------------------------------\n"," test_prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\",i))\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and prediction\n"," test_GT_norm,test_prediction_norm = norm_minmse(test_GT, test_prediction, normalize_gt=True) \n","\n","\n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\n","\n"," # Calculate the SSIM maps\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT_norm, test_prediction_norm)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT_norm, test_source_norm)\n","\n"," #Save ssim_maps\n"," img_SSIM_GTvsPrediction_32bit = np.float32(img_SSIM_GTvsPrediction)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/SSIM_GTvsPrediction_'+i,img_SSIM_GTvsPrediction_32bit)\n"," img_SSIM_GTvsSource_32bit = np.float32(img_SSIM_GTvsSource)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/SSIM_GTvsSource_'+i,img_SSIM_GTvsSource_32bit)\n"," \n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n","\n"," # Save SE maps\n"," img_RSE_GTvsPrediction_32bit = np.float32(img_RSE_GTvsPrediction)\n"," img_RSE_GTvsSource_32bit = np.float32(img_RSE_GTvsSource)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/RSE_GTvsPrediction_'+i,img_RSE_GTvsPrediction_32bit)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/RSE_GTvsSource_'+i,img_RSE_GTvsSource_32bit)\n","\n","\n"," # -------------------------------- Calculate the RSE metrics and save them --------------------------------\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n"," \n"," # We can also measure the peak signal to noise ratio between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n","\n"," writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource),str(PSNR_GTvsPrediction),str(PSNR_GTvsSource)])\n","\n","\n","# All data is now processed saved\n","Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same\n","\n","plt.figure(figsize=(15,15))\n","# Currently only displays the last computed set, from memory\n","# Target (Ground-truth)\n","plt.subplot(3,3,1)\n","plt.axis('off')\n","img_GT = io.imread(os.path.join(Target_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_GT)\n","plt.title('Target',fontsize=15)\n","\n","# Source\n","plt.subplot(3,3,2)\n","plt.axis('off')\n","img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_Source)\n","plt.title('Source',fontsize=15)\n","\n","#Prediction\n","plt.subplot(3,3,3)\n","plt.axis('off')\n","img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction/\", Test_FileList[-1]))\n","plt.imshow(img_Prediction)\n","plt.title('Prediction',fontsize=15)\n","\n","#Setting up colours\n","cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Source\n","plt.subplot(3,3,5)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\n","plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n","plt.subplot(3,3,6)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n","plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\n","\n","#Root Squared Error between GT and Source\n","plt.subplot(3,3,8)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource, cmap = cmap, vmin=0, vmax = 1)\n","plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsSource,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n","#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\n","plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#Root Squared Error between GT and Prediction\n","plt.subplot(3,3,9)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsPrediction,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n","plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","\n","qc_pdf_export()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"69aJVFfsqXbY"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"tcPNRq1TrMPB"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If an older model needs to be used, please untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contains the images that you want to predict using the network that you will train.\n","\n","**`Result_folder`:** This folder will contain the predicted output images.\n","\n","**`Data_type`:** Please indicate if the images you want to predict are single images or stacks"]},{"cell_type":"code","metadata":{"id":"Am2JSmpC0frj","cellView":"form"},"source":["Single_Images = 1\n","Stacks = 2\n","\n","#@markdown ### Provide the path to your dataset and to the folder where the prediction will be saved, then play the cell to predict output on your unseen images.\n","\n","#@markdown ###Path to data to analyse and where predicted output should be saved:\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n","#@markdown ###Are your data single images or stacks?\n","\n","Data_type = Single_Images #@param [\"Single_Images\", \"Stacks\"] {type:\"raw\"}\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," print(bcolors.WARNING +'!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","#Activate the pretrained model. \n","config = None\n","model = N2V(config, Prediction_model_name, basedir=Prediction_model_path)\n","\n","thisdir = Path(Data_folder)\n","outputdir = Path(Result_folder)\n","\n"," # r=root, d=directories, f = files\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," if \".tif\" in file:\n"," print(os.path.join(r, file))\n","\n","if Data_type == 1 :\n"," print(\"Single images are now beeing predicted\")\n","\n","# Loop through the files\n"," for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," base_filename = os.path.basename(file)\n"," input_train = imread(os.path.join(r, file))\n"," pred_train = model.predict(input_train, axes='YX', n_tiles=(2,1))\n"," save_tiff_imagej_compatible(os.path.join(outputdir, base_filename), pred_train, axes='YX') \n","\n"," print(\"Images saved into folder:\", Result_folder)\n","\n","if Data_type == 2 :\n"," print(\"Stacks are now beeing predicted\")\n"," for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," base_filename = os.path.basename(file)\n"," timelapse = imread(os.path.join(r, file))\n"," n_timepoint = timelapse.shape[0]\n"," prediction_stack = np.zeros((n_timepoint, timelapse.shape[1], timelapse.shape[2]))\n","\n"," for t in range(n_timepoint):\n"," img_t = timelapse[t]\n"," prediction_stack[t] = model.predict(img_t, axes='YX', n_tiles=(2,1))\n","\n"," prediction_stack_32 = img_as_float32(prediction_stack, force_copy=False)\n"," imsave(os.path.join(outputdir, base_filename), prediction_stack_32) \n"," \n"," \n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"67_8rEKp8C-z"},"source":["## **6.2. Assess predicted output**\n","---\n","\n","\n"]},{"cell_type":"code","metadata":{"cellView":"form","id":"n-stU-f08Cae"},"source":["# @markdown ##Run this cell to display a randomly chosen input and its corresponding predicted output.\n","\n","# This will display a randomly chosen dataset input and predicted output\n","\n","\n","random_choice = random.choice(os.listdir(Data_folder))\n","x = imread(Data_folder+\"/\"+random_choice)\n","\n","os.chdir(Result_folder)\n","y = imread(Result_folder+\"/\"+random_choice)\n","\n","if Data_type == 1 :\n","\n"," f=plt.figure(figsize=(16,8))\n"," plt.subplot(1,2,1)\n"," plt.imshow(x, interpolation='nearest')\n"," plt.title('Input')\n"," plt.axis('off');\n"," plt.subplot(1,2,2)\n"," plt.imshow(y, interpolation='nearest')\n"," plt.title('Predicted output')\n"," plt.axis('off');\n","\n","if Data_type == 2 :\n","\n"," f=plt.figure(figsize=(16,8))\n"," plt.subplot(1,2,1)\n"," plt.imshow(x[1], interpolation='nearest')\n"," plt.title('Input')\n"," plt.axis('off');\n"," plt.subplot(1,2,2)\n"," plt.imshow(y[1], interpolation='nearest')\n"," plt.title('Predicted output')\n"," plt.axis('off');\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"u4pcBe8Z3T2J"},"source":["#**Thank you for using Noise2Void 2D!**"]}]} \ No newline at end of file +{"nbformat":4,"nbformat_minor":0,"metadata":{"accelerator":"GPU","colab":{"name":"Noise2Void_2D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1hMjEc-Ex7j-jeYGclaPw2x3OgbkeC6Bl","timestamp":1610626439596},{"file_id":"1_W4q9V1ExGFldTUBvGK91E0LG5QMc7K6","timestamp":1602523405636},{"file_id":"1t9a-44km730bI7F4I08-6Xh7wEZuL98p","timestamp":1591013189418},{"file_id":"11TigzvLl4FSSwFHUNwLzZKI2IAix4Nmu","timestamp":1586415689249},{"file_id":"1_dSnxUg_qtNWjrPc7D6RWDWlCanEL4Ve","timestamp":1585153449937},{"file_id":"1bKo8jYVZPPgXPa_-Gdu1KhDnNN4vYfLx","timestamp":1583200150464}],"collapsed_sections":[],"machine_shape":"hm"},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.4"}},"cells":[{"cell_type":"markdown","metadata":{"id":"V9zNGvape2-I"},"source":["# **Noise2Void (2D)**\n","\n","---\n","\n"," Noise2Void is a deep-learning method that can be used to denoise many types of images, including microscopy images and which was originally published by [Krull *et al.* on arXiv](https://arxiv.org/abs/1811.10980). It allows denoising of image data in a self-supervised manner, therefore high-quality, low noise equivalent images are not necessary to train this network. This is performed by \"masking\" a random subset of pixels in the noisy image and training the network to predict the values in these pixels. The resulting output is a denoised version of the image. Noise2Void is based on the popular U-Net network architecture, adapted from [CARE](https://www.nature.com/articles/s41592-018-0216-7).\n","\n"," **This particular notebook enables self-supervised denoised of 2D dataset. If you are interested in 3D dataset, you should use the Noise2Void 3D notebook instead.**\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the following paper:\n","\n","**Noise2Void - Learning Denoising from Single Noisy Images**\n","from Krull *et al.* published on arXiv in 2018 (https://arxiv.org/abs/1811.10980)\n","\n","And source code found in: /~https://github.com/juglab/n2v\n","\n","**Please also cite this original paper when using or developing this notebook.**\n"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"vNMDQHm0Ah-Z"},"source":["# **0. Before getting started**\n","---\n","\n","Before you run the notebook, please ensure that you are logged into your Google account and have the training and/or data to process in your Google Drive.\n","\n","For Noise2Void to train, it only requires a single noisy image but multiple images can be used. Information on how to generate a training dataset is available in our Wiki page: /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","Please note that you currently can **only use .tif files!**\n","\n","**We strongly recommend that you generate high signal to noise ration version of your noisy images (Quality control dataset). These images can be used to assess the quality of your trained model**. The quality control assessment can be done directly in this notebook.\n","\n"," You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed.\n","\n","Here is a common data structure that can work:\n","\n","* Data\n"," - **Training dataset**\n"," - **Quality control dataset** (Optional but recomended)\n"," - Low SNR images\n"," - img_1.tif, img_2.tif\n"," - High SNR images\n"," - img_1.tif, img_2.tif \n"," - **Data to be predicted** \n"," - Results\n","\n","\n","The **Results** folder will contain the processed images, trained model and network parameters as csv file. Your original images remain unmodified.\n","\n","---\n","**Important note**\n","\n","- If you wish to **train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **1. Install Noise2Void and dependencies**\n","---"]},{"cell_type":"markdown","metadata":{"id":"6sReQBH5VKlw"},"source":["## **1.1. Install key dependencies**\n","---\n"," "]},{"cell_type":"code","metadata":{"id":"3u2mXn3XsWzd","cellView":"form"},"source":["#@markdown ##Install Noise2Void and dependencies\n","\n","\n","!pip uninstall -y tensorflow\n","!pip install tensorflow==2.4.1\n","\n","import tensorflow \n","\n","!pip uninstall -y keras-nightly\n","\n","!pip3 install h5py==2.10.0\n","\n","!pip install n2v\n","\n","!pip install wget\n","!pip install fpdf\n","!pip install memory_profiler\n","\n","!pip install q keras==2.3.1\n","\n","#Force session restart\n","exit(0)\n","print('--------')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"JF4AHIp2VOWo"},"source":["## **1.2. Restart your runtime**\n","---\n","\n","\n","\n","** Ignore the following message error message. Your Runtime has automatically restarted. This is normal.**\n","\n","\"\"
\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"P3b6lUliVRdr"},"source":["## **1.3. Load key dependencies**\n","---\n"," "]},{"cell_type":"code","metadata":{"id":"WFKHpNCrVUPc","cellView":"form"},"source":["#@markdown ##Load key dependencies\n","\n","Notebook_version = '1.13'\n","Network = 'Noise2Void (2D)'\n","\n","\n","from builtins import any as b_any\n","\n","def get_requirements_path():\n"," # Store requirements file in 'contents' directory \n"," current_dir = os.getcwd()\n"," dir_count = current_dir.count('/') - 1\n"," path = '../' * (dir_count) + 'requirements.txt'\n"," return path\n","\n","def filter_files(file_list, filter_list):\n"," filtered_list = []\n"," for fname in file_list:\n"," if b_any(fname.split('==')[0] in s for s in filter_list):\n"," filtered_list.append(fname)\n"," return filtered_list\n","\n","def build_requirements_file(before, after):\n"," path = get_requirements_path()\n","\n"," # Exporting requirements.txt for local run\n"," !pip freeze > $path\n","\n"," # Get minimum requirements file\n"," df = pd.read_csv(path, delimiter = \"\\n\")\n"," mod_list = [m.split('.')[0] for m in after if not m in before]\n"," req_list_temp = df.values.tolist()\n"," req_list = [x[0] for x in req_list_temp]\n","\n"," # Replace with package name and handle cases where import name is different to module name\n"," mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]\n"," mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list] \n"," filtered_list = filter_files(req_list, mod_replace_list)\n","\n"," file=open(path,'w')\n"," for item in filtered_list:\n"," file.writelines(item + '\\n')\n","\n"," file.close()\n","\n","import sys\n","before = [str(m) for m in sys.modules]\n","\n","#import tensorflow\n","import tensorflow \n","print('TensorFlow version:')\n","print(tensorflow.__version__)\n","print(\"Tensorflow enabled.\")\n","\n","# Here we install all libraries and other depencies to run the notebook.\n","\n","# ------- Variable specific to N2V -------\n","\n","from n2v.models import N2VConfig, N2V\n","from csbdeep.utils import plot_history\n","from n2v.utils.n2v_utils import manipulate_val_data\n","from n2v.internals.N2V_DataGenerator import N2V_DataGenerator\n","from csbdeep.io import save_tiff_imagej_compatible\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","from fpdf import FPDF, HTMLMixin\n","from datetime import datetime\n","from pip._internal.operations.freeze import freeze\n","import subprocess\n","\n","\n","%load_ext memory_profiler\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","\n","# Check if this is the latest version of the notebook\n","All_notebook_versions = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_Notebook_versions.csv\", dtype=str)\n","print('Notebook version: '+Notebook_version)\n","Latest_Notebook_version = All_notebook_versions[All_notebook_versions[\"Notebook\"] == Network]['Version'].iloc[0]\n","print('Latest notebook version: '+Latest_Notebook_version)\n","if Notebook_version == Latest_Notebook_version:\n"," print(\"This notebook is up-to-date.\")\n","else:\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","\n","\n","#PDF export\n","def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'Noise2Void 2D'\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," # add another cell \n"," if trained:\n"," training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n"," pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n"," pdf.ln(1)\n","\n"," Header_2 = 'Information for your materials and method:'\n"," pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n"," #print(all_packages)\n","\n"," #Main Packages\n"," main_packages = ''\n"," version_numbers = []\n"," for name in ['tensorflow','numpy','Keras','csbdeep']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n"," cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n"," cuda_version = cuda_version.stdout.decode('utf-8')\n"," cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n"," gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n"," gpu_name = gpu_name.stdout.decode('utf-8')\n"," gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n"," #print(cuda_version[cuda_version.find(', V')+3:-1])\n"," #print(gpu_name)\n","\n"," shape = io.imread(Training_source+'/'+os.listdir(Training_source)[0]).shape\n"," dataset_size = len(os.listdir(Training_source))\n","\n"," text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(Xdata.shape[0])+' image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+config.train_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," if pretrained_model:\n"," text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(Xdata.shape[0])+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+config.train_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was re-trained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.multi_cell(190, 5, txt = text, align='L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(26, 5, txt='Augmentation: ', ln=0)\n"," pdf.set_font('')\n"," if augmentation:\n"," aug_text = 'The dataset was augmented by default.'\n"," else:\n"," aug_text = 'No augmentation was used for training.'\n"," pdf.multi_cell(190, 5, txt=aug_text, align='L')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," if Use_Default_Advanced_Parameters:\n"," pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n"," pdf.cell(200, 5, txt='The following parameters were used for training:')\n"," pdf.ln(1)\n"," html = \"\"\" \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
number_of_epochs{0}
patch_size{1}
batch_size{2}
number_of_steps{3}
percentage_validation{4}
initial_learning_rate{5}
\n"," \"\"\".format(number_of_epochs,str(patch_size)+'x'+str(patch_size),batch_size,number_of_steps,percentage_validation,initial_learning_rate)\n"," pdf.write_html(html)\n","\n"," #pdf.multi_cell(190, 5, txt = text_2, align='L')\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," pdf.ln(1)\n"," pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(28, 5, txt= 'Training_source:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n"," # pdf.set_font('')\n"," # pdf.set_font('Arial', size = 10, style = 'B')\n"," # pdf.cell(28, 5, txt= 'Training_target:', align = 'L', ln=0)\n"," # pdf.set_font('')\n"," # pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n"," #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(21, 5, txt= 'Model Path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n"," pdf.ln(1)\n"," pdf.cell(60, 5, txt = 'Example Training Image', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread('/content/TrainingDataExample_N2V2D.png').shape\n"," pdf.image('/content/TrainingDataExample_N2V2D.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- Noise2Void: Krull, Alexander, Tim-Oliver Buchholz, and Florian Jug. \"Noise2void-learning denoising from single noisy images.\" Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2019.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," pdf.ln(3)\n"," reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(model_path+'/'+model_name+'/'+model_name+\"_training_report.pdf\")\n","\n","\n","\n"," #Make a pdf summary of the QC results\n","\n","def qc_pdf_export():\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n","\n","\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(2)\n"," pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/lossCurvePlots.png').shape\n"," if os.path.exists(full_QC_model_path+'/Quality Control/lossCurvePlots.png'):\n"," pdf.image(full_QC_model_path+'/Quality Control/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," pdf.cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.')\n"," pdf.ln(2)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(3)\n"," pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n"," pdf.image(full_QC_model_path+'/Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/10))\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n","\n"," pdf.ln(1)\n"," html = \"\"\"\n"," \n"," \n"," \"\"\"\n"," with open(full_QC_model_path+'/Quality Control/QC_metrics_'+QC_model_name+'.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," mSSIM_PvsGT = header[1]\n"," mSSIM_SvsGT = header[2]\n"," NRMSE_PvsGT = header[3]\n"," NRMSE_SvsGT = header[4]\n"," PSNR_PvsGT = header[5]\n"," PSNR_SvsGT = header[6]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,mSSIM_PvsGT,mSSIM_SvsGT,NRMSE_PvsGT,NRMSE_SvsGT,PSNR_PvsGT,PSNR_SvsGT)\n"," html = html+header\n"," for row in metrics:\n"," image = row[0]\n"," mSSIM_PvsGT = row[1]\n"," mSSIM_SvsGT = row[2]\n"," NRMSE_PvsGT = row[3]\n"," NRMSE_SvsGT = row[4]\n"," PSNR_PvsGT = row[5]\n"," PSNR_SvsGT = row[6]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,str(round(float(mSSIM_PvsGT),3)),str(round(float(mSSIM_SvsGT),3)),str(round(float(NRMSE_PvsGT),3)),str(round(float(NRMSE_SvsGT),3)),str(round(float(PSNR_PvsGT),3)),str(round(float(PSNR_SvsGT),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}{3}{4}{5}{6}
{0}{1}{2}{3}{4}{5}{6}
\"\"\"\n"," \n"," pdf.write_html(html)\n","\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- Noise2Void: Krull, Alexander, Tim-Oliver Buchholz, and Florian Jug. \"Noise2void-learning denoising from single noisy images.\" Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2019.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," pdf.ln(3)\n"," reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n","\n","\n","# Build requirements file for local run\n","after = [str(m) for m in sys.modules]\n","build_requirements_file(before, after)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"L1Nwo9k5kXCm"},"source":["# **2. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"2JydTyW1kafd"},"source":["\n","## **2.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelerator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"C30tLtARkQve","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","# %tensorflow_version 1.x\n","\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-oqBTeLaImnU"},"source":["## **2.2 Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","cellView":"form"},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xqF0lB-ShqQ4"},"source":["** If you cannot see your files, reactivate your session by connecting to your hosted runtime.** \n","\n","\n","\"Example
Connect to a hosted runtime.
"]},{"cell_type":"markdown","metadata":{"id":"Fw0kkTU6CsU4"},"source":["# **3. Select your parameters and paths**\n","\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"WzYAA-MuaYrT"},"source":["## **3.1. Setting main training parameters**\n","---\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"CB6acvUFtWqd"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`:** These is the path to your folders containing the Training_source (noisy images). To find the path of the folder containing your datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Do not re-use the name of an existing model (saved in the same folder), otherwise it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","\n","**Training Parameters**\n","\n","**`number_of_epochs`:** Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for 100-200 epochs. Evaluate the performance after training (see 4.3.). **Default value: 100**\n"," \n","**`patch_size`:** Noise2Void divides the image into patches for training. Input the size of the patches (length of a side). The value should be between 64 and the dimensions of the image and divisible by 8. **Default value: 64**\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Noise2Void requires a large batch size for stable training. Reduce this parameter if your GPU runs out of memory. **Default value: 128**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each image / patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during the training. **Default value: 10**\n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0004**\n"]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","cellView":"form"},"source":["# create DataGenerator-object.\n","\n","datagen = N2V_DataGenerator()\n","\n","#@markdown ###Path to training image(s): \n","Training_source = \"\" #@param {type:\"string\"}\n","\n","#compatibility to easily change the name of the parameters\n","training_images = Training_source \n","imgs = datagen.load_imgs_from_directory(directory = Training_source)\n","\n","#@markdown ### Model name and path:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","full_model_path = model_path+'/'+model_name+'/'\n","\n","#@markdown ###Training Parameters\n","#@markdown Number of epochs:\n","number_of_epochs = 100#@param {type:\"number\"}\n","\n","#@markdown Patch size (pixels)\n","patch_size = 64#@param {type:\"number\"}\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True#@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please input:\n","batch_size = 128#@param {type:\"number\"}\n","number_of_steps = 0#@param {type:\"number\"}\n","percentage_validation = 10#@param {type:\"number\"}\n","initial_learning_rate = 0.0004 #@param {type:\"number\"}\n","\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," # number_of_steps is defined in the following cell in this case\n"," batch_size = 128\n"," percentage_validation = 10\n"," initial_learning_rate = 0.0004\n"," \n","\n","#here we check that no model with the same name already exist, if so print a warning\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n"," print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\"+W)\n"," \n","\n","# This will open a randomly chosen dataset input image\n","random_choice = random.choice(os.listdir(Training_source))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","# Here we check that the input images contains the expected dimensions\n","if len(x.shape) == 2:\n"," print(\"Image dimensions (y,x)\",x.shape)\n","\n","if not len(x.shape) == 2:\n"," print(bcolors.WARNING +\"Your images appear to have the wrong dimensions. Image dimension\",x.shape)\n","\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","#Hyperparameters failsafes\n","\n","# Here we check that patch_size is smaller than the smallest xy dimension of the image \n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is divisible by 8\n","if not patch_size % 8 == 0:\n"," patch_size = ((int(patch_size / 8)-1) * 8)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we disable pre-trained model by default (in case the next cell is not run)\n","Use_pretrained_model = False\n","\n","# Here we enable data augmentation by default (in case the cell is not ran)\n","Use_Data_augmentation = True\n","\n","print(\"Parameters initiated.\")\n","\n","#Here we display one image\n","norm = simple_norm(x, percent = 99)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest', norm=norm, cmap='magma')\n","plt.title('Training source')\n","plt.axis('off');\n","plt.savefig('/content/TrainingDataExample_N2V2D.png',bbox_inches='tight',pad_inches=0)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xGcl7WGP4WHt"},"source":["## **3.2. Data augmentation**\n","---"]},{"cell_type":"markdown","metadata":{"id":"5Lio8hpZ4PJ1"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n","Data augmentation is performed here by rotating the patches in XY-Plane and flip them along X-Axis. This only works if the patches are square in XY.\n","\n"," **By default data augmentation is enabled. Disable this option is you run out of RAM during the training**.\n"," "]},{"cell_type":"code","metadata":{"id":"htqjkJWt5J_8","cellView":"form"},"source":["#Data augmentation\n","\n","#@markdown ##Play this cell to enable or disable data augmentation: \n","\n","Use_Data_augmentation = True #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation:\n"," print(\"Data augmentation enabled\")\n","\n","if not Use_Data_augmentation:\n"," print(\"Data augmentation disabled\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"bQDuybvyadKU"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a N2V 2D model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"8vPkzEBNamE4","cellView":"form"},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","\n","Weights_choice = \"best\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","\n","# --------------------- Download the a model provided in the XXX ------------------------\n","\n"," if pretrained_model_choice == \"Model_name\":\n"," pretrained_model_name = \"Model_name\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the 2D_Demo_Model_from_Stardist_2D_paper\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path) \n"," wget.download(\"\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_last.h5 pretrained model does not exist')\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print(bcolors.WARNING+'No pretrained nerwork will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"rQndJj70FzfL"},"source":["# **4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"tGW2iaU6X5zi"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"id":"WMJnGJpCMa4y","cellView":"form"},"source":["#@markdown ##Create the model and dataset objects\n","\n","# --------------------- Here we delete the model folder if it already exist ------------------------\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: Model folder already exists and has been removed !!\" + W)\n"," shutil.rmtree(model_path+'/'+model_name)\n","\n","\n","# split patches from the training images\n","Xdata = datagen.generate_patches_from_list(imgs, shape=(patch_size,patch_size), augment=Use_Data_augmentation)\n","shape_of_Xdata = Xdata.shape\n","# create a threshold (10 % patches for the validation)\n","threshold = int(shape_of_Xdata[0]*(percentage_validation/100))\n","# split the patches into training patches and validation patches\n","X = Xdata[threshold:]\n","X_val = Xdata[:threshold]\n","print(Xdata.shape[0],\"patches created.\")\n","print(threshold,\"patch images for validation (\",percentage_validation,\"%).\")\n","print(Xdata.shape[0]-threshold,\"patch images for training.\")\n","%memit\n","\n","#Here we automatically define number_of_step in function of training data and batch size\n","if (Use_Default_Advanced_Parameters): \n"," number_of_steps= int(X.shape[0]/batch_size)+1\n","\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","# create a Config object\n","config = N2VConfig(X, unet_kern_size=3, \n"," train_steps_per_epoch=number_of_steps, train_epochs=number_of_epochs, \n"," train_loss='mse', batch_norm=True, train_batch_size=batch_size, n2v_perc_pix=0.198, \n"," n2v_manipulator='uniform_withCP', n2v_neighborhood_radius=5, train_learning_rate = initial_learning_rate)\n","\n","# Let's look at the parameters stored in the config-object.\n","vars(config)\n"," \n"," \n","# create network model.\n","model = N2V(config=config, name=model_name, basedir=model_path)\n","\n","# --------------------- Using pretrained model ------------------------\n","# Load the pretrained weights \n","if Use_pretrained_model:\n"," model.load_weights(h5_file_path)\n","# --------------------- ---------------------- ------------------------\n","\n","\n","print(\"Setup done.\")\n","print(config)\n","\n","\n","# creates a plot and shows one training patch and one validation patch.\n","plt.figure(figsize=(16,87))\n","plt.subplot(1,2,1)\n","plt.imshow(X[0,...,0], cmap='magma')\n","plt.axis('off')\n","plt.title('Training Patch');\n","plt.subplot(1,2,2)\n","plt.imshow(X_val[0,...,0], cmap='magma')\n","plt.axis('off')\n","plt.title('Validation Patch');\n","\n","pdf_export(pretrained_model = Use_pretrained_model)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"wQPz0F6JlvJR"},"source":["## **4.2. Start Training**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this \n","point.\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder from Google Drive as all data can be erased at the next training if using the same folder.\n","\n","**Of Note:** At the end of the training, your model will be automatically exported so it can be used in the CSB Fiji plugin (Run your Network). You can find it in your model folder (TF_SavedModel.zip). In Fiji, Make sure to choose the right version of tensorflow. You can check at: Edit-- Options-- Tensorflow. Choose the version 1.4 (CPU or GPU depending on your system).\n"]},{"cell_type":"code","metadata":{"id":"j_Qm5JBmlvJg","cellView":"form"},"source":["start = time.time()\n","\n","#@markdown ##Start training\n","%memit\n","\n","history = model.train(X, X_val)\n","print(\"Training done.\")\n","%memit\n","\n","\n","print(\"Training, done.\")\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n"," shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","model.export_TF(name='Noise2Void', \n"," description='Noise2Void 2D trained using ZeroCostDL4Mic.', \n"," authors=[\"You\"],\n"," test_img=X_val[0,...,0], axes='YX',\n"," patch_shape=(patch_size, patch_size))\n","\n","print(\"Your model has been sucessfully exported and can now also be used in the CSBdeep Fiji plugin\")\n","\n","pdf_export(trained = True, pretrained_model = Use_pretrained_model)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"QYuIOWQ3imuU"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"id":"zazOZ3wDx0zQ","cellView":"form"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," \n"," print(bcolors.WARNING + '!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"yDY9dtzdUTLh"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","It is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact noise patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"id":"vMzSP50kMv5p","cellView":"form"},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png')\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"biT9FI9Ri77_"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n"]},{"cell_type":"code","metadata":{"id":"nAs4Wni7VYbq","cellView":"form"},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","# Create a quality control/Prediction Folder\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","# Activate the pretrained model. \n","model_training = N2V(config=None, name=QC_model_name, basedir=QC_model_path)\n","\n","\n","# List Tif images in Source_QC_folder\n","Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\n","Z = sorted(glob(Source_QC_folder_tif))\n","Z = list(map(imread,Z))\n","\n","print('Number of test dataset found in the folder: '+str(len(Z)))\n","\n","\n","# Perform prediction on all datasets in the Source_QC folder\n","for filename in os.listdir(Source_QC_folder):\n"," img = imread(os.path.join(Source_QC_folder, filename))\n"," predicted = model_training.predict(img, axes='YX', n_tiles=(2,1))\n"," os.chdir(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n"," imsave(filename, predicted)\n","\n","def ssim(img1, img2):\n"," return structural_similarity(img1,img2,data_range=1.,full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n","\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/QC_metrics_\"+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n","\n"," # Let's loop through the provided dataset in the QC folders\n","\n","\n"," for i in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder,i)):\n"," print('Running QC on: '+i)\n"," # -------------------------------- Target test data (Ground truth) --------------------------------\n"," test_GT = io.imread(os.path.join(Target_QC_folder, i))\n","\n"," # -------------------------------- Source test data --------------------------------\n"," test_source = io.imread(os.path.join(Source_QC_folder,i))\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and Source image\n"," test_GT_norm,test_source_norm = norm_minmse(test_GT, test_source, normalize_gt=True)\n","\n"," # -------------------------------- Prediction --------------------------------\n"," test_prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\",i))\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and prediction\n"," test_GT_norm,test_prediction_norm = norm_minmse(test_GT, test_prediction, normalize_gt=True) \n","\n","\n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\n","\n"," # Calculate the SSIM maps\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT_norm, test_prediction_norm)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT_norm, test_source_norm)\n","\n"," #Save ssim_maps\n"," img_SSIM_GTvsPrediction_32bit = np.float32(img_SSIM_GTvsPrediction)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/SSIM_GTvsPrediction_'+i,img_SSIM_GTvsPrediction_32bit)\n"," img_SSIM_GTvsSource_32bit = np.float32(img_SSIM_GTvsSource)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/SSIM_GTvsSource_'+i,img_SSIM_GTvsSource_32bit)\n"," \n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n","\n"," # Save SE maps\n"," img_RSE_GTvsPrediction_32bit = np.float32(img_RSE_GTvsPrediction)\n"," img_RSE_GTvsSource_32bit = np.float32(img_RSE_GTvsSource)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/RSE_GTvsPrediction_'+i,img_RSE_GTvsPrediction_32bit)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/RSE_GTvsSource_'+i,img_RSE_GTvsSource_32bit)\n","\n","\n"," # -------------------------------- Calculate the RSE metrics and save them --------------------------------\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n"," \n"," # We can also measure the peak signal to noise ratio between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n","\n"," writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource),str(PSNR_GTvsPrediction),str(PSNR_GTvsSource)])\n","\n","\n","# All data is now processed saved\n","Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same\n","\n","plt.figure(figsize=(15,15))\n","# Currently only displays the last computed set, from memory\n","# Target (Ground-truth)\n","plt.subplot(3,3,1)\n","plt.axis('off')\n","img_GT = io.imread(os.path.join(Target_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_GT)\n","plt.title('Target',fontsize=15)\n","\n","# Source\n","plt.subplot(3,3,2)\n","plt.axis('off')\n","img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_Source)\n","plt.title('Source',fontsize=15)\n","\n","#Prediction\n","plt.subplot(3,3,3)\n","plt.axis('off')\n","img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction/\", Test_FileList[-1]))\n","plt.imshow(img_Prediction)\n","plt.title('Prediction',fontsize=15)\n","\n","#Setting up colours\n","cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Source\n","plt.subplot(3,3,5)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\n","plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n","plt.subplot(3,3,6)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n","plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\n","\n","#Root Squared Error between GT and Source\n","plt.subplot(3,3,8)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource, cmap = cmap, vmin=0, vmax = 1)\n","plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsSource,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n","#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\n","plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#Root Squared Error between GT and Prediction\n","plt.subplot(3,3,9)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsPrediction,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n","plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","\n","qc_pdf_export()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"69aJVFfsqXbY"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"tcPNRq1TrMPB"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If an older model needs to be used, please untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contains the images that you want to predict using the network that you will train.\n","\n","**`Result_folder`:** This folder will contain the predicted output images.\n","\n","**`Data_type`:** Please indicate if the images you want to predict are single images or stacks"]},{"cell_type":"code","metadata":{"id":"Am2JSmpC0frj","cellView":"form"},"source":["Single_Images = 1\n","Stacks = 2\n","\n","#@markdown ### Provide the path to your dataset and to the folder where the prediction will be saved, then play the cell to predict output on your unseen images.\n","\n","#@markdown ###Path to data to analyse and where predicted output should be saved:\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n","#@markdown ###Are your data single images or stacks?\n","\n","Data_type = Single_Images #@param [\"Single_Images\", \"Stacks\"] {type:\"raw\"}\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," print(bcolors.WARNING +'!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","#Activate the pretrained model. \n","config = None\n","model = N2V(config, Prediction_model_name, basedir=Prediction_model_path)\n","\n","thisdir = Path(Data_folder)\n","outputdir = Path(Result_folder)\n","\n"," # r=root, d=directories, f = files\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," if \".tif\" in file:\n"," print(os.path.join(r, file))\n","\n","if Data_type == 1 :\n"," print(\"Single images are now beeing predicted\")\n","\n","# Loop through the files\n"," for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," base_filename = os.path.basename(file)\n"," input_train = imread(os.path.join(r, file))\n"," pred_train = model.predict(input_train, axes='YX', n_tiles=(2,1))\n"," save_tiff_imagej_compatible(os.path.join(outputdir, base_filename), pred_train, axes='YX') \n","\n"," print(\"Images saved into folder:\", Result_folder)\n","\n","if Data_type == 2 :\n"," print(\"Stacks are now beeing predicted\")\n"," for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," base_filename = os.path.basename(file)\n"," timelapse = imread(os.path.join(r, file))\n"," n_timepoint = timelapse.shape[0]\n"," prediction_stack = np.zeros((n_timepoint, timelapse.shape[1], timelapse.shape[2]))\n","\n"," for t in range(n_timepoint):\n"," img_t = timelapse[t]\n"," prediction_stack[t] = model.predict(img_t, axes='YX', n_tiles=(2,1))\n","\n"," prediction_stack_32 = img_as_float32(prediction_stack, force_copy=False)\n"," imsave(os.path.join(outputdir, base_filename), prediction_stack_32) \n"," \n"," \n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"67_8rEKp8C-z"},"source":["## **6.2. Assess predicted output**\n","---\n","\n","\n"]},{"cell_type":"code","metadata":{"cellView":"form","id":"n-stU-f08Cae"},"source":["# @markdown ##Run this cell to display a randomly chosen input and its corresponding predicted output.\n","\n","# This will display a randomly chosen dataset input and predicted output\n","\n","\n","random_choice = random.choice(os.listdir(Data_folder))\n","x = imread(Data_folder+\"/\"+random_choice)\n","\n","os.chdir(Result_folder)\n","y = imread(Result_folder+\"/\"+random_choice)\n","\n","if Data_type == 1 :\n","\n"," f=plt.figure(figsize=(16,8))\n"," plt.subplot(1,2,1)\n"," plt.imshow(x, interpolation='nearest')\n"," plt.title('Input')\n"," plt.axis('off');\n"," plt.subplot(1,2,2)\n"," plt.imshow(y, interpolation='nearest')\n"," plt.title('Predicted output')\n"," plt.axis('off');\n","\n","if Data_type == 2 :\n","\n"," f=plt.figure(figsize=(16,8))\n"," plt.subplot(1,2,1)\n"," plt.imshow(x[1], interpolation='nearest')\n"," plt.title('Input')\n"," plt.axis('off');\n"," plt.subplot(1,2,2)\n"," plt.imshow(y[1], interpolation='nearest')\n"," plt.title('Predicted output')\n"," plt.axis('off');\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"MhqdcAboiZU4"},"source":["# **7. Version log**\n","---\n","**v1.13**: \n","\n","* N2V now uses tensorflow 2.4.\n","* This version now includes an automatic restart allowing to set the h5py library to v2.10.\n","* The section 1 and 2 are now swapped for better export of *requirements.txt*.\n","This version also now includes built-in version check and the version log that \n","\n","* This version also now includes built-in version check and the version log that you're reading now."]},{"cell_type":"markdown","metadata":{"id":"u4pcBe8Z3T2J"},"source":["#**Thank you for using Noise2Void 2D!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/Noise2Void_3D_ZeroCostDL4Mic.ipynb b/Colab_notebooks/Noise2Void_3D_ZeroCostDL4Mic.ipynb index 5a019f2f..fdd9dadd 100644 --- a/Colab_notebooks/Noise2Void_3D_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/Noise2Void_3D_ZeroCostDL4Mic.ipynb @@ -1 +1 @@ -{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"Noise2Void_3D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1MmLTCC0nyX3Akb9V4C_OVxM3X_M8u-eX","timestamp":1610543191319},{"file_id":"1paNjUObR5Rcr4BMGADJTz0PQBBLZDPrY","timestamp":1602522500580},{"file_id":"1WZRIoSBNcRUEq4-Rq5M4mDkIaOlEHnxz","timestamp":1588762142860},{"file_id":"10weAY0es-pEfHlACCaBCKK7PmgdoJqdh","timestamp":1587728072051},{"file_id":"10Ze0rFZoooyyTL_OIVWGdFJEhWE6_cSB","timestamp":1586789421439},{"file_id":"1SsGyUbWcMaLGHFepMuKElRNYLdEBUwf6","timestamp":1583244509550}],"collapsed_sections":[],"toc_visible":true,"machine_shape":"hm"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.7"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"IkSguVy8Xv83"},"source":["# **Noise2Void (3D)**\n","\n","---\n","\n"," Noise2Void is a deep-learning method that can be used to denoise many types of images, including microscopy images and which was originally published by [Krull *et al.* on arXiv](https://arxiv.org/abs/1811.10980). It allows denoising of image data in a self-supervised manner, therefore high-quality, low noise equivalent images are not necessary to train this network. This is performed by \"masking\" a random subset of pixels in the noisy image and training the network to predict the values in these pixels. The resulting output is a denoised version of the image. Noise2Void is based on the popular U-Net network architecture, adapted from [CARE](https://www.nature.com/articles/s41592-018-0216-7).\n","\n"," **This particular notebook enables self-supervised denoised of 3D dataset. If you are interested in 2D dataset, you should use the Noise2Void 2D notebook instead.**\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the following paper:\n","\n","**Noise2Void - Learning Denoising from Single Noisy Images**\n","from Krull *et al.* published on arXiv in 2018 (https://arxiv.org/abs/1811.10980)\n","\n","And source code found in: /~https://github.com/juglab/n2v\n","\n","**Please also cite this original paper when using or developing this notebook.**\n"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","\n","\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"gKDLkLWUd-YX"},"source":["# **0. Before getting started**\n","---\n","\n","Before you run the notebook, please ensure that you are logged into your Google account and have the training and/or data to process in your Google Drive.\n","\n","For Noise2Void to train, it only requires a single noisy image but multiple images can be used. Information on how to generate a training dataset is available in our Wiki page: /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","Please note that you currently can **only use .tif files!**\n","\n","**We strongly recommend that you generate high signal to noise ration version of your noisy images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n"," You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed.\n","\n","Here is a common data structure that can work:\n","\n","* Data\n"," - **Training dataset**\n"," - **Quality control dataset** (Optional but recomended)\n"," - Low SNR images\n"," - img_1.tif, img_2.tif\n"," - High SNR images\n"," - img_1.tif, img_2.tif \n"," - **Data to be predicted** \n"," - **Results**\n","\n","\n","The **Results** folder will contain the processed images, trained model and network parameters as csv file. Your original images remain unmodified.\n","\n","---\n","**Important note**\n","\n","- If you wish to **train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **1. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelerator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"zCvebubeSaGY","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"sNIVx8_CLolt"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","cellView":"form"},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"AdN8B91xZO0x"},"source":["# **2. Install Noise2Void and dependencies**\n","---"]},{"cell_type":"markdown","metadata":{"id":"kXu3EG6fpA00"},"source":["## **2.1. Install key dependencies**\n","---\n"," "]},{"cell_type":"code","metadata":{"id":"yGkQrmBYpEFq","cellView":"form"},"source":["\n","#@markdown ##Install Noise2Void and dependencies\n","\n","\n","!pip uninstall -y tensorflow\n","!pip install tensorflow==2.4.1\n","\n","import tensorflow \n","\n","!pip uninstall -y keras-nightly\n","# !pip install q keras==2.2.5\n","!pip install q keras==2.3.1\n","\n","\n","!pip3 install h5py==2.10.0\n","!pip install n2v\n","!pip install wget\n","!pip install fpdf\n","!pip install memory_profiler\n","\n","#Force session restart\n","exit(0)\n","\n","print('--------')\n","print('TensorFlow version:')\n","print(tensorflow.__version__)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"FpR7Cz6fpGqa"},"source":["## **2.2. Restart your runtime**\n","---\n","\n","\n","\n","** Your Runtime has automatically restarted. This is normal.**\n","\n"]},{"cell_type":"markdown","metadata":{"id":"kGskPwVSpJaO"},"source":["## **2.3. Load key dependencies**\n","---\n"," "]},{"cell_type":"code","metadata":{"id":"fq21zJVFNASx","cellView":"form"},"source":["#@markdown ##Load key dependencies\n","\n","Notebook_version = ['1.12.3']\n","\n","from builtins import any as b_any\n","\n","def get_requirements_path():\n"," # Store requirements file in 'contents' directory \n"," current_dir = os.getcwd()\n"," dir_count = current_dir.count('/') - 1\n"," path = '../' * (dir_count) + 'requirements.txt'\n"," return path\n","\n","def filter_files(file_list, filter_list):\n"," filtered_list = []\n"," for fname in file_list:\n"," if b_any(fname.split('==')[0] in s for s in filter_list):\n"," filtered_list.append(fname)\n"," return filtered_list\n","\n","def build_requirements_file(before, after):\n"," path = get_requirements_path()\n","\n"," # Exporting requirements.txt for local run\n"," !pip freeze > $path\n","\n"," # Get minimum requirements file\n"," df = pd.read_csv(path, delimiter = \"\\n\")\n"," mod_list = [m.split('.')[0] for m in after if not m in before]\n"," req_list_temp = df.values.tolist()\n"," req_list = [x[0] for x in req_list_temp]\n","\n"," # Replace with package name and handle cases where import name is different to module name\n"," mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]\n"," mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list] \n"," filtered_list = filter_files(req_list, mod_replace_list)\n","\n"," file=open(path,'w')\n"," for item in filtered_list:\n"," file.writelines(item + '\\n')\n","\n"," file.close()\n","\n","import sys\n","before = [str(m) for m in sys.modules]\n","\n","\n","import tensorflow\n","print('TensorFlow version:')\n","print(tensorflow.__version__)\n","\n","print(\"Tensorflow enabled.\")\n","\n","%load_ext memory_profiler\n","\n","# ------- Variable specific to N2V -------\n","from n2v.models import N2VConfig, N2V\n","from csbdeep.utils import plot_history\n","from n2v.utils.n2v_utils import manipulate_val_data\n","from n2v.internals.N2V_DataGenerator import N2V_DataGenerator\n","from csbdeep.io import save_tiff_imagej_compatible\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","from fpdf import FPDF, HTMLMixin\n","from datetime import datetime\n","from pip._internal.operations.freeze import freeze\n","import subprocess\n","\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print(\"Libraries installed\")\n","\n","\n","# Check if this is the latest version of the notebook\n","Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n","print('Notebook version: '+Notebook_version[0])\n","strlist = Notebook_version[0].split('.')\n","Notebook_version_main = strlist[0]+'.'+strlist[1]\n","if Notebook_version_main == Latest_notebook_version.columns:\n"," print(\"This notebook is up-to-date.\")\n","else:\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","\n","def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'Noise2Void 3D'\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," \n"," # add another cell \n"," if trained:\n"," training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n"," pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n"," pdf.ln(1)\n","\n"," Header_2 = 'Information for your materials and method:'\n"," pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n"," #print(all_packages)\n","\n"," #Main Packages\n"," main_packages = ''\n"," version_numbers = []\n"," for name in ['tensorflow','numpy','Keras','csbdeep']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n"," cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n"," cuda_version = cuda_version.stdout.decode('utf-8')\n"," cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n"," gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n"," gpu_name = gpu_name.stdout.decode('utf-8')\n"," gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n"," #print(cuda_version[cuda_version.find(', V')+3:-1])\n"," #print(gpu_name)\n","\n"," shape = io.imread(Training_source+'/'+os.listdir(Training_source)[0]).shape\n"," dataset_size = len(os.listdir(Training_source))\n","\n"," text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(len(patches))+' image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_height)+','+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+config.train_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," if pretrained_model:\n"," text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(len(patches))+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_height)+','+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+config.train_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was re-trained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.multi_cell(190, 5, txt = text, align='L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(26, 5, txt='Augmentation: ', ln=0)\n"," pdf.set_font('')\n"," if augmentation:\n"," aug_text = 'The dataset was augmented by default.'\n"," else:\n"," aug_text = 'No augmentation was used for training.'\n"," pdf.multi_cell(190, 5, txt=aug_text, align='L')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," if Use_Default_Advanced_Parameters:\n"," pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n"," pdf.cell(200, 5, txt='The following parameters were used for training:')\n"," pdf.ln(1)\n"," html = \"\"\" \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
number_of_epochs{0}
patch_size{1}
batch_size{2}
number_of_steps{3}
percentage_validation{4}
initial_learning_rate{5}
\n"," \"\"\".format(number_of_epochs,str(patch_height)+'x'+str(patch_size)+'x'+str(patch_size),batch_size,number_of_steps,percentage_validation,initial_learning_rate)\n"," pdf.write_html(html)\n","\n"," #pdf.multi_cell(190, 5, txt = text_2, align='L')\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," pdf.ln(1)\n"," pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(29, 5, txt= 'Training_source:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n"," # pdf.set_font('')\n"," # pdf.set_font('Arial', size = 10, style = 'B')\n"," # pdf.cell(28, 5, txt= 'Training_target:', align = 'L', ln=0)\n"," # pdf.set_font('')\n"," # pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n"," #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(21, 5, txt= 'Model Path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n"," pdf.ln(1)\n"," pdf.cell(60, 5, txt = 'Training Image', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread('/content/TrainingDataExample_N2V3D.png').shape\n"," pdf.image('/content/TrainingDataExample_N2V3D.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- Noise2Void: Krull, Alexander, Tim-Oliver Buchholz, and Florian Jug. \"Noise2void-learning denoising from single noisy images.\" Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2019.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," # if Use_Data_augmentation:\n"," # ref_3 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\n"," # pdf.multi_cell(190, 5, txt = ref_3, align='L')\n"," pdf.ln(3)\n"," reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(model_path+'/'+model_name+'/'+model_name+\"_training_report.pdf\")\n","\n","\n","def qc_pdf_export():\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'Noise2Void 3D'\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(2)\n"," pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n"," if os.path.exists(full_QC_model_path+'/Quality Control/lossCurvePlots.png'):\n"," pdf.image(full_QC_model_path+'/Quality Control/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/13))\n"," else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," # pdf.ln(3)\n"," pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.')\n"," pdf.ln(3)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(3)\n"," pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n"," pdf.image(full_QC_model_path+'/Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/10))\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n","\n"," pdf.ln(1)\n"," html = \"\"\"\n"," \n"," \n"," \"\"\"\n"," with open(full_QC_model_path+'/Quality Control/QC_metrics_'+QC_model_name+'.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," slice_n = header[1]\n"," mSSIM_PvsGT = header[2]\n"," mSSIM_SvsGT = header[3]\n"," NRMSE_PvsGT = header[4]\n"," NRMSE_SvsGT = header[5]\n"," PSNR_PvsGT = header[6]\n"," PSNR_SvsGT = header[7]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,slice_n,mSSIM_PvsGT,mSSIM_SvsGT,NRMSE_PvsGT,NRMSE_SvsGT,PSNR_PvsGT,PSNR_SvsGT)\n"," html = html+header\n"," for row in metrics:\n"," image = row[0]\n"," slice_n = row[1]\n"," mSSIM_PvsGT = row[2]\n"," mSSIM_SvsGT = row[3]\n"," NRMSE_PvsGT = row[4]\n"," NRMSE_SvsGT = row[5]\n"," PSNR_PvsGT = row[6]\n"," PSNR_SvsGT = row[7]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,slice_n,str(round(float(mSSIM_PvsGT),3)),str(round(float(mSSIM_SvsGT),3)),str(round(float(NRMSE_PvsGT),3)),str(round(float(NRMSE_SvsGT),3)),str(round(float(PSNR_PvsGT),3)),str(round(float(PSNR_SvsGT),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}{3}{4}{5}{6}{7}
{0}{1}{2}{3}{4}{5}{6}{7}
\"\"\"\n"," \n"," pdf.write_html(html)\n","\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- Noise2Void: Krull, Alexander, Tim-Oliver Buchholz, and Florian Jug. \"Noise2void-learning denoising from single noisy images.\" Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2019.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n"," pdf.ln(3)\n"," reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n","\n","# Build requirements file for local run\n","after = [str(m) for m in sys.modules]\n","build_requirements_file(before, after)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"HLYcZR9gMv42"},"source":["# **3. Select your parameters and paths**\n","---"]},{"cell_type":"markdown","metadata":{"id":"FQ_QxtSWQ7CL"},"source":["## **3.1. Setting main training parameters**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"id":"AuESFimvMv43"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`:** This is the path to your folders containing the Training_source (noisy images). To find the path of the folder containing your datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Do not re-use the name of an existing model (saved in the same folder), otherwise it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","\n","**Training parameters**\n","\n","**`number_of_epochs`:** Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for 100-200 epochs. Evaluate the performance after training (see 5.). **Default value: 30**\n","\n","**`patch_size`:** Noise2Void divides the image into patches for training. Input the size of the patches (length of a side). The value should be smaller than the dimensions of the image and divisible by 8. **Default value: 64**\n","\n","**`patch_height`:** The value should be smaller than the Z dimensions of the image and divisible by 4. When analysing isotropic stacks patch_size and patch_height should have similar values.\n","\n","**If you get an Out of memory (OOM) error during the training, manually decrease the patch_size and patch_height values until the OOM error disappear.**\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Noise2Void requires a large batch size for stable training. Reduce this parameter if your GPU runs out of memory. **Default value: 128**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each image / patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during the training. **Default value: 10** \n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0004**\n"]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","cellView":"form"},"source":["\n","# Create DataGenerator-object.\n","datagen = N2V_DataGenerator()\n","\n","#@markdown ###Path to training images: \n","Training_source = \"\" #@param {type:\"string\"}\n","\n","imgs = datagen.load_imgs_from_directory(directory = Training_source, dims='ZYX')\n","\n","#@markdown ### Model name and path:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","#@markdown ###Training Parameters\n","#@markdown Number of steps and epochs:\n","\n","number_of_epochs = 30#@param {type:\"number\"}\n","\n","#@markdown Patch size (pixels) and number\n","patch_size = 64#@param {type:\"number\"}\n","\n","patch_height = 4#@param {type:\"number\"}\n","\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please input:\n","batch_size = 128#@param {type:\"number\"}\n","number_of_steps = 100#@param {type:\"number\"}\n","percentage_validation = 10 #@param {type:\"number\"}\n","initial_learning_rate = 0.0004 #@param {type:\"number\"}\n","\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," # number_of_steps is defined in the following cell in this case\n"," batch_size = 128\n"," percentage_validation = 10\n"," initial_learning_rate = 0.0004\n","\n","#here we check that no model with the same name already exist, if so print a warning\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n"," print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\"+W)\n"," \n","\n","#Load one randomly chosen training target file\n","\n","random_choice=random.choice(os.listdir(Training_source))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","# Here we check that the input images are stacks\n","if len(x.shape) == 3:\n"," print(\"Image dimensions (z,y,x)\",x.shape)\n","\n","if not len(x.shape) == 3:\n"," print(bcolors.WARNING + \"Your images appear to have the wrong dimensions. Image dimension\",x.shape)\n","\n","#Find image Z dimension and select the mid-plane\n","Image_Z = x.shape[0]\n","mid_plane = int(Image_Z / 2)+1\n","\n","#Find image XY dimension\n","Image_Y = x.shape[1]\n","Image_X = x.shape[2]\n","\n","#Hyperparameters failsafes\n","\n","# Here we check that patch_size is smaller than the smallest xy dimension of the image \n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is divisible by 8\n","if not patch_size % 8 == 0:\n"," patch_size = ((int(patch_size / 8)-1) * 8)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_height is smaller than the z dimension of the image \n","if patch_height > Image_Z :\n"," patch_height = Image_Z\n"," print (bcolors.WARNING + \" Your chosen patch_height is bigger than the z dimension of your image; therefore the patch_size chosen is now:\",patch_height)\n","\n","# Here we check that patch_height is divisible by 4\n","if not patch_height % 4 == 0:\n"," patch_height = ((int(patch_height / 4)-1) * 4)\n"," if patch_height == 0:\n"," patch_height = 4\n"," print (bcolors.WARNING + \" Your chosen patch_height is not divisible by 4; therefore the patch_size chosen is now:\",patch_height)\n","\n","# Here we disable pre-trained model by default (in case the next cell is not run)\n","Use_pretrained_model = False\n","\n","# Here we enable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = True\n","\n","print(\"Parameters initiated.\")\n","\n","\n","#Here we display a single z plane\n","\n","norm = simple_norm(x[mid_plane], percent = 99)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x[mid_plane], interpolation='nearest', norm=norm, cmap='magma')\n","plt.title('Training source')\n","plt.axis('off');\n","plt.savefig('/content/TrainingDataExample_N2V3D.png',bbox_inches='tight',pad_inches=0)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xyQZKby8yFME"},"source":["## **3.2. Data augmentation**\n","---\n",""]},{"cell_type":"markdown","metadata":{"id":"w_jCy7xOx2g3"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n","Data augmentation is performed here by rotating the patches in XY-Plane and flip them along X-Axis. This only works if the patches are square in XY.\n","\n"," By default data augmentation is enabled. Disable this option is you run out of RAM during the training.\n"," "]},{"cell_type":"code","metadata":{"id":"DMqWq5-AxnFU","cellView":"form"},"source":["#Data augmentation\n","#@markdown ##Play this cell to enable or disable data augmentation: \n","Use_Data_augmentation = True #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation:\n"," print(\"Data augmentation enabled\")\n","\n","if not Use_Data_augmentation:\n"," print(\"Data augmentation disabled\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"3L9zSGtORKYI"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a N2V 3D model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"9vC2n-HeLdiJ","cellView":"form"},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","\n","Weights_choice = \"last\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","\n","# --------------------- Download the a model provided in the XXX ------------------------\n","\n"," if pretrained_model_choice == \"Model_name\":\n"," pretrained_model_name = \"Model_name\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the 2D_Demo_Model_from_Stardist_2D_paper\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path) \n"," wget.download(\"\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_last.h5 pretrained model does not exist')\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print(bcolors.WARNING+'No pretrained network will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"MCGklf1vZf2M"},"source":["#**4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"1KYOuygETJkT"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"id":"lIUAOJ_LMv5E","cellView":"form"},"source":["#@markdown ##Create the model and dataset objects\n","\n","# --------------------- Here we delete the model folder if it already exist ------------------------\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: Model folder already exists and has been removed !!\" + W)\n"," shutil.rmtree(model_path+'/'+model_name)\n","\n","\n","#Disable some of the warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","# Create batches from the training data.\n","patches = datagen.generate_patches_from_list(imgs, shape=(patch_height, patch_size, patch_size), augment=Use_Data_augmentation)\n","\n","# Patches are divited into training and validation patch set. This inhibits over-lapping of patches. \n","number_train_images =int(len(patches)*(percentage_validation/100))\n","X = patches[number_train_images:]\n","X_val = patches[:number_train_images]\n","\n","print(len(patches),\"patches created.\")\n","print(number_train_images,\"patch images for validation (\",percentage_validation,\"%).\")\n","print((len(patches)-number_train_images),\"patch images for training.\")\n","%memit \n","\n","#Here we automatically define number_of_step in function of training data and batch size\n","if (Use_Default_Advanced_Parameters): \n"," number_of_steps= int(X.shape[0]/batch_size) + 1\n","\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","\n","# creates Congfig object. \n","config = N2VConfig(X, unet_kern_size=3, \n"," train_steps_per_epoch=number_of_steps,train_epochs=number_of_epochs, train_loss='mse', batch_norm=True, \n"," train_batch_size=batch_size, n2v_perc_pix=0.198, n2v_patch_shape=(patch_height, patch_size, patch_size), \n"," n2v_manipulator='uniform_withCP', n2v_neighborhood_radius=5, train_learning_rate = initial_learning_rate)\n","\n","vars(config)\n","\n","# Create the default model.\n","model = N2V(config=config, name=model_name, basedir=model_path)\n","\n","# --------------------- Using pretrained model ------------------------\n","# Load the pretrained weights \n","if Use_pretrained_model:\n"," model.load_weights(h5_file_path)\n","# --------------------- ---------------------- ------------------------\n","\n","print(\"Parameters transferred into the model.\")\n","print(config)\n","\n","# Shows a training batch and a validation batch.\n","plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(X[0,1,...,0],cmap='magma')\n","plt.axis('off')\n","plt.title('Training Patch');\n","plt.subplot(1,2,2)\n","plt.imshow(X_val[0,1,...,0],cmap='magma')\n","plt.axis('off')\n","plt.title('Validation Patch');\n","\n","\n","pdf_export(trained = False, pretrained_model = Use_pretrained_model)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"0Dfn8ZsEMv5d"},"source":["## **4.2. Start Training**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point.\n","\n","**Of Note:** At the end of the training, your model will be automatically exported so it can be used in the CSBDeep Fiji plugin (N2V -- N2V Predict). You can find it in your model folder (export.bioimage.io.zip and model.yaml). In Fiji, Make sure to choose the right version of tensorflow. You can check at: Edit-- Options-- Tensorflow. Choose the version 1.4 (CPU or GPU depending on your system).\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"code","metadata":{"scrolled":true,"id":"iwNmp1PUzRDQ","cellView":"form"},"source":["start = time.time()\n","\n","#@markdown ##Start training\n","%memit\n","# the training starts.\n","history = model.train(X, X_val)\n","%memit\n","print(\"Model training is now done.\")\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n"," shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","model.export_TF(name='Noise2Void', \n"," description='Noise2Void 3D trained using ZeroCostDL4Mic.', \n"," authors=[\"You\"],\n"," test_img=X_val[0,...,0], axes='ZYX',\n"," patch_shape=(patch_size, patch_size))\n","\n","print(\"Your model has been sucessfully exported and can now also be used in the CSBDeep Fiji plugin\")\n","\n","#Create a pdf document with training summary\n","\n","pdf_export(trained=True, pretrained_model = Use_pretrained_model)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"_0Hynw3-xHp1"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"id":"eAJzMwPA6tlH","cellView":"form"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else: \n"," print(bcolors.WARNING + '!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"dhJROwlAMv5o"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"id":"vMzSP50kMv5p","cellView":"form"},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png')\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"X5_92nL2xdP6"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n"]},{"cell_type":"code","metadata":{"id":"w90MdriMxhjD","cellView":"form"},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","path_metrics_save = QC_model_path+'/'+QC_model_name+'/Quality Control/'\n","\n","# Create a quality control/Prediction Folder\n","if os.path.exists(path_metrics_save+'Prediction'):\n"," shutil.rmtree(path_metrics_save+'Prediction')\n","os.makedirs(path_metrics_save+'Prediction')\n","\n","#Here we allow the user to choose the number of tile to be used when predicting the images\n","#@markdown #####To analyse large image, your images need to be divided into tiles. Each tile will then be processed independently and re-assembled to generate the final image. \"Automatic_number_of_tiles\" will search for and use the smallest number of tiles that can be used, at the expanse of your runtime. Alternatively, manually input the number of tiles in each dimension to be used to process your images. \n","\n","Automatic_number_of_tiles = False #@param {type:\"boolean\"}\n","#@markdown #####If you get an Out of memory (OOM) error when using the \"Automatic_number_of_tiles\" option, disable it and manually input the values to be used to process your images. Progressively increases these numbers until the OOM error disappear.\n","n_tiles_Z = 1#@param {type:\"number\"}\n","n_tiles_Y = 2#@param {type:\"number\"}\n","n_tiles_X = 2#@param {type:\"number\"}\n","\n","if (Automatic_number_of_tiles): \n"," n_tilesZYX = None\n","\n","if not (Automatic_number_of_tiles):\n"," n_tilesZYX = (n_tiles_Z, n_tiles_Y, n_tiles_X)\n","\n","\n","# Activate the pretrained model. \n","model_training = N2V(config=None, name=QC_model_name, basedir=QC_model_path)\n","\n","# List Tif images in Source_QC_folder\n","Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\n","Z = sorted(glob(Source_QC_folder_tif))\n","Z = list(map(imread,Z))\n","print('Number of test dataset found in the folder: '+str(len(Z)))\n","\n","\n","# Perform prediction on all datasets in the Source_QC folder\n","for filename in os.listdir(Source_QC_folder):\n"," img = imread(os.path.join(Source_QC_folder, filename))\n"," n_slices = img.shape[0]\n"," predicted = model_training.predict(img, axes='ZYX', n_tiles=n_tilesZYX)\n"," os.chdir(path_metrics_save+'Prediction/')\n"," imsave('Predicted_'+filename, predicted)\n","\n","\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(path_metrics_save+'QC_metrics_'+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"File name\",\"Slice #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n"," \n"," # These lists will be used to collect all the metrics values per slice\n"," file_name_list = []\n"," slice_number_list = []\n"," mSSIM_GvP_list = []\n"," mSSIM_GvS_list = []\n"," NRMSE_GvP_list = []\n"," NRMSE_GvS_list = []\n"," PSNR_GvP_list = []\n"," PSNR_GvS_list = []\n","\n"," # These lists will be used to display the mean metrics for the stacks\n"," mSSIM_GvP_list_mean = []\n"," mSSIM_GvS_list_mean = []\n"," NRMSE_GvP_list_mean = []\n"," NRMSE_GvS_list_mean = []\n"," PSNR_GvP_list_mean = []\n"," PSNR_GvS_list_mean = []\n","\n"," # Let's loop through the provided dataset in the QC folders\n"," for thisFile in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder, thisFile)):\n"," print('Running QC on: '+thisFile)\n","\n"," test_GT_stack = io.imread(os.path.join(Target_QC_folder, thisFile))\n"," test_source_stack = io.imread(os.path.join(Source_QC_folder,thisFile))\n"," test_prediction_stack = io.imread(os.path.join(path_metrics_save+\"Prediction/\",'Predicted_'+thisFile))\n"," n_slices = test_GT_stack.shape[0]\n","\n"," # Calculating the position of the mid-plane slice\n"," z_mid_plane = int(n_slices / 2)+1\n","\n"," img_SSIM_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_SSIM_GTvsSource_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_RSE_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_RSE_GTvsSource_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n","\n"," for z in range(n_slices): \n"," # -------------------------------- Normalising the dataset --------------------------------\n","\n"," test_GT_norm,test_source_norm = norm_minmse(test_GT_stack[z], test_source_stack[z], normalize_gt=True)\n"," test_GT_norm,test_prediction_norm = norm_minmse(test_GT_stack[z], test_prediction_stack[z], normalize_gt=True)\n","\n"," # -------------------------------- Calculate the SSIM metric and maps --------------------------------\n"," # Calculate the SSIM maps and index\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = structural_similarity(test_GT_norm, test_prediction_norm, data_range=1.0, full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = structural_similarity(test_GT_norm, test_source_norm, data_range=1.0, full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n"," #Calculate ssim_maps\n"," img_SSIM_GTvsPrediction_stack[z] = img_as_float32(img_SSIM_GTvsPrediction,force_copy=False)\n"," img_SSIM_GTvsSource_stack[z] = img_as_float32(img_SSIM_GTvsSource,force_copy=False)\n"," \n","\n"," # -------------------------------- Calculate the NRMSE metrics --------------------------------\n","\n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n","\n"," # Calculate SE maps\n"," img_RSE_GTvsPrediction_stack[z] = img_as_float32(img_RSE_GTvsPrediction)\n"," img_RSE_GTvsSource_stack[z] = img_as_float32(img_RSE_GTvsSource)\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n","\n"," # Calculate the PSNR between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n","\n"," writer.writerow([thisFile, str(z),str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource), str(PSNR_GTvsPrediction), str(PSNR_GTvsSource)])\n"," \n"," # Collect values to display in dataframe output\n"," slice_number_list.append(z)\n"," mSSIM_GvP_list.append(index_SSIM_GTvsPrediction)\n"," mSSIM_GvS_list.append(index_SSIM_GTvsSource)\n"," NRMSE_GvP_list.append(NRMSE_GTvsPrediction)\n"," NRMSE_GvS_list.append(NRMSE_GTvsSource)\n"," PSNR_GvP_list.append(PSNR_GTvsPrediction)\n"," PSNR_GvS_list.append(PSNR_GTvsSource)\n","\n"," if (z == z_mid_plane): # catch these for display\n"," SSIM_GTvsP_forDisplay = index_SSIM_GTvsPrediction\n"," SSIM_GTvsS_forDisplay = index_SSIM_GTvsSource\n"," NRMSE_GTvsP_forDisplay = NRMSE_GTvsPrediction\n"," NRMSE_GTvsS_forDisplay = NRMSE_GTvsSource\n"," \n"," # If calculating average metrics for dataframe output\n"," file_name_list.append(thisFile)\n"," mSSIM_GvP_list_mean.append(sum(mSSIM_GvP_list)/len(mSSIM_GvP_list))\n"," mSSIM_GvS_list_mean.append(sum(mSSIM_GvS_list)/len(mSSIM_GvS_list))\n"," NRMSE_GvP_list_mean.append(sum(NRMSE_GvP_list)/len(NRMSE_GvP_list))\n"," NRMSE_GvS_list_mean.append(sum(NRMSE_GvS_list)/len(NRMSE_GvS_list))\n"," PSNR_GvP_list_mean.append(sum(PSNR_GvP_list)/len(PSNR_GvP_list))\n"," PSNR_GvS_list_mean.append(sum(PSNR_GvS_list)/len(PSNR_GvS_list))\n","\n","\n"," # ----------- Change the stacks to 32 bit images -----------\n","\n"," img_SSIM_GTvsSource_stack_32 = img_as_float32(img_SSIM_GTvsSource_stack, force_copy=False)\n"," img_SSIM_GTvsPrediction_stack_32 = img_as_float32(img_SSIM_GTvsPrediction_stack, force_copy=False)\n"," img_RSE_GTvsSource_stack_32 = img_as_float32(img_RSE_GTvsSource_stack, force_copy=False)\n"," img_RSE_GTvsPrediction_stack_32 = img_as_float32(img_RSE_GTvsPrediction_stack, force_copy=False)\n","\n"," # ----------- Saving the error map stacks -----------\n"," io.imsave(path_metrics_save+'SSIM_GTvsSource_'+thisFile,img_SSIM_GTvsSource_stack_32)\n"," io.imsave(path_metrics_save+'SSIM_GTvsPrediction_'+thisFile,img_SSIM_GTvsPrediction_stack_32)\n"," io.imsave(path_metrics_save+'RSE_GTvsSource_'+thisFile,img_RSE_GTvsSource_stack_32)\n"," io.imsave(path_metrics_save+'RSE_GTvsPrediction_'+thisFile,img_RSE_GTvsPrediction_stack_32)\n","\n","#Averages of the metrics per stack as dataframe output\n","pdResults = pd.DataFrame(file_name_list, columns = [\"File name\"])\n","pdResults[\"Prediction v. GT mSSIM\"] = mSSIM_GvP_list_mean\n","pdResults[\"Input v. GT mSSIM\"] = mSSIM_GvS_list_mean\n","pdResults[\"Prediction v. GT NRMSE\"] = NRMSE_GvP_list_mean\n","pdResults[\"Input v. GT NRMSE\"] = NRMSE_GvS_list_mean\n","pdResults[\"Prediction v. GT PSNR\"] = PSNR_GvP_list_mean\n","pdResults[\"Input v. GT PSNR\"] = PSNR_GvS_list_mean\n","\n","\n","# All data is now processed saved\n","Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same way\n","\n","plt.figure(figsize=(15,15))\n","# Currently only displays the last computed set, from memory\n","# Target (Ground-truth)\n","plt.subplot(3,3,1)\n","plt.axis('off')\n","img_GT = io.imread(os.path.join(Target_QC_folder, Test_FileList[-1]))\n","\n","# Calculating the position of the mid-plane slice\n","z_mid_plane = int(img_GT.shape[0] / 2)+1\n","\n","plt.imshow(img_GT[z_mid_plane])\n","plt.title('Target (slice #'+str(z_mid_plane)+')')\n","\n","# Source\n","plt.subplot(3,3,2)\n","plt.axis('off')\n","img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_Source[z_mid_plane])\n","plt.title('Source (slice #'+str(z_mid_plane)+')')\n","\n","#Prediction\n","plt.subplot(3,3,3)\n","plt.axis('off')\n","img_Prediction = io.imread(os.path.join(path_metrics_save+'Prediction/', 'Predicted_'+Test_FileList[-1]))\n","plt.imshow(img_Prediction[z_mid_plane])\n","plt.title('Prediction (slice #'+str(z_mid_plane)+')')\n","\n","#Setting up colours\n","cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Source\n","plt.subplot(3,3,5)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","img_SSIM_GTvsSource = io.imread(os.path.join(path_metrics_save, 'SSIM_GTvsSource_'+Test_FileList[-1]))\n","imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource[z_mid_plane], cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(SSIM_GTvsS_forDisplay,3)),fontsize=14)\n","plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n","plt.subplot(3,3,6)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_SSIM_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'SSIM_GTvsPrediction_'+Test_FileList[-1]))\n","imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0,vmax=1)\n","plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(SSIM_GTvsP_forDisplay,3)),fontsize=14)\n","\n","#Root Squared Error between GT and Source\n","plt.subplot(3,3,8)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","img_RSE_GTvsSource = io.imread(os.path.join(path_metrics_save, 'RSE_GTvsSource_'+Test_FileList[-1]))\n","imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource[z_mid_plane], cmap = cmap, vmin=0, vmax = 1) \n","plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsS_forDisplay,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n","#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\n","plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#Root Squared Error between GT and Prediction\n","plt.subplot(3,3,9)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_RSE_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'RSE_GTvsPrediction_'+Test_FileList[-1]))\n","imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsP_forDisplay,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n","plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","print('-----------------------------------')\n","print('Here are the average scores for the stacks you tested in Quality control. To see values for all slices, open the .csv file saved in the Qulity Control folder.')\n","pdResults.head()\n","\n","#Make a pdf summary of the QC results\n","\n","qc_pdf_export()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-tJeeJjLnRkP"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"d8wuQGjoq6eN"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images."]},{"cell_type":"code","metadata":{"id":"y2TD5p7MZrEb","cellView":"form"},"source":["#Activate the pretrained model. \n","#model_training = CARE(config=None, name=model_name, basedir=model_path)\n","\n","#@markdown ### Provide the path to your dataset and to the folder where the prediction will be saved, then play the cell to predict output on your unseen images.\n","\n","#@markdown ###Path to data to analyse and where predicted output should be saved:\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else: \n"," print(bcolors.WARNING + '!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","#Here we allow the user to choose the number of tile to be used when predicting the images\n","#@markdown #####To analyse large image, your images need to be divided into tiles. Each tile will then be processed independently and re-assembled to generate the final image. \"Automatic_number_of_tiles\" will search for and use the smallest number of tiles that can be used, at the expanse of your runtime. Alternatively, manually input the number of tiles in each dimension to be used to process your images. \n","\n","Automatic_number_of_tiles = False #@param {type:\"boolean\"}\n","#@markdown #####If you get an Out of memory (OOM) error when using the \"Automatic_number_of_tiles\" option, disable it and manually input the values to be used to process your images. Progressively increases these numbers until the OOM error disappear.\n","n_tiles_Z = 1#@param {type:\"number\"}\n","n_tiles_Y = 2#@param {type:\"number\"}\n","n_tiles_X = 2#@param {type:\"number\"}\n","\n","if (Automatic_number_of_tiles): \n"," n_tilesZYX = None\n","\n","if not (Automatic_number_of_tiles):\n"," n_tilesZYX = (n_tiles_Z, n_tiles_Y, n_tiles_X)\n","\n","#Activate the pretrained model.\n","config = None\n","model = N2V(config, Prediction_model_name, basedir=Prediction_model_path)\n","\n","print(\"Denoising images...\")\n","\n","thisdir = Path(Data_folder)\n","outputdir = Path(Result_folder)\n","suffix = '.tif'\n","\n","# r=root, d=directories, f = files\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," if \".tif\" in file:\n"," print(os.path.join(r, file))\n","\n","# The code by Lucas von Chamier\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," base_filename = os.path.basename(file)\n"," input_train = imread(os.path.join(r, file))\n"," pred_train = model.predict(input_train, axes='ZYX', n_tiles=n_tilesZYX)\n"," save_tiff_imagej_compatible(os.path.join(outputdir, base_filename), pred_train, axes='ZYX')\n"," \n","print(\"Prediction of images done.\")\n","\n","print(\"One example is displayed here.\")\n","\n","\n","#Display an example\n","random_choice=random.choice(os.listdir(Data_folder))\n","x = imread(Data_folder+\"/\"+random_choice)\n","\n","#Find image Z dimension and select the mid-plane\n","Image_Z = x.shape[0]\n","mid_plane = int(Image_Z / 2)+1\n","\n","os.chdir(Result_folder)\n","y = imread(Result_folder+\"/\"+random_choice)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x[mid_plane], interpolation='nearest')\n","plt.title('Noisy Input (single Z plane)');\n","plt.axis('off');\n","plt.subplot(1,2,2)\n","plt.imshow(y[mid_plane], interpolation='nearest')\n","plt.title('Prediction (single Z plane)');\n","plt.axis('off');"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB"},"source":["## **6.2. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"UvSlTaH14s3t"},"source":["#**Thank you for using Noise2Void 3D!**"]}]} \ No newline at end of file +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"Noise2Void_3D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1MmLTCC0nyX3Akb9V4C_OVxM3X_M8u-eX","timestamp":1610543191319},{"file_id":"1paNjUObR5Rcr4BMGADJTz0PQBBLZDPrY","timestamp":1602522500580},{"file_id":"1WZRIoSBNcRUEq4-Rq5M4mDkIaOlEHnxz","timestamp":1588762142860},{"file_id":"10weAY0es-pEfHlACCaBCKK7PmgdoJqdh","timestamp":1587728072051},{"file_id":"10Ze0rFZoooyyTL_OIVWGdFJEhWE6_cSB","timestamp":1586789421439},{"file_id":"1SsGyUbWcMaLGHFepMuKElRNYLdEBUwf6","timestamp":1583244509550}],"collapsed_sections":[],"machine_shape":"hm"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.7"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"IkSguVy8Xv83"},"source":["# **Noise2Void (3D)**\n","\n","---\n","\n"," Noise2Void is a deep-learning method that can be used to denoise many types of images, including microscopy images and which was originally published by [Krull *et al.* on arXiv](https://arxiv.org/abs/1811.10980). It allows denoising of image data in a self-supervised manner, therefore high-quality, low noise equivalent images are not necessary to train this network. This is performed by \"masking\" a random subset of pixels in the noisy image and training the network to predict the values in these pixels. The resulting output is a denoised version of the image. Noise2Void is based on the popular U-Net network architecture, adapted from [CARE](https://www.nature.com/articles/s41592-018-0216-7).\n","\n"," **This particular notebook enables self-supervised denoised of 3D dataset. If you are interested in 2D dataset, you should use the Noise2Void 2D notebook instead.**\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the following paper:\n","\n","**Noise2Void - Learning Denoising from Single Noisy Images**\n","from Krull *et al.* published on arXiv in 2018 (https://arxiv.org/abs/1811.10980)\n","\n","And source code found in: /~https://github.com/juglab/n2v\n","\n","**Please also cite this original paper when using or developing this notebook.**\n"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","\n","\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"gKDLkLWUd-YX"},"source":["# **0. Before getting started**\n","---\n","\n","Before you run the notebook, please ensure that you are logged into your Google account and have the training and/or data to process in your Google Drive.\n","\n","For Noise2Void to train, it only requires a single noisy image but multiple images can be used. Information on how to generate a training dataset is available in our Wiki page: /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","Please note that you currently can **only use .tif files!**\n","\n","**We strongly recommend that you generate high signal to noise ration version of your noisy images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n"," You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed.\n","\n","Here is a common data structure that can work:\n","\n","* Data\n"," - **Training dataset**\n"," - **Quality control dataset** (Optional but recomended)\n"," - Low SNR images\n"," - img_1.tif, img_2.tif\n"," - High SNR images\n"," - img_1.tif, img_2.tif \n"," - **Data to be predicted** \n"," - **Results**\n","\n","\n","The **Results** folder will contain the processed images, trained model and network parameters as csv file. Your original images remain unmodified.\n","\n","---\n","**Important note**\n","\n","- If you wish to **train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"AdN8B91xZO0x"},"source":["# **1. Install Noise2Void and dependencies**\n","---"]},{"cell_type":"markdown","metadata":{"id":"kXu3EG6fpA00"},"source":["## **1.1. Install key dependencies**\n","---\n"," "]},{"cell_type":"code","metadata":{"id":"yGkQrmBYpEFq","cellView":"form"},"source":["\n","#@markdown ##Install Noise2Void and dependencies\n","\n","\n","!pip uninstall -y tensorflow\n","!pip install tensorflow==2.4.1\n","\n","import tensorflow \n","\n","!pip uninstall -y keras-nightly\n","# !pip install q keras==2.2.5\n","!pip install q keras==2.3.1\n","\n","\n","!pip3 install h5py==2.10.0\n","!pip install n2v\n","!pip install wget\n","!pip install fpdf\n","!pip install memory_profiler\n","\n","#Force session restart\n","exit(0)\n","\n","print('--------')\n","print('TensorFlow version:')\n","print(tensorflow.__version__)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"FpR7Cz6fpGqa"},"source":["## **1.2. Restart your runtime**\n","---\n","\n","\n","\n","** Ignore the following message error message. Your Runtime has automatically restarted. This is normal.**\n","\n","\"\"
\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"kGskPwVSpJaO"},"source":["## **1.3. Load key dependencies**\n","---\n"," "]},{"cell_type":"code","metadata":{"id":"fq21zJVFNASx","cellView":"form"},"source":["#@markdown ##Load key dependencies\n","\n","Notebook_version = '1.13'\n","Network = 'Noise2Void (3D)'\n","\n","from builtins import any as b_any\n","\n","def get_requirements_path():\n"," # Store requirements file in 'contents' directory \n"," current_dir = os.getcwd()\n"," dir_count = current_dir.count('/') - 1\n"," path = '../' * (dir_count) + 'requirements.txt'\n"," return path\n","\n","def filter_files(file_list, filter_list):\n"," filtered_list = []\n"," for fname in file_list:\n"," if b_any(fname.split('==')[0] in s for s in filter_list):\n"," filtered_list.append(fname)\n"," return filtered_list\n","\n","def build_requirements_file(before, after):\n"," path = get_requirements_path()\n","\n"," # Exporting requirements.txt for local run\n"," !pip freeze > $path\n","\n"," # Get minimum requirements file\n"," df = pd.read_csv(path, delimiter = \"\\n\")\n"," mod_list = [m.split('.')[0] for m in after if not m in before]\n"," req_list_temp = df.values.tolist()\n"," req_list = [x[0] for x in req_list_temp]\n","\n"," # Replace with package name and handle cases where import name is different to module name\n"," mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]\n"," mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list] \n"," filtered_list = filter_files(req_list, mod_replace_list)\n","\n"," file=open(path,'w')\n"," for item in filtered_list:\n"," file.writelines(item + '\\n')\n","\n"," file.close()\n","\n","import sys\n","before = [str(m) for m in sys.modules]\n","\n","\n","import tensorflow\n","print('TensorFlow version:')\n","print(tensorflow.__version__)\n","\n","print(\"Tensorflow enabled.\")\n","\n","%load_ext memory_profiler\n","\n","# ------- Variable specific to N2V -------\n","from n2v.models import N2VConfig, N2V\n","from csbdeep.utils import plot_history\n","from n2v.utils.n2v_utils import manipulate_val_data\n","from n2v.internals.N2V_DataGenerator import N2V_DataGenerator\n","from csbdeep.io import save_tiff_imagej_compatible\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","from fpdf import FPDF, HTMLMixin\n","from datetime import datetime\n","from pip._internal.operations.freeze import freeze\n","import subprocess\n","\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print(\"Libraries installed\")\n","\n","\n","# Check if this is the latest version of the notebook\n","All_notebook_versions = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_Notebook_versions.csv\", dtype=str)\n","print('Notebook version: '+Notebook_version)\n","Latest_Notebook_version = All_notebook_versions[All_notebook_versions[\"Notebook\"] == Network]['Version'].iloc[0]\n","print('Latest notebook version: '+Latest_Notebook_version)\n","if Notebook_version == Latest_Notebook_version:\n"," print(\"This notebook is up-to-date.\")\n","else:\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","\n","def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'Noise2Void 3D'\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," \n"," # add another cell \n"," if trained:\n"," training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n"," pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n"," pdf.ln(1)\n","\n"," Header_2 = 'Information for your materials and method:'\n"," pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n"," #print(all_packages)\n","\n"," #Main Packages\n"," main_packages = ''\n"," version_numbers = []\n"," for name in ['tensorflow','numpy','Keras','csbdeep']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n"," cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n"," cuda_version = cuda_version.stdout.decode('utf-8')\n"," cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n"," gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n"," gpu_name = gpu_name.stdout.decode('utf-8')\n"," gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n"," #print(cuda_version[cuda_version.find(', V')+3:-1])\n"," #print(gpu_name)\n","\n"," shape = io.imread(Training_source+'/'+os.listdir(Training_source)[0]).shape\n"," dataset_size = len(os.listdir(Training_source))\n","\n"," text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(len(patches))+' image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_height)+','+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+config.train_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," if pretrained_model:\n"," text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(len(patches))+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_height)+','+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+config.train_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was re-trained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.multi_cell(190, 5, txt = text, align='L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(26, 5, txt='Augmentation: ', ln=0)\n"," pdf.set_font('')\n"," if augmentation:\n"," aug_text = 'The dataset was augmented by default.'\n"," else:\n"," aug_text = 'No augmentation was used for training.'\n"," pdf.multi_cell(190, 5, txt=aug_text, align='L')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," if Use_Default_Advanced_Parameters:\n"," pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n"," pdf.cell(200, 5, txt='The following parameters were used for training:')\n"," pdf.ln(1)\n"," html = \"\"\" \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
number_of_epochs{0}
patch_size{1}
batch_size{2}
number_of_steps{3}
percentage_validation{4}
initial_learning_rate{5}
\n"," \"\"\".format(number_of_epochs,str(patch_height)+'x'+str(patch_size)+'x'+str(patch_size),batch_size,number_of_steps,percentage_validation,initial_learning_rate)\n"," pdf.write_html(html)\n","\n"," #pdf.multi_cell(190, 5, txt = text_2, align='L')\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," pdf.ln(1)\n"," pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(29, 5, txt= 'Training_source:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n"," # pdf.set_font('')\n"," # pdf.set_font('Arial', size = 10, style = 'B')\n"," # pdf.cell(28, 5, txt= 'Training_target:', align = 'L', ln=0)\n"," # pdf.set_font('')\n"," # pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n"," #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(21, 5, txt= 'Model Path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n"," pdf.ln(1)\n"," pdf.cell(60, 5, txt = 'Training Image', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread('/content/TrainingDataExample_N2V3D.png').shape\n"," pdf.image('/content/TrainingDataExample_N2V3D.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- Noise2Void: Krull, Alexander, Tim-Oliver Buchholz, and Florian Jug. \"Noise2void-learning denoising from single noisy images.\" Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2019.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," # if Use_Data_augmentation:\n"," # ref_3 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\n"," # pdf.multi_cell(190, 5, txt = ref_3, align='L')\n"," pdf.ln(3)\n"," reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(model_path+'/'+model_name+'/'+model_name+\"_training_report.pdf\")\n","\n","\n","def qc_pdf_export():\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'Noise2Void 3D'\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(2)\n"," pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n"," if os.path.exists(full_QC_model_path+'/Quality Control/lossCurvePlots.png'):\n"," pdf.image(full_QC_model_path+'/Quality Control/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/13))\n"," else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," # pdf.ln(3)\n"," pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.')\n"," pdf.ln(3)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(3)\n"," pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n"," pdf.image(full_QC_model_path+'/Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/10))\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n","\n"," pdf.ln(1)\n"," html = \"\"\"\n"," \n"," \n"," \"\"\"\n"," with open(full_QC_model_path+'/Quality Control/QC_metrics_'+QC_model_name+'.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," slice_n = header[1]\n"," mSSIM_PvsGT = header[2]\n"," mSSIM_SvsGT = header[3]\n"," NRMSE_PvsGT = header[4]\n"," NRMSE_SvsGT = header[5]\n"," PSNR_PvsGT = header[6]\n"," PSNR_SvsGT = header[7]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,slice_n,mSSIM_PvsGT,mSSIM_SvsGT,NRMSE_PvsGT,NRMSE_SvsGT,PSNR_PvsGT,PSNR_SvsGT)\n"," html = html+header\n"," for row in metrics:\n"," image = row[0]\n"," slice_n = row[1]\n"," mSSIM_PvsGT = row[2]\n"," mSSIM_SvsGT = row[3]\n"," NRMSE_PvsGT = row[4]\n"," NRMSE_SvsGT = row[5]\n"," PSNR_PvsGT = row[6]\n"," PSNR_SvsGT = row[7]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,slice_n,str(round(float(mSSIM_PvsGT),3)),str(round(float(mSSIM_SvsGT),3)),str(round(float(NRMSE_PvsGT),3)),str(round(float(NRMSE_SvsGT),3)),str(round(float(PSNR_PvsGT),3)),str(round(float(PSNR_SvsGT),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}{3}{4}{5}{6}{7}
{0}{1}{2}{3}{4}{5}{6}{7}
\"\"\"\n"," \n"," pdf.write_html(html)\n","\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- Noise2Void: Krull, Alexander, Tim-Oliver Buchholz, and Florian Jug. \"Noise2void-learning denoising from single noisy images.\" Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2019.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n"," pdf.ln(3)\n"," reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n","\n","# Build requirements file for local run\n","after = [str(m) for m in sys.modules]\n","build_requirements_file(before, after)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **2. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["\n","## **2.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelerator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"zCvebubeSaGY","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"sNIVx8_CLolt"},"source":["## **2.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","cellView":"form"},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"0jHA0eb0lCf4"},"source":["** If you cannot see your files, reactivate your session by connecting to your hosted runtime.** \n","\n","\n","\"Example
Connect to a hosted runtime.
"]},{"cell_type":"markdown","metadata":{"id":"HLYcZR9gMv42"},"source":["# **3. Select your parameters and paths**\n","---"]},{"cell_type":"markdown","metadata":{"id":"FQ_QxtSWQ7CL"},"source":["## **3.1. Setting main training parameters**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"id":"AuESFimvMv43"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`:** This is the path to your folders containing the Training_source (noisy images). To find the path of the folder containing your datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Do not re-use the name of an existing model (saved in the same folder), otherwise it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","\n","**Training parameters**\n","\n","**`number_of_epochs`:** Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for 100-200 epochs. Evaluate the performance after training (see 5.). **Default value: 30**\n","\n","**`patch_size`:** Noise2Void divides the image into patches for training. Input the size of the patches (length of a side). The value should be smaller than the dimensions of the image and divisible by 8. **Default value: 64**\n","\n","**`patch_height`:** The value should be smaller than the Z dimensions of the image and divisible by 4. When analysing isotropic stacks patch_size and patch_height should have similar values.\n","\n","**If you get an Out of memory (OOM) error during the training, manually decrease the patch_size and patch_height values until the OOM error disappear.**\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Noise2Void requires a large batch size for stable training. Reduce this parameter if your GPU runs out of memory. **Default value: 128**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each image / patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during the training. **Default value: 10** \n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0004**\n"]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","cellView":"form"},"source":["\n","# Create DataGenerator-object.\n","datagen = N2V_DataGenerator()\n","\n","#@markdown ###Path to training images: \n","Training_source = \"\" #@param {type:\"string\"}\n","\n","imgs = datagen.load_imgs_from_directory(directory = Training_source, dims='ZYX')\n","\n","#@markdown ### Model name and path:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","#@markdown ###Training Parameters\n","#@markdown Number of steps and epochs:\n","\n","number_of_epochs = 100#@param {type:\"number\"}\n","\n","#@markdown Patch size (pixels) and number\n","patch_size = 64#@param {type:\"number\"}\n","\n","patch_height = 4#@param {type:\"number\"}\n","\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please input:\n","batch_size = 128#@param {type:\"number\"}\n","number_of_steps = 100#@param {type:\"number\"}\n","percentage_validation = 10 #@param {type:\"number\"}\n","initial_learning_rate = 0.0004 #@param {type:\"number\"}\n","\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," # number_of_steps is defined in the following cell in this case\n"," batch_size = 128\n"," percentage_validation = 10\n"," initial_learning_rate = 0.0004\n","\n","#here we check that no model with the same name already exist, if so print a warning\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n"," print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\"+W)\n"," \n","\n","#Load one randomly chosen training target file\n","\n","random_choice=random.choice(os.listdir(Training_source))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","# Here we check that the input images are stacks\n","if len(x.shape) == 3:\n"," print(\"Image dimensions (z,y,x)\",x.shape)\n","\n","if not len(x.shape) == 3:\n"," print(bcolors.WARNING + \"Your images appear to have the wrong dimensions. Image dimension\",x.shape)\n","\n","#Find image Z dimension and select the mid-plane\n","Image_Z = x.shape[0]\n","mid_plane = int(Image_Z / 2)+1\n","\n","#Find image XY dimension\n","Image_Y = x.shape[1]\n","Image_X = x.shape[2]\n","\n","#Hyperparameters failsafes\n","\n","# Here we check that patch_size is smaller than the smallest xy dimension of the image \n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is divisible by 8\n","if not patch_size % 8 == 0:\n"," patch_size = ((int(patch_size / 8)-1) * 8)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_height is smaller than the z dimension of the image \n","if patch_height > Image_Z :\n"," patch_height = Image_Z\n"," print (bcolors.WARNING + \" Your chosen patch_height is bigger than the z dimension of your image; therefore the patch_size chosen is now:\",patch_height)\n","\n","# Here we check that patch_height is divisible by 4\n","if not patch_height % 4 == 0:\n"," patch_height = ((int(patch_height / 4)-1) * 4)\n"," if patch_height == 0:\n"," patch_height = 4\n"," print (bcolors.WARNING + \" Your chosen patch_height is not divisible by 4; therefore the patch_size chosen is now:\",patch_height)\n","\n","# Here we disable pre-trained model by default (in case the next cell is not run)\n","Use_pretrained_model = False\n","\n","# Here we enable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = True\n","\n","print(\"Parameters initiated.\")\n","\n","\n","#Here we display a single z plane\n","\n","norm = simple_norm(x[mid_plane], percent = 99)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x[mid_plane], interpolation='nearest', norm=norm, cmap='magma')\n","plt.title('Training source')\n","plt.axis('off');\n","plt.savefig('/content/TrainingDataExample_N2V3D.png',bbox_inches='tight',pad_inches=0)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xyQZKby8yFME"},"source":["## **3.2. Data augmentation**\n","---\n",""]},{"cell_type":"markdown","metadata":{"id":"w_jCy7xOx2g3"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n","Data augmentation is performed here by rotating the patches in XY-Plane and flip them along X-Axis. This only works if the patches are square in XY.\n","\n"," By default data augmentation is enabled. Disable this option is you run out of RAM during the training.\n"," "]},{"cell_type":"code","metadata":{"id":"DMqWq5-AxnFU","cellView":"form"},"source":["#Data augmentation\n","#@markdown ##Play this cell to enable or disable data augmentation: \n","Use_Data_augmentation = True #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation:\n"," print(\"Data augmentation enabled\")\n","\n","if not Use_Data_augmentation:\n"," print(\"Data augmentation disabled\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"3L9zSGtORKYI"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a N2V 3D model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"9vC2n-HeLdiJ","cellView":"form"},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","\n","Weights_choice = \"last\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","\n","# --------------------- Download the a model provided in the XXX ------------------------\n","\n"," if pretrained_model_choice == \"Model_name\":\n"," pretrained_model_name = \"Model_name\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the 2D_Demo_Model_from_Stardist_2D_paper\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path) \n"," wget.download(\"\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_last.h5 pretrained model does not exist')\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print(bcolors.WARNING+'No pretrained network will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"MCGklf1vZf2M"},"source":["#**4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"1KYOuygETJkT"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"id":"lIUAOJ_LMv5E","cellView":"form"},"source":["#@markdown ##Create the model and dataset objects\n","\n","# --------------------- Here we delete the model folder if it already exist ------------------------\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: Model folder already exists and has been removed !!\" + W)\n"," shutil.rmtree(model_path+'/'+model_name)\n","\n","\n","#Disable some of the warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","# Create batches from the training data.\n","patches = datagen.generate_patches_from_list(imgs, shape=(patch_height, patch_size, patch_size), augment=Use_Data_augmentation)\n","\n","# Patches are divited into training and validation patch set. This inhibits over-lapping of patches. \n","number_train_images =int(len(patches)*(percentage_validation/100))\n","X = patches[number_train_images:]\n","X_val = patches[:number_train_images]\n","\n","print(len(patches),\"patches created.\")\n","print(number_train_images,\"patch images for validation (\",percentage_validation,\"%).\")\n","print((len(patches)-number_train_images),\"patch images for training.\")\n","%memit \n","\n","#Here we automatically define number_of_step in function of training data and batch size\n","if (Use_Default_Advanced_Parameters): \n"," number_of_steps= int(X.shape[0]/batch_size) + 1\n","\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","\n","# creates Congfig object. \n","config = N2VConfig(X, unet_kern_size=3, \n"," train_steps_per_epoch=number_of_steps,train_epochs=number_of_epochs, train_loss='mse', batch_norm=True, \n"," train_batch_size=batch_size, n2v_perc_pix=0.198, n2v_patch_shape=(patch_height, patch_size, patch_size), \n"," n2v_manipulator='uniform_withCP', n2v_neighborhood_radius=5, train_learning_rate = initial_learning_rate)\n","\n","vars(config)\n","\n","# Create the default model.\n","model = N2V(config=config, name=model_name, basedir=model_path)\n","\n","# --------------------- Using pretrained model ------------------------\n","# Load the pretrained weights \n","if Use_pretrained_model:\n"," model.load_weights(h5_file_path)\n","# --------------------- ---------------------- ------------------------\n","\n","print(\"Parameters transferred into the model.\")\n","print(config)\n","\n","# Shows a training batch and a validation batch.\n","plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(X[0,1,...,0],cmap='magma')\n","plt.axis('off')\n","plt.title('Training Patch');\n","plt.subplot(1,2,2)\n","plt.imshow(X_val[0,1,...,0],cmap='magma')\n","plt.axis('off')\n","plt.title('Validation Patch');\n","\n","\n","pdf_export(trained = False, pretrained_model = Use_pretrained_model)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"0Dfn8ZsEMv5d"},"source":["## **4.2. Start Training**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point.\n","\n","**Of Note:** At the end of the training, your model will be automatically exported so it can be used in the CSBDeep Fiji plugin (N2V -- N2V Predict). You can find it in your model folder (export.bioimage.io.zip and model.yaml). In Fiji, Make sure to choose the right version of tensorflow. You can check at: Edit-- Options-- Tensorflow. Choose the version 1.4 (CPU or GPU depending on your system).\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"code","metadata":{"scrolled":true,"id":"iwNmp1PUzRDQ","cellView":"form"},"source":["start = time.time()\n","\n","#@markdown ##Start training\n","%memit\n","# the training starts.\n","history = model.train(X, X_val)\n","%memit\n","print(\"Model training is now done.\")\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n"," shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","model.export_TF(name='Noise2Void', \n"," description='Noise2Void 3D trained using ZeroCostDL4Mic.', \n"," authors=[\"You\"],\n"," test_img=X_val[0,...,0], axes='ZYX',\n"," patch_shape=(patch_size, patch_size))\n","\n","print(\"Your model has been sucessfully exported and can now also be used in the CSBDeep Fiji plugin\")\n","\n","#Create a pdf document with training summary\n","\n","pdf_export(trained=True, pretrained_model = Use_pretrained_model)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"_0Hynw3-xHp1"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"id":"eAJzMwPA6tlH","cellView":"form"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else: \n"," print(bcolors.WARNING + '!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"dhJROwlAMv5o"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"id":"vMzSP50kMv5p","cellView":"form"},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png')\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"X5_92nL2xdP6"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n"]},{"cell_type":"code","metadata":{"id":"w90MdriMxhjD","cellView":"form"},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","path_metrics_save = QC_model_path+'/'+QC_model_name+'/Quality Control/'\n","\n","# Create a quality control/Prediction Folder\n","if os.path.exists(path_metrics_save+'Prediction'):\n"," shutil.rmtree(path_metrics_save+'Prediction')\n","os.makedirs(path_metrics_save+'Prediction')\n","\n","#Here we allow the user to choose the number of tile to be used when predicting the images\n","#@markdown #####To analyse large image, your images need to be divided into tiles. Each tile will then be processed independently and re-assembled to generate the final image. \"Automatic_number_of_tiles\" will search for and use the smallest number of tiles that can be used, at the expanse of your runtime. Alternatively, manually input the number of tiles in each dimension to be used to process your images. \n","\n","Automatic_number_of_tiles = False #@param {type:\"boolean\"}\n","#@markdown #####If you get an Out of memory (OOM) error when using the \"Automatic_number_of_tiles\" option, disable it and manually input the values to be used to process your images. Progressively increases these numbers until the OOM error disappear.\n","n_tiles_Z = 1#@param {type:\"number\"}\n","n_tiles_Y = 2#@param {type:\"number\"}\n","n_tiles_X = 2#@param {type:\"number\"}\n","\n","if (Automatic_number_of_tiles): \n"," n_tilesZYX = None\n","\n","if not (Automatic_number_of_tiles):\n"," n_tilesZYX = (n_tiles_Z, n_tiles_Y, n_tiles_X)\n","\n","\n","# Activate the pretrained model. \n","model_training = N2V(config=None, name=QC_model_name, basedir=QC_model_path)\n","\n","# List Tif images in Source_QC_folder\n","Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\n","Z = sorted(glob(Source_QC_folder_tif))\n","Z = list(map(imread,Z))\n","print('Number of test dataset found in the folder: '+str(len(Z)))\n","\n","\n","# Perform prediction on all datasets in the Source_QC folder\n","for filename in os.listdir(Source_QC_folder):\n"," img = imread(os.path.join(Source_QC_folder, filename))\n"," n_slices = img.shape[0]\n"," predicted = model_training.predict(img, axes='ZYX', n_tiles=n_tilesZYX)\n"," os.chdir(path_metrics_save+'Prediction/')\n"," imsave('Predicted_'+filename, predicted)\n","\n","\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(path_metrics_save+'QC_metrics_'+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"File name\",\"Slice #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n"," \n"," # These lists will be used to collect all the metrics values per slice\n"," file_name_list = []\n"," slice_number_list = []\n"," mSSIM_GvP_list = []\n"," mSSIM_GvS_list = []\n"," NRMSE_GvP_list = []\n"," NRMSE_GvS_list = []\n"," PSNR_GvP_list = []\n"," PSNR_GvS_list = []\n","\n"," # These lists will be used to display the mean metrics for the stacks\n"," mSSIM_GvP_list_mean = []\n"," mSSIM_GvS_list_mean = []\n"," NRMSE_GvP_list_mean = []\n"," NRMSE_GvS_list_mean = []\n"," PSNR_GvP_list_mean = []\n"," PSNR_GvS_list_mean = []\n","\n"," # Let's loop through the provided dataset in the QC folders\n"," for thisFile in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder, thisFile)):\n"," print('Running QC on: '+thisFile)\n","\n"," test_GT_stack = io.imread(os.path.join(Target_QC_folder, thisFile))\n"," test_source_stack = io.imread(os.path.join(Source_QC_folder,thisFile))\n"," test_prediction_stack = io.imread(os.path.join(path_metrics_save+\"Prediction/\",'Predicted_'+thisFile))\n"," n_slices = test_GT_stack.shape[0]\n","\n"," # Calculating the position of the mid-plane slice\n"," z_mid_plane = int(n_slices / 2)+1\n","\n"," img_SSIM_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_SSIM_GTvsSource_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_RSE_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_RSE_GTvsSource_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n","\n"," for z in range(n_slices): \n"," # -------------------------------- Normalising the dataset --------------------------------\n","\n"," test_GT_norm,test_source_norm = norm_minmse(test_GT_stack[z], test_source_stack[z], normalize_gt=True)\n"," test_GT_norm,test_prediction_norm = norm_minmse(test_GT_stack[z], test_prediction_stack[z], normalize_gt=True)\n","\n"," # -------------------------------- Calculate the SSIM metric and maps --------------------------------\n"," # Calculate the SSIM maps and index\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = structural_similarity(test_GT_norm, test_prediction_norm, data_range=1.0, full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = structural_similarity(test_GT_norm, test_source_norm, data_range=1.0, full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n"," #Calculate ssim_maps\n"," img_SSIM_GTvsPrediction_stack[z] = img_as_float32(img_SSIM_GTvsPrediction,force_copy=False)\n"," img_SSIM_GTvsSource_stack[z] = img_as_float32(img_SSIM_GTvsSource,force_copy=False)\n"," \n","\n"," # -------------------------------- Calculate the NRMSE metrics --------------------------------\n","\n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n","\n"," # Calculate SE maps\n"," img_RSE_GTvsPrediction_stack[z] = img_as_float32(img_RSE_GTvsPrediction)\n"," img_RSE_GTvsSource_stack[z] = img_as_float32(img_RSE_GTvsSource)\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n","\n"," # Calculate the PSNR between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n","\n"," writer.writerow([thisFile, str(z),str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource), str(PSNR_GTvsPrediction), str(PSNR_GTvsSource)])\n"," \n"," # Collect values to display in dataframe output\n"," slice_number_list.append(z)\n"," mSSIM_GvP_list.append(index_SSIM_GTvsPrediction)\n"," mSSIM_GvS_list.append(index_SSIM_GTvsSource)\n"," NRMSE_GvP_list.append(NRMSE_GTvsPrediction)\n"," NRMSE_GvS_list.append(NRMSE_GTvsSource)\n"," PSNR_GvP_list.append(PSNR_GTvsPrediction)\n"," PSNR_GvS_list.append(PSNR_GTvsSource)\n","\n"," if (z == z_mid_plane): # catch these for display\n"," SSIM_GTvsP_forDisplay = index_SSIM_GTvsPrediction\n"," SSIM_GTvsS_forDisplay = index_SSIM_GTvsSource\n"," NRMSE_GTvsP_forDisplay = NRMSE_GTvsPrediction\n"," NRMSE_GTvsS_forDisplay = NRMSE_GTvsSource\n"," \n"," # If calculating average metrics for dataframe output\n"," file_name_list.append(thisFile)\n"," mSSIM_GvP_list_mean.append(sum(mSSIM_GvP_list)/len(mSSIM_GvP_list))\n"," mSSIM_GvS_list_mean.append(sum(mSSIM_GvS_list)/len(mSSIM_GvS_list))\n"," NRMSE_GvP_list_mean.append(sum(NRMSE_GvP_list)/len(NRMSE_GvP_list))\n"," NRMSE_GvS_list_mean.append(sum(NRMSE_GvS_list)/len(NRMSE_GvS_list))\n"," PSNR_GvP_list_mean.append(sum(PSNR_GvP_list)/len(PSNR_GvP_list))\n"," PSNR_GvS_list_mean.append(sum(PSNR_GvS_list)/len(PSNR_GvS_list))\n","\n","\n"," # ----------- Change the stacks to 32 bit images -----------\n","\n"," img_SSIM_GTvsSource_stack_32 = img_as_float32(img_SSIM_GTvsSource_stack, force_copy=False)\n"," img_SSIM_GTvsPrediction_stack_32 = img_as_float32(img_SSIM_GTvsPrediction_stack, force_copy=False)\n"," img_RSE_GTvsSource_stack_32 = img_as_float32(img_RSE_GTvsSource_stack, force_copy=False)\n"," img_RSE_GTvsPrediction_stack_32 = img_as_float32(img_RSE_GTvsPrediction_stack, force_copy=False)\n","\n"," # ----------- Saving the error map stacks -----------\n"," io.imsave(path_metrics_save+'SSIM_GTvsSource_'+thisFile,img_SSIM_GTvsSource_stack_32)\n"," io.imsave(path_metrics_save+'SSIM_GTvsPrediction_'+thisFile,img_SSIM_GTvsPrediction_stack_32)\n"," io.imsave(path_metrics_save+'RSE_GTvsSource_'+thisFile,img_RSE_GTvsSource_stack_32)\n"," io.imsave(path_metrics_save+'RSE_GTvsPrediction_'+thisFile,img_RSE_GTvsPrediction_stack_32)\n","\n","#Averages of the metrics per stack as dataframe output\n","pdResults = pd.DataFrame(file_name_list, columns = [\"File name\"])\n","pdResults[\"Prediction v. GT mSSIM\"] = mSSIM_GvP_list_mean\n","pdResults[\"Input v. GT mSSIM\"] = mSSIM_GvS_list_mean\n","pdResults[\"Prediction v. GT NRMSE\"] = NRMSE_GvP_list_mean\n","pdResults[\"Input v. GT NRMSE\"] = NRMSE_GvS_list_mean\n","pdResults[\"Prediction v. GT PSNR\"] = PSNR_GvP_list_mean\n","pdResults[\"Input v. GT PSNR\"] = PSNR_GvS_list_mean\n","\n","\n","# All data is now processed saved\n","Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same way\n","\n","plt.figure(figsize=(15,15))\n","# Currently only displays the last computed set, from memory\n","# Target (Ground-truth)\n","plt.subplot(3,3,1)\n","plt.axis('off')\n","img_GT = io.imread(os.path.join(Target_QC_folder, Test_FileList[-1]))\n","\n","# Calculating the position of the mid-plane slice\n","z_mid_plane = int(img_GT.shape[0] / 2)+1\n","\n","plt.imshow(img_GT[z_mid_plane])\n","plt.title('Target (slice #'+str(z_mid_plane)+')')\n","\n","# Source\n","plt.subplot(3,3,2)\n","plt.axis('off')\n","img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_Source[z_mid_plane])\n","plt.title('Source (slice #'+str(z_mid_plane)+')')\n","\n","#Prediction\n","plt.subplot(3,3,3)\n","plt.axis('off')\n","img_Prediction = io.imread(os.path.join(path_metrics_save+'Prediction/', 'Predicted_'+Test_FileList[-1]))\n","plt.imshow(img_Prediction[z_mid_plane])\n","plt.title('Prediction (slice #'+str(z_mid_plane)+')')\n","\n","#Setting up colours\n","cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Source\n","plt.subplot(3,3,5)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","img_SSIM_GTvsSource = io.imread(os.path.join(path_metrics_save, 'SSIM_GTvsSource_'+Test_FileList[-1]))\n","imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource[z_mid_plane], cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(SSIM_GTvsS_forDisplay,3)),fontsize=14)\n","plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n","plt.subplot(3,3,6)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_SSIM_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'SSIM_GTvsPrediction_'+Test_FileList[-1]))\n","imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0,vmax=1)\n","plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(SSIM_GTvsP_forDisplay,3)),fontsize=14)\n","\n","#Root Squared Error between GT and Source\n","plt.subplot(3,3,8)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","img_RSE_GTvsSource = io.imread(os.path.join(path_metrics_save, 'RSE_GTvsSource_'+Test_FileList[-1]))\n","imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource[z_mid_plane], cmap = cmap, vmin=0, vmax = 1) \n","plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsS_forDisplay,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n","#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\n","plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#Root Squared Error between GT and Prediction\n","plt.subplot(3,3,9)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_RSE_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'RSE_GTvsPrediction_'+Test_FileList[-1]))\n","imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsP_forDisplay,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n","plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","print('-----------------------------------')\n","print('Here are the average scores for the stacks you tested in Quality control. To see values for all slices, open the .csv file saved in the Qulity Control folder.')\n","pdResults.head()\n","\n","#Make a pdf summary of the QC results\n","\n","qc_pdf_export()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-tJeeJjLnRkP"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"d8wuQGjoq6eN"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images."]},{"cell_type":"code","metadata":{"id":"y2TD5p7MZrEb","cellView":"form"},"source":["#Activate the pretrained model. \n","#model_training = CARE(config=None, name=model_name, basedir=model_path)\n","\n","#@markdown ### Provide the path to your dataset and to the folder where the prediction will be saved, then play the cell to predict output on your unseen images.\n","\n","#@markdown ###Path to data to analyse and where predicted output should be saved:\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else: \n"," print(bcolors.WARNING + '!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","#Here we allow the user to choose the number of tile to be used when predicting the images\n","#@markdown #####To analyse large image, your images need to be divided into tiles. Each tile will then be processed independently and re-assembled to generate the final image. \"Automatic_number_of_tiles\" will search for and use the smallest number of tiles that can be used, at the expanse of your runtime. Alternatively, manually input the number of tiles in each dimension to be used to process your images. \n","\n","Automatic_number_of_tiles = False #@param {type:\"boolean\"}\n","#@markdown #####If you get an Out of memory (OOM) error when using the \"Automatic_number_of_tiles\" option, disable it and manually input the values to be used to process your images. Progressively increases these numbers until the OOM error disappear.\n","n_tiles_Z = 1#@param {type:\"number\"}\n","n_tiles_Y = 2#@param {type:\"number\"}\n","n_tiles_X = 2#@param {type:\"number\"}\n","\n","if (Automatic_number_of_tiles): \n"," n_tilesZYX = None\n","\n","if not (Automatic_number_of_tiles):\n"," n_tilesZYX = (n_tiles_Z, n_tiles_Y, n_tiles_X)\n","\n","#Activate the pretrained model.\n","config = None\n","model = N2V(config, Prediction_model_name, basedir=Prediction_model_path)\n","\n","print(\"Denoising images...\")\n","\n","thisdir = Path(Data_folder)\n","outputdir = Path(Result_folder)\n","suffix = '.tif'\n","\n","# r=root, d=directories, f = files\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," if \".tif\" in file:\n"," print(os.path.join(r, file))\n","\n","# The code by Lucas von Chamier\n","for r, d, f in os.walk(thisdir):\n"," for file in f:\n"," base_filename = os.path.basename(file)\n"," input_train = imread(os.path.join(r, file))\n"," pred_train = model.predict(input_train, axes='ZYX', n_tiles=n_tilesZYX)\n"," save_tiff_imagej_compatible(os.path.join(outputdir, base_filename), pred_train, axes='ZYX')\n"," \n","print(\"Prediction of images done.\")\n","\n","print(\"One example is displayed here.\")\n","\n","\n","#Display an example\n","random_choice=random.choice(os.listdir(Data_folder))\n","x = imread(Data_folder+\"/\"+random_choice)\n","\n","#Find image Z dimension and select the mid-plane\n","Image_Z = x.shape[0]\n","mid_plane = int(Image_Z / 2)+1\n","\n","os.chdir(Result_folder)\n","y = imread(Result_folder+\"/\"+random_choice)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x[mid_plane], interpolation='nearest')\n","plt.title('Noisy Input (single Z plane)');\n","plt.axis('off');\n","plt.subplot(1,2,2)\n","plt.imshow(y[mid_plane], interpolation='nearest')\n","plt.title('Prediction (single Z plane)');\n","plt.axis('off');"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB"},"source":["## **6.2. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"bPLQA1sIk8Fh"},"source":["# **7. Version log**\n","---\n","**v1.13**: \n","\n","* N2V now uses tensorflow 2.4.\n","* This version now includes an automatic restart allowing to set the h5py library to v2.10.\n","* The section 1 and 2 are now swapped for better export of *requirements.txt*.\n","This version also now includes built-in version check and the version log that \n","\n","* This version also now includes built-in version check and the version log that you're reading now."]},{"cell_type":"markdown","metadata":{"id":"UvSlTaH14s3t"},"source":["#**Thank you for using Noise2Void 3D!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/StarDist_2D_ZeroCostDL4Mic.ipynb b/Colab_notebooks/StarDist_2D_ZeroCostDL4Mic.ipynb index 7650d8cd..4d0250f0 100644 --- a/Colab_notebooks/StarDist_2D_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/StarDist_2D_ZeroCostDL4Mic.ipynb @@ -1,2269 +1 @@ -{ - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "StarDist_2D_ZeroCostDL4Mic.ipynb", - "provenance": [], - "collapsed_sections": [], - "toc_visible": true, - "machine_shape": "hm" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.7" - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "accelerator": "GPU" - }, - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "IkSguVy8Xv83" - }, - "source": [ - "# **StarDist (2D)**\n", - "---\n", - "\n", - "**StarDist 2D** is a deep-learning method that can be used to segment cell nuclei from bioimages and was first published by [Schmidt *et al.* in 2018, on arXiv](https://arxiv.org/abs/1806.03535). It uses a shape representation based on star-convex polygons for nuclei in an image to predict the presence and the shape of these nuclei. This StarDist 2D network is based on an adapted U-Net network architecture.\n", - "\n", - " **This particular notebook enables nuclei segmentation of 2D dataset. If you are interested in 3D dataset, you should use the StarDist 3D notebook instead.**\n", - "\n", - "---\n", - "*Disclaimer*:\n", - "\n", - "This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n", - "\n", - "This notebook is largely based on the paper:\n", - "\n", - "**Cell Detection with Star-convex Polygons** from Schmidt *et al.*, International Conference on Medical Image Computing and Computer-Assisted Intervention (MICCAI), Granada, Spain, September 2018. (https://arxiv.org/abs/1806.03535)\n", - "\n", - "and the 3D extension of the approach:\n", - "\n", - "**Star-convex Polyhedra for 3D Object Detection and Segmentation in Microscopy** from Weigert *et al.* published on arXiv in 2019 (https://arxiv.org/abs/1908.03636)\n", - "\n", - "**The Original code** is freely available in GitHub:\n", - "/~https://github.com/mpicbg-csbd/stardist\n", - "\n", - "**Please also cite this original paper when using or developing this notebook.**\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jWAz2i7RdxUV" - }, - "source": [ - "# **How to use this notebook?**\n", - "\n", - "---\n", - "\n", - "Video describing how to use our notebooks are available on youtube:\n", - " - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n", - " - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n", - "\n", - "\n", - "---\n", - "###**Structure of a notebook**\n", - "\n", - "The notebook contains two types of cell: \n", - "\n", - "**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n", - "\n", - "**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n", - "\n", - "---\n", - "###**Table of contents, Code snippets** and **Files**\n", - "\n", - "On the top left side of the notebook you find three tabs which contain from top to bottom:\n", - "\n", - "*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n", - "\n", - "*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n", - "\n", - "*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n", - "\n", - "**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n", - "\n", - "**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n", - "\n", - "---\n", - "###**Making changes to the notebook**\n", - "\n", - "**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n", - "\n", - "To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n", - "You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "gKDLkLWUd-YX" - }, - "source": [ - "#**0. Before getting started**\n", - "---\n", - " For StarDist to train, **it needs to have access to a paired training dataset made of images of nuclei and their corresponding masks**. Information on how to generate a training dataset is available in our Wiki page: /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n", - "\n", - "**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model**. The quality control assessment can be done directly in this notebook.\n", - "\n", - "The data structure is important. It is necessary that all the input data are in the same folder and that all the output data is in a separate folder. The provided training dataset is already split in two folders called \"Training - Images\" (Training_source) and \"Training - Masks\" (Training_target).\n", - "\n", - "Additionally, the corresponding Training_source and Training_target files need to have **the same name**.\n", - "\n", - "Please note that you currently can **only use .tif files!**\n", - "\n", - "You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed. This can include Test dataset for which you have the equivalent output and can compare to what the network provides.\n", - "\n", - "Here's a common data structure that can work:\n", - "* Experiment A\n", - " - **Training dataset**\n", - " - Images of nuclei (Training_source)\n", - " - img_1.tif, img_2.tif, ...\n", - " - Masks (Training_target)\n", - " - img_1.tif, img_2.tif, ...\n", - " - **Quality control dataset**\n", - " - Images of nuclei\n", - " - img_1.tif, img_2.tif\n", - " - Masks \n", - " - img_1.tif, img_2.tif\n", - " - **Data to be predicted**\n", - " - **Results**\n", - "\n", - "---\n", - "**Important note**\n", - "\n", - "- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n", - "\n", - "- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n", - "\n", - "- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "n4yWFoJNnoin" - }, - "source": [ - "# **1. Initialise the Colab session**\n", - "\n", - "\n", - "\n", - "\n", - "---\n", - "\n", - "\n", - "\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "DMNHVZfHmbKb" - }, - "source": [ - "\n", - "## **1.1. Check for GPU access**\n", - "---\n", - "\n", - "By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n", - "\n", - "Go to **Runtime -> Change the Runtime type**\n", - "\n", - "**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n", - "\n", - "**Accelerator: GPU** *(Graphics processing unit)*\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "zCvebubeSaGY", - "cellView": "form" - }, - "source": [ - "#@markdown ##Run this cell to check if you have GPU access\n", - "%tensorflow_version 1.x\n", - "\n", - "\n", - "import tensorflow as tf\n", - "if tf.test.gpu_device_name()=='':\n", - " print('You do not have GPU access.') \n", - " print('Did you change your runtime ?') \n", - " print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n", - " print('Expect slow performance. To access GPU try reconnecting later')\n", - "\n", - "else:\n", - " print('You have GPU access')\n", - " !nvidia-smi" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "sNIVx8_CLolt" - }, - "source": [ - "## **1.2. Mount your Google Drive**\n", - "---\n", - " To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n", - "\n", - " Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n", - "\n", - " Once this is done, your data are available in the **Files** tab on the top left of notebook." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "01Djr8v-5pPk", - "cellView": "form" - }, - "source": [ - "#@markdown ##Play the cell to connect your Google Drive to Colab\n", - "\n", - "#@markdown * Click on the URL. \n", - "\n", - "#@markdown * Sign in your Google Account. \n", - "\n", - "#@markdown * Copy the authorization code. \n", - "\n", - "#@markdown * Enter the authorization code. \n", - "\n", - "#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n", - "\n", - "# mount user's Google Drive to Google Colab.\n", - "from google.colab import drive\n", - "drive.mount('/content/gdrive')" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "AdN8B91xZO0x" - }, - "source": [ - "# **2. Install StarDist and dependencies**\n", - "---\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "XuwTHSva_Y5K" - }, - "source": [ - "## **2.1. Install key dependencies**\n", - "---\n", - " " - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "fq21zJVFNASx", - "cellView": "form" - }, - "source": [ - "#@markdown ##Install StarDist and dependencies\n", - "\n", - "# Install packages which are not included in Google Colab\n", - "\n", - "!pip install tifffile # contains tools to operate tiff-files\n", - "!pip install csbdeep # contains tools for restoration of fluorescence microcopy images (Content-aware Image Restoration, CARE). It uses Keras and Tensorflow.\n", - "!pip install stardist # contains tools to operate STARDIST.\n", - "!pip install gputools # improves STARDIST performances\n", - "#!pip install edt # improves STARDIST performances\n", - "!pip install wget\n", - "!pip install fpdf\n", - "!pip install PTable # Nice tables \n", - "!pip install zarr\n", - "!pip install imagecodecs\n", - "\n", - "#Force session restart\n", - "exit(0)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "e_oQT-9180CX" - }, - "source": [ - "## **2.2. Restart your runtime**\n", - "---\n", - " " - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "cuPQR21r83vM" - }, - "source": [ - "** Your Runtime has automatically restarted. This is normal.**\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "2uHigbVJ9CUh" - }, - "source": [ - "## **2.3. Load key dependencies**\n", - "---\n", - " " - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "HMAdm-Mc9HFz", - "cellView": "form" - }, - "source": [ - "#@markdown ##Load key dependencies\n", - "\n", - "Notebook_version = ['1.12']\n", - "\n", - "from builtins import any as b_any\n", - "\n", - "def get_requirements_path():\n", - " # Store requirements file in 'contents' directory \n", - " current_dir = os.getcwd()\n", - " dir_count = current_dir.count('/') - 1\n", - " path = '../' * (dir_count) + 'requirements.txt'\n", - " return path\n", - "\n", - "def filter_files(file_list, filter_list):\n", - " filtered_list = []\n", - " for fname in file_list:\n", - " if b_any(fname.split('==')[0] in s for s in filter_list):\n", - " filtered_list.append(fname)\n", - " return filtered_list\n", - "\n", - "def build_requirements_file(before, after):\n", - " path = get_requirements_path()\n", - "\n", - " # Exporting requirements.txt for local run\n", - " !pip freeze > $path\n", - "\n", - " # Get minimum requirements file\n", - " df = pd.read_csv(path, delimiter = \"\\n\")\n", - " mod_list = [m.split('.')[0] for m in after if not m in before]\n", - " req_list_temp = df.values.tolist()\n", - " req_list = [x[0] for x in req_list_temp]\n", - "\n", - " # Replace with package name and handle cases where import name is different to module name\n", - " mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]\n", - " mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list] \n", - " filtered_list = filter_files(req_list, mod_replace_list)\n", - "\n", - " file=open(path,'w')\n", - " for item in filtered_list:\n", - " file.writelines(item + '\\n')\n", - "\n", - " file.close()\n", - "\n", - "import sys\n", - "before = [str(m) for m in sys.modules]\n", - "\n", - "\n", - "#%load_ext memory_profiler\n", - "\n", - "\n", - "%tensorflow_version 1.x\n", - "\n", - "import tensorflow\n", - "print(tensorflow.__version__)\n", - "print(\"Tensorflow enabled.\")\n", - "\n", - "\n", - "import imagecodecs\n", - "\n", - "# ------- Variable specific to Stardist -------\n", - "from stardist import fill_label_holes, random_label_cmap, calculate_extents, gputools_available, relabel_image_stardist, random_label_cmap, relabel_image_stardist, _draw_polygons, export_imagej_rois\n", - "from stardist.models import Config2D, StarDist2D, StarDistData2D # import objects\n", - "from stardist.matching import matching_dataset\n", - "from __future__ import print_function, unicode_literals, absolute_import, division\n", - "from csbdeep.utils import Path, normalize, download_and_extract_zip_file, plot_history # for loss plot\n", - "from csbdeep.io import save_tiff_imagej_compatible\n", - "import numpy as np\n", - "np.random.seed(42)\n", - "lbl_cmap = random_label_cmap()\n", - "%matplotlib inline\n", - "%config InlineBackend.figure_format = 'retina'\n", - "\n", - "from PIL import Image\n", - "import zarr\n", - "from zipfile import ZIP_DEFLATED\n", - "from csbdeep.data import Normalizer, normalize_mi_ma\n", - "import imagecodecs\n", - "\n", - "\n", - "class MyNormalizer(Normalizer):\n", - " def __init__(self, mi, ma):\n", - " self.mi, self.ma = mi, ma\n", - " def before(self, x, axes):\n", - " return normalize_mi_ma(x, self.mi, self.ma, dtype=np.float32)\n", - " def after(*args, **kwargs):\n", - " assert False\n", - " @property\n", - " def do_after(self):\n", - " return False\n", - "\n", - "\n", - "\n", - "# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n", - "import numpy as np\n", - "from matplotlib import pyplot as plt\n", - "import urllib\n", - "import os, random\n", - "import shutil \n", - "import zipfile\n", - "from tifffile import imread, imsave\n", - "import time\n", - "import sys\n", - "import wget\n", - "from pathlib import Path\n", - "import pandas as pd\n", - "import csv\n", - "from glob import glob\n", - "from scipy import signal\n", - "from scipy import ndimage\n", - "from skimage import io\n", - "from sklearn.linear_model import LinearRegression\n", - "from skimage.util import img_as_uint\n", - "import matplotlib as mpl\n", - "from skimage.metrics import structural_similarity\n", - "from skimage.metrics import peak_signal_noise_ratio as psnr\n", - "from astropy.visualization import simple_norm\n", - "from skimage import img_as_float32, img_as_ubyte, img_as_float\n", - "from skimage.util import img_as_ubyte\n", - "from tqdm import tqdm \n", - "import cv2\n", - "from fpdf import FPDF, HTMLMixin\n", - "from datetime import datetime\n", - "from pip._internal.operations.freeze import freeze\n", - "import subprocess\n", - "\n", - "# For sliders and dropdown menu and progress bar\n", - "from ipywidgets import interact\n", - "import ipywidgets as widgets\n", - "\n", - "# Colors for the warning messages\n", - "class bcolors:\n", - " WARNING = '\\033[31m'\n", - "W = '\\033[0m' # white (normal)\n", - "R = '\\033[31m' # red\n", - "\n", - "#Disable some of the tensorflow warnings\n", - "import warnings\n", - "warnings.filterwarnings(\"ignore\")\n", - "\n", - "print('------------------------------------------')\n", - "print(\"Libraries installed\")\n", - "\n", - "\n", - "# Check if this is the latest version of the notebook\n", - "Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n", - "print('Notebook version: '+Notebook_version[0])\n", - "strlist = Notebook_version[0].split('.')\n", - "Notebook_version_main = strlist[0]+'.'+strlist[1]\n", - "if Notebook_version_main == Latest_notebook_version.columns:\n", - " print(\"This notebook is up-to-date.\")\n", - "else:\n", - " print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n", - "\n", - "\n", - "# PDF export\n", - "\n", - "def pdf_export(trained=False, augmentation = False, pretrained_model = False):\n", - " class MyFPDF(FPDF, HTMLMixin):\n", - " pass\n", - "\n", - " pdf = MyFPDF()\n", - " pdf.add_page()\n", - " pdf.set_right_margin(-1)\n", - " pdf.set_font(\"Arial\", size = 11, style='B') \n", - "\n", - " Network = 'StarDist 2D'\n", - " day = datetime.now()\n", - " datetime_str = str(day)[0:10]\n", - "\n", - " Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n", - " pdf.multi_cell(180, 5, txt = Header, align = 'L') \n", - " \n", - " # add another cell \n", - " if trained:\n", - " training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n", - " pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n", - " pdf.ln(1)\n", - "\n", - " Header_2 = 'Information for your materials and method:'\n", - " pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n", - "\n", - " all_packages = ''\n", - " for requirement in freeze(local_only=True):\n", - " all_packages = all_packages+requirement+', '\n", - " #print(all_packages)\n", - "\n", - " #Main Packages\n", - " main_packages = ''\n", - " version_numbers = []\n", - " for name in ['tensorflow','numpy','Keras','csbdeep']:\n", - " find_name=all_packages.find(name)\n", - " main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n", - " #Version numbers only here:\n", - " version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n", - "\n", - " cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n", - " cuda_version = cuda_version.stdout.decode('utf-8')\n", - " cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n", - " gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n", - " gpu_name = gpu_name.stdout.decode('utf-8')\n", - " gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n", - " #print(cuda_version[cuda_version.find(', V')+3:-1])\n", - " #print(gpu_name)\n", - "\n", - " shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n", - " dataset_size = len(os.listdir(Training_source))\n", - "\n", - " text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+conf.train_dist_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n", - "\n", - " #text = 'The '+Network+' model ('+model_name+') was trained using '+str(dataset_size)+' paired images (image dimensions: '+str(shape)+') using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The GPU used was a '+gpu_name+'.'\n", - "\n", - " if pretrained_model:\n", - " text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+conf.train_dist_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was retrained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n", - "\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - " pdf.multi_cell(190, 5, txt = text, align='L')\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.ln(1)\n", - " pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n", - " pdf.set_font('')\n", - " if augmentation:\n", - " aug_text = 'The dataset was augmented by a factor of '+str(Multiply_dataset_by)\n", - " \n", - " else:\n", - " aug_text = 'No augmentation was used for training.'\n", - " pdf.multi_cell(190, 5, txt=aug_text, align='L')\n", - " pdf.set_font('Arial', size = 11, style = 'B')\n", - " pdf.ln(1)\n", - " pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - " if Use_Default_Advanced_Parameters:\n", - " pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n", - " pdf.cell(200, 5, txt='The following parameters were used for training:')\n", - " pdf.ln(1)\n", - " html = \"\"\" \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
ParameterValue
number_of_epochs{0}
patch_size{1}
batch_size{2}
number_of_steps{3}
percentage_validation{4}
n_rays{5}
grid_parameter{6}
initial_learning_rate{7}
\n", - " \"\"\".format(number_of_epochs,str(patch_size)+'x'+str(patch_size),batch_size,number_of_steps,percentage_validation,n_rays,grid_parameter,initial_learning_rate)\n", - " pdf.write_html(html)\n", - "\n", - " #pdf.multi_cell(190, 5, txt = text_2, align='L')\n", - " pdf.set_font(\"Arial\", size = 11, style='B')\n", - " pdf.ln(1)\n", - " pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.cell(30, 5, txt= 'Training_source:', align = 'L', ln=0)\n", - " pdf.set_font('')\n", - " pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.cell(28, 5, txt= 'Training_target:', align = 'L', ln=0)\n", - " pdf.set_font('')\n", - " pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n", - " #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n", - " pdf.ln(1)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.cell(21, 5, txt= 'Model Path:', align = 'L', ln=0)\n", - " pdf.set_font('')\n", - " pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n", - " pdf.ln(1)\n", - " pdf.cell(60, 5, txt = 'Example Training pair', ln=1)\n", - " pdf.ln(1)\n", - " exp_size = io.imread('/content/TrainingDataExample_StarDist2D.png').shape\n", - " pdf.image('/content/TrainingDataExample_StarDist2D.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n", - " pdf.ln(1)\n", - " ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n", - " pdf.multi_cell(190, 5, txt = ref_1, align='L')\n", - " ref_2 = '- StarDist 2D: Schmidt, Uwe, et al. \"Cell detection with star-convex polygons.\" International Conference on Medical Image Computing and Computer-Assisted Intervention. Springer, Cham, 2018.'\n", - " pdf.multi_cell(190, 5, txt = ref_2, align='L')\n", - " if augmentation:\n", - " ref_4 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\n", - " pdf.multi_cell(190, 5, txt = ref_4, align='L')\n", - " pdf.ln(3)\n", - " reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n", - " pdf.set_font('Arial', size = 11, style='B')\n", - " pdf.multi_cell(190, 5, txt=reminder, align='C')\n", - "\n", - " pdf.output(model_path+'/'+model_name+'/'+model_name+\"_training_report.pdf\")\n", - "\n", - "\n", - "def qc_pdf_export():\n", - " class MyFPDF(FPDF, HTMLMixin):\n", - " pass\n", - "\n", - " pdf = MyFPDF()\n", - " pdf.add_page()\n", - " pdf.set_right_margin(-1)\n", - " pdf.set_font(\"Arial\", size = 11, style='B') \n", - "\n", - " Network = 'Stardist 2D'\n", - "\n", - " day = datetime.now()\n", - " datetime_str = str(day)[0:10]\n", - "\n", - " Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n", - " pdf.multi_cell(180, 5, txt = Header, align = 'L') \n", - "\n", - " all_packages = ''\n", - " for requirement in freeze(local_only=True):\n", - " all_packages = all_packages+requirement+', '\n", - "\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 11, style = 'B')\n", - " pdf.ln(2)\n", - " pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n", - " pdf.ln(1)\n", - " exp_size = io.imread(full_QC_model_path+'/Quality Control/lossCurvePlots.png').shape\n", - " if os.path.exists(full_QC_model_path+'/Quality Control/lossCurvePlots.png'):\n", - " pdf.image(full_QC_model_path+'/Quality Control/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n", - " else:\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size=10)\n", - " pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.')\n", - " pdf.ln(2)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.ln(3)\n", - " pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n", - " pdf.ln(1)\n", - " exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n", - " pdf.image(full_QC_model_path+'/Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/10))\n", - " pdf.ln(1)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 11, style = 'B')\n", - " pdf.ln(1)\n", - " pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - "\n", - " pdf.ln(1)\n", - " html = \"\"\"\n", - " \n", - " \n", - " \"\"\"\n", - " with open(full_QC_model_path+'/Quality Control/Quality_Control for '+QC_model_name+'.csv', 'r') as csvfile:\n", - " metrics = csv.reader(csvfile)\n", - " header = next(metrics)\n", - " #image = header[0]\n", - " #PvGT_IoU = header[1]\n", - " fp = header[2]\n", - " tp = header[3]\n", - " fn = header[4]\n", - " precision = header[5]\n", - " recall = header[6]\n", - " acc = header[7]\n", - " f1 = header[8]\n", - " n_true = header[9]\n", - " n_pred = header[10]\n", - " mean_true = header[11]\n", - " mean_matched = header[12]\n", - " panoptic = header[13]\n", - " header = \"\"\"\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \"\"\".format(\"image #\",\"Prediction v. GT IoU\",'false pos.','true pos.','false neg.',precision,recall,acc,f1,n_true,n_pred,mean_true,mean_matched,panoptic)\n", - " html = html+header\n", - " i=0\n", - " for row in metrics:\n", - " i+=1\n", - " #image = row[0]\n", - " PvGT_IoU = row[1]\n", - " fp = row[2]\n", - " tp = row[3]\n", - " fn = row[4]\n", - " precision = row[5]\n", - " recall = row[6]\n", - " acc = row[7]\n", - " f1 = row[8]\n", - " n_true = row[9]\n", - " n_pred = row[10]\n", - " mean_true = row[11]\n", - " mean_matched = row[12]\n", - " panoptic = row[13]\n", - " cells = \"\"\"\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \"\"\".format(str(i),str(round(float(PvGT_IoU),3)),fp,tp,fn,str(round(float(precision),3)),str(round(float(recall),3)),str(round(float(acc),3)),str(round(float(f1),3)),n_true,n_pred,str(round(float(mean_true),3)),str(round(float(mean_matched),3)),str(round(float(panoptic),3)))\n", - " html = html+cells\n", - " html = html+\"\"\"
{0}{1}{2}{3}{4}{5}{6}{7}{8}{9}{10}{11}{12}{13}
{0}{1}{2}{3}{4}{5}{6}{7}{8}{9}{10}{11}{12}{13}
\"\"\"\n", - " \n", - " pdf.write_html(html)\n", - "\n", - " pdf.ln(1)\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - " ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n", - " pdf.multi_cell(190, 5, txt = ref_1, align='L')\n", - " ref_2 = '- StarDist 2D: Schmidt, Uwe, et al. \"Cell detection with star-convex polygons.\" International Conference on Medical Image Computing and Computer-Assisted Intervention. Springer, Cham, 2018.'\n", - " pdf.multi_cell(190, 5, txt = ref_2, align='L')\n", - "\n", - " pdf.ln(3)\n", - " reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n", - "\n", - " pdf.set_font('Arial', size = 11, style='B')\n", - " pdf.multi_cell(190, 5, txt=reminder, align='C')\n", - "\n", - " pdf.output(full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n", - "\n", - "# Build requirements file for local run\n", - "after = [str(m) for m in sys.modules]\n", - "build_requirements_file(before, after)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "HLYcZR9gMv42" - }, - "source": [ - "# **3. Select your parameters and paths**\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "FQ_QxtSWQ7CL" - }, - "source": [ - "## **3.1. Setting main training parameters**\n", - "---\n", - " " - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "AuESFimvMv43" - }, - "source": [ - " **Paths for training, predictions and results**\n", - "\n", - "\n", - "**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source (images of nuclei) and Training_target (masks) training data respecively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n", - "\n", - "**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n", - "\n", - "**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n", - "\n", - "\n", - "**Training parameters**\n", - "\n", - "**`number_of_epochs`:** Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a 50-100 epochs, but a full training should run for up to 400 epochs. Evaluate the performance after training (see 5.). **Default value: 100**\n", - "\n", - "**Advanced Parameters - experienced users only**\n", - "\n", - "**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 2**\n", - "\n", - "**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each image / patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n", - "\n", - "**`patch_size`:** Input the size of the patches use to train StarDist 2D (length of a side). The value should be smaller or equal to the dimensions of the image. Make the patch size as large as possible and divisible by 8. **Default value: dimension of the training images** \n", - "\n", - "**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during the training. **Default value: 10** \n", - "\n", - "**`n_rays`:** Set number of rays (corners) used for StarDist (for instance, a square has 4 corners). **Default value: 32** \n", - "\n", - "**`grid_parameter`:** increase this number if the cells/nuclei are very large or decrease it if they are very small. **Default value: 2**\n", - "\n", - "**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0003**\n", - "\n", - "**If you get an Out of memory (OOM) error during the training, manually decrease the patch_size value until the OOM error disappear.**\n", - "\n", - "\n", - "\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "ewpNJ_I0Mv47", - "cellView": "form" - }, - "source": [ - "#@markdown ###Path to training images: \n", - "Training_source = \"\" #@param {type:\"string\"}\n", - "\n", - "Training_target = \"\" #@param {type:\"string\"}\n", - "\n", - "\n", - "#@markdown ###Name of the model and path to model folder:\n", - "model_name = \"\" #@param {type:\"string\"}\n", - "\n", - "model_path = \"\" #@param {type:\"string\"}\n", - "#trained_model = model_path \n", - "\n", - "\n", - "#@markdown ### Other parameters for training:\n", - "number_of_epochs = 100#@param {type:\"number\"}\n", - "\n", - "#@markdown ###Advanced Parameters\n", - "Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n", - "\n", - "#@markdown ###If not, please input:\n", - "\n", - "#GPU_limit = 90 #@param {type:\"number\"}\n", - "batch_size = 4 #@param {type:\"number\"}\n", - "number_of_steps = 0#@param {type:\"number\"}\n", - "patch_size = 1024 #@param {type:\"number\"}\n", - "percentage_validation = 10 #@param {type:\"number\"}\n", - "n_rays = 32 #@param {type:\"number\"}\n", - "grid_parameter = 2#@param [1, 2, 4, 8, 16, 32] {type:\"raw\"}\n", - "initial_learning_rate = 0.0003 #@param {type:\"number\"}\n", - "\n", - "if (Use_Default_Advanced_Parameters): \n", - " print(\"Default advanced parameters enabled\")\n", - " batch_size = 2\n", - " n_rays = 32\n", - " percentage_validation = 10\n", - " grid_parameter = 2\n", - " initial_learning_rate = 0.0003\n", - "\n", - "percentage = percentage_validation/100\n", - "\n", - "#here we check that no model with the same name already exist, if so print a warning\n", - "\n", - "if os.path.exists(model_path+'/'+model_name):\n", - " print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted !!\")\n", - " print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\"+W)\n", - "\n", - " \n", - "# Here we open will randomly chosen input and output image\n", - "random_choice = random.choice(os.listdir(Training_source))\n", - "x = imread(Training_source+\"/\"+random_choice)\n", - "\n", - "# Here we check the image dimensions\n", - "\n", - "Image_Y = x.shape[0]\n", - "Image_X = x.shape[1]\n", - "\n", - "print('Loaded images (width, length) =', x.shape)\n", - "\n", - "# If default parameters, patch size is the same as image size\n", - "if (Use_Default_Advanced_Parameters):\n", - " patch_size = min(Image_Y, Image_X)\n", - " \n", - "#Hyperparameters failsafes\n", - "\n", - "# Here we check that patch_size is smaller than the smallest xy dimension of the image \n", - "\n", - "if patch_size > min(Image_Y, Image_X):\n", - " patch_size = min(Image_Y, Image_X)\n", - " print(bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n", - "\n", - "if patch_size > 2048:\n", - " patch_size = 2048\n", - " print(bcolors.WARNING + \" Your image dimension is large; therefore the patch_size chosen is now:\",patch_size)\n", - "\n", - "\n", - "# Here we check that the patch_size is divisible by 16\n", - "if not patch_size % 16 == 0:\n", - " patch_size = ((int(patch_size / 16)-1) * 16)\n", - " print(bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is:\",patch_size)\n", - "\n", - "# Here we disable pre-trained model by default (in case the next cell is not ran)\n", - "Use_pretrained_model = False\n", - "\n", - "# Here we disable data augmentation by default (in case the cell is not ran)\n", - "\n", - "Use_Data_augmentation = False\n", - "\n", - "\n", - "print(\"Parameters initiated.\")\n", - "\n", - "\n", - "os.chdir(Training_target)\n", - "y = imread(Training_target+\"/\"+random_choice)\n", - "\n", - "#Here we use a simple normalisation strategy to visualise the image\n", - "norm = simple_norm(x, percent = 99)\n", - "\n", - "f=plt.figure(figsize=(16,8))\n", - "plt.subplot(1,2,1)\n", - "plt.imshow(x, interpolation='nearest', norm=norm, cmap='magma')\n", - "plt.title('Training source')\n", - "plt.axis('off');\n", - "\n", - "plt.subplot(1,2,2)\n", - "plt.imshow(y, interpolation='nearest', cmap=lbl_cmap)\n", - "plt.title('Training target')\n", - "plt.axis('off');\n", - "plt.savefig('/content/TrainingDataExample_StarDist2D.png',bbox_inches='tight',pad_inches=0)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "xyQZKby8yFME" - }, - "source": [ - "## **3.2. Data augmentation**\n", - "---\n", - "" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "w_jCy7xOx2g3" - }, - "source": [ - "Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n", - "\n", - "Data augmentation is performed here via random rotations, flips, and intensity changes.\n", - "\n", - "\n", - " **However, data augmentation is not a magic solution and may also introduce issues. Therefore, we recommend that you train your network with and without augmentation, and use the QC section to validate that it improves overall performances.** " - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "DMqWq5-AxnFU", - "cellView": "form" - }, - "source": [ - "#Data augmentation\n", - "\n", - "Use_Data_augmentation = False #@param {type:\"boolean\"}\n", - "\n", - "#@markdown ####Choose a factor by which you want to multiply your original dataset\n", - "\n", - "Multiply_dataset_by = 4 #@param {type:\"slider\", min:1, max:10, step:1}\n", - "\n", - "\n", - "def random_fliprot(img, mask): \n", - " assert img.ndim >= mask.ndim\n", - " axes = tuple(range(mask.ndim))\n", - " perm = tuple(np.random.permutation(axes))\n", - " img = img.transpose(perm + tuple(range(mask.ndim, img.ndim))) \n", - " mask = mask.transpose(perm) \n", - " for ax in axes: \n", - " if np.random.rand() > 0.5:\n", - " img = np.flip(img, axis=ax)\n", - " mask = np.flip(mask, axis=ax)\n", - " return img, mask \n", - "\n", - "def random_intensity_change(img):\n", - " img = img*np.random.uniform(0.6,2) + np.random.uniform(-0.2,0.2)\n", - " return img\n", - "\n", - "\n", - "def augmenter(x, y):\n", - " \"\"\"Augmentation of a single input/label image pair.\n", - " x is an input image\n", - " y is the corresponding ground-truth label image\n", - " \"\"\"\n", - " x, y = random_fliprot(x, y)\n", - " x = random_intensity_change(x)\n", - " # add some gaussian noise\n", - " sig = 0.02*np.random.uniform(0,1)\n", - " x = x + sig*np.random.normal(0,1,x.shape)\n", - " return x, y\n", - "\n", - "\n", - "\n", - "if Use_Data_augmentation:\n", - " augmenter = augmenter\n", - " print(\"Data augmentation enabled\")\n", - "\n", - "\n", - "if not Use_Data_augmentation:\n", - " augmenter = None\n", - " print(bcolors.WARNING+\"Data augmentation disabled\") \n", - "\n", - "\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "3L9zSGtORKYI" - }, - "source": [ - "\n", - "## **3.3. Using weights from a pre-trained model as initial weights**\n", - "---\n", - " Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a StarDist model**. \n", - "\n", - " This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n", - "\n", - " In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. " - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "9vC2n-HeLdiJ", - "cellView": "form" - }, - "source": [ - "# @markdown ##Loading weights from a pre-trained network\n", - "\n", - "\n", - "Use_pretrained_model = False #@param {type:\"boolean\"}\n", - "\n", - "pretrained_model_choice = \"2D_versatile_fluo_from_Stardist_Fiji\" #@param [\"Model_from_file\", \"2D_versatile_fluo_from_Stardist_Fiji\", \"2D_Demo_Model_from_Stardist_Github\", \"Versatile_H&E_nuclei\"]\n", - "\n", - "Weights_choice = \"best\" #@param [\"last\", \"best\"]\n", - "\n", - "\n", - "#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n", - "pretrained_model_path = \"\" #@param {type:\"string\"}\n", - "\n", - "# --------------------- Check if we load a previously trained model ------------------------\n", - "if Use_pretrained_model:\n", - "\n", - "# --------------------- Load the model from the choosen path ------------------------\n", - " if pretrained_model_choice == \"Model_from_file\":\n", - " h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n", - "\n", - "\n", - "# --------------------- Download the Demo 2D model provided in the Stardist 2D github ------------------------\n", - "\n", - " if pretrained_model_choice == \"2D_Demo_Model_from_Stardist_Github\":\n", - " pretrained_model_name = \"2D_Demo\"\n", - " pretrained_model_path = \"/content/\"+pretrained_model_name\n", - " print(\"Downloading the 2D_Demo_Model_from_Stardist_Github\")\n", - " if os.path.exists(pretrained_model_path):\n", - " shutil.rmtree(pretrained_model_path)\n", - " os.makedirs(pretrained_model_path)\n", - " wget.download(\"/~https://github.com/mpicbg-csbd/stardist/raw/master/models/examples/2D_demo/config.json\", pretrained_model_path)\n", - " wget.download(\"/~https://github.com/mpicbg-csbd/stardist/raw/master/models/examples/2D_demo/thresholds.json\", pretrained_model_path)\n", - " wget.download(\"/~https://github.com/mpicbg-csbd/stardist/blob/master/models/examples/2D_demo/weights_best.h5?raw=true\", pretrained_model_path) \n", - " wget.download(\"/~https://github.com/mpicbg-csbd/stardist/blob/master/models/examples/2D_demo/weights_last.h5?raw=true\", pretrained_model_path)\n", - " h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n", - "\n", - "# --------------------- Download the Demo 2D_versatile_fluo_from_Stardist_Fiji ------------------------\n", - "\n", - " if pretrained_model_choice == \"2D_versatile_fluo_from_Stardist_Fiji\":\n", - " print(\"Downloading the 2D_versatile_fluo_from_Stardist_Fiji\")\n", - " pretrained_model_name = \"2D_versatile_fluo\"\n", - " pretrained_model_path = \"/content/\"+pretrained_model_name\n", - " \n", - " if os.path.exists(pretrained_model_path):\n", - " shutil.rmtree(pretrained_model_path)\n", - " os.makedirs(pretrained_model_path)\n", - " \n", - " wget.download(\"https://cloud.mpi-cbg.de/index.php/s/1k5Zcy7PpFWRb0Q/download?path=/versatile&files=2D_versatile_fluo.zip\", pretrained_model_path)\n", - " \n", - " with zipfile.ZipFile(pretrained_model_path+\"/2D_versatile_fluo.zip\", 'r') as zip_ref:\n", - " zip_ref.extractall(pretrained_model_path)\n", - " \n", - " h5_file_path = os.path.join(pretrained_model_path, \"weights_best.h5\")\n", - "\n", - "# --------------------- Download the Versatile (H&E nuclei)_fluo_from_Stardist_Fiji ------------------------\n", - "\n", - " if pretrained_model_choice == \"Versatile_H&E_nuclei\":\n", - " print(\"Downloading the Versatile_H&E_nuclei from_Stardist_Fiji\")\n", - " pretrained_model_name = \"2D_versatile_he\"\n", - " pretrained_model_path = \"/content/\"+pretrained_model_name\n", - " \n", - " if os.path.exists(pretrained_model_path):\n", - " shutil.rmtree(pretrained_model_path)\n", - " os.makedirs(pretrained_model_path)\n", - " \n", - " wget.download(\"https://cloud.mpi-cbg.de/index.php/s/1k5Zcy7PpFWRb0Q/download?path=/versatile&files=2D_versatile_he.zip\", pretrained_model_path)\n", - " \n", - " with zipfile.ZipFile(pretrained_model_path+\"/2D_versatile_he.zip\", 'r') as zip_ref:\n", - " zip_ref.extractall(pretrained_model_path)\n", - " \n", - " h5_file_path = os.path.join(pretrained_model_path, \"weights_best.h5\")\n", - "\n", - "\n", - "# --------------------- Add additional pre-trained models here ------------------------\n", - "\n", - "\n", - "\n", - "# --------------------- Check the model exist ------------------------\n", - "# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n", - " if not os.path.exists(h5_file_path):\n", - " print(bcolors.WARNING+'WARNING: weights_last.h5 pretrained model does not exist' + W)\n", - " Use_pretrained_model = False\n", - "\n", - " \n", - "# If the model path contains a pretrain model, we load the training rate, \n", - " if os.path.exists(h5_file_path):\n", - "#Here we check if the learning rate can be loaded from the quality control folder\n", - " if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n", - "\n", - " with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n", - " csvRead = pd.read_csv(csvfile, sep=',')\n", - " #print(csvRead)\n", - " \n", - " if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n", - " print(\"pretrained network learning rate found\")\n", - " #find the last learning rate\n", - " lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n", - " #Find the learning rate corresponding to the lowest validation loss\n", - " min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n", - " #print(min_val_loss)\n", - " bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n", - "\n", - " if Weights_choice == \"last\":\n", - " print('Last learning rate: '+str(lastLearningRate))\n", - "\n", - " if Weights_choice == \"best\":\n", - " print('Learning rate of best validation loss: '+str(bestLearningRate))\n", - "\n", - " if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n", - " bestLearningRate = initial_learning_rate\n", - " lastLearningRate = initial_learning_rate\n", - " print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n", - "\n", - "#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n", - " if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n", - " print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n", - " bestLearningRate = initial_learning_rate\n", - " lastLearningRate = initial_learning_rate\n", - "\n", - "\n", - "# Display info about the pretrained model to be loaded (or not)\n", - "if Use_pretrained_model:\n", - " print('Weights found in:')\n", - " print(h5_file_path)\n", - " print('will be loaded prior to training.')\n", - "\n", - "else:\n", - " print(bcolors.WARNING+'No pretrained network will be used.')\n", - "\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "MCGklf1vZf2M" - }, - "source": [ - "#**4. Train the network**\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1KYOuygETJkT" - }, - "source": [ - "## **4.1. Prepare the training data and model for training**\n", - "---\n", - "Here, we use the information from 3. to build the model and convert the training data into a suitable format for training." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "lIUAOJ_LMv5E", - "cellView": "form" - }, - "source": [ - "#@markdown ##Create the model and dataset objects\n", - "\n", - "# --------------------- Here we delete the model folder if it already exist ------------------------\n", - "\n", - "if os.path.exists(model_path+'/'+model_name):\n", - " print(bcolors.WARNING +\"!! WARNING: Model folder already exists and has been removed !!\" + W)\n", - " shutil.rmtree(model_path+'/'+model_name)\n", - "\n", - "\n", - "# --------------------- Here we load the augmented data or the raw data ------------------------\n", - "\n", - "\n", - "Training_source_dir = Training_source\n", - "Training_target_dir = Training_target\n", - "# --------------------- ------------------------------------------------\n", - "\n", - "training_images_tiff=Training_source_dir+\"/*.tif\"\n", - "mask_images_tiff=Training_target_dir+\"/*.tif\"\n", - "\n", - "# this funtion imports training images and masks and sorts them suitable for the network\n", - "X = sorted(glob(training_images_tiff)) \n", - "Y = sorted(glob(mask_images_tiff)) \n", - "\n", - "# assert -funtion check that X and Y really have images. If not this cell raises an error\n", - "assert all(Path(x).name==Path(y).name for x,y in zip(X,Y))\n", - "\n", - "# Here we map the training dataset (images and masks).\n", - "X = list(map(imread,X))\n", - "Y = list(map(imread,Y))\n", - "n_channel = 1 if X[0].ndim == 2 else X[0].shape[-1]\n", - "\n", - "# --------------------- Piece of code to remove alpha channel in RGBA images ------------------------\n", - "\n", - "if n_channel == 4:\n", - " X[:] = [i[:,:,:3] for i in X]\n", - " print(\"The alpha channel has been removed\")\n", - " n_channel = 1 if X[0].ndim == 2 else X[0].shape[-1]\n", - "\n", - "\n", - "if not Use_Data_augmentation:\n", - " augmenter = None\n", - " \n", - "\n", - "#Normalize images and fill small label holes.\n", - "\n", - "if n_channel == 1:\n", - " axis_norm = (0,1) # normalize channels independently\n", - " print(\"Normalizing image channels independently\")\n", - "\n", - "\n", - "if n_channel > 1:\n", - " axis_norm = (0,1,2) # normalize channels jointly\n", - " print(\"Normalizing image channels jointly\") \n", - " sys.stdout.flush()\n", - "\n", - "X = [normalize(x,1,99.8,axis=axis_norm) for x in tqdm(X)]\n", - "Y = [fill_label_holes(y) for y in tqdm(Y)]\n", - "\n", - "#Here we split the your training dataset into training images (90 %) and validation images (10 %). \n", - "#It is advisable to use 10 % of your training dataset for validation. This ensures the truthfull validation error value. If only few validation images are used network may choose too easy or too challenging images for validation. \n", - "# split training data (images and masks) into training images and validation images.\n", - "assert len(X) > 1, \"not enough training data\"\n", - "rng = np.random.RandomState(42)\n", - "ind = rng.permutation(len(X))\n", - "n_val = max(1, int(round(percentage * len(ind))))\n", - "ind_train, ind_val = ind[:-n_val], ind[-n_val:]\n", - "X_val, Y_val = [X[i] for i in ind_val] , [Y[i] for i in ind_val]\n", - "X_trn, Y_trn = [X[i] for i in ind_train], [Y[i] for i in ind_train] \n", - "print('number of images: %3d' % len(X))\n", - "print('- training: %3d' % len(X_trn))\n", - "print('- validation: %3d' % len(X_val))\n", - "\n", - "# Use OpenCL-based computations for data generator during training (requires 'gputools')\n", - "# Currently always false for stability\n", - "use_gpu = False and gputools_available()\n", - "\n", - "#Here we ensure that our network has a minimal number of steps\n", - "\n", - "if (Use_Default_Advanced_Parameters) or (number_of_steps == 0): \n", - " # number_of_steps= (int(len(X)/batch_size)+1)\n", - " number_of_steps = int(Image_X*Image_Y/(patch_size*patch_size))*(int(len(X)/batch_size)+1)\n", - " if (Use_Data_augmentation):\n", - " augmentation_factor = Multiply_dataset_by\n", - " number_of_steps = number_of_steps * augmentation_factor\n", - "\n", - "\n", - "# --------------------- Using pretrained model ------------------------\n", - "#Here we ensure that the learning rate set correctly when using pre-trained models\n", - "if Use_pretrained_model:\n", - " if Weights_choice == \"last\":\n", - " initial_learning_rate = lastLearningRate\n", - "\n", - " if Weights_choice == \"best\": \n", - " initial_learning_rate = bestLearningRate\n", - "# --------------------- ---------------------- ------------------------\n", - "\n", - "\n", - "\n", - "conf = Config2D (\n", - " n_rays = n_rays,\n", - " use_gpu = use_gpu,\n", - " train_batch_size = batch_size,\n", - " n_channel_in = n_channel,\n", - " train_patch_size = (patch_size, patch_size),\n", - " grid = (grid_parameter, grid_parameter),\n", - " train_learning_rate = initial_learning_rate,\n", - ")\n", - "\n", - "# Here we create a model according to section 5.3.\n", - "model = StarDist2D(conf, name=model_name, basedir=model_path)\n", - "\n", - "# --------------------- Using pretrained model ------------------------\n", - "# Load the pretrained weights \n", - "if Use_pretrained_model:\n", - " model.load_weights(h5_file_path)\n", - "\n", - "\n", - "# --------------------- ---------------------- ------------------------\n", - "\n", - "#Here we check the FOV of the network.\n", - "median_size = calculate_extents(list(Y), np.median)\n", - "fov = np.array(model._axes_tile_overlap('YX'))\n", - "if any(median_size > fov):\n", - " print(bcolors.WARNING+\"WARNING: median object size larger than field of view of the neural network.\")\n", - "print(conf)\n", - "\n", - "pdf_export(augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)\n", - "\n", - "print(\"Number of steps: \"+str(number_of_steps))\n", - "\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "0Dfn8ZsEMv5d" - }, - "source": [ - "## **4.2. Start Training**\n", - "---\n", - "\n", - "When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n", - "\n", - "* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point.\n", - "\n", - "**Of Note:** At the end of the training, your model will be automatically exported so it can be used in the Stardist Fiji plugin. You can find it in your model folder (TF_SavedModel.zip). In Fiji, Make sure to choose the right version of tensorflow. You can check at: Edit-- Options-- Tensorflow. Choose the version 1.4 (CPU or GPU depending on your system).\n", - "\n", - "Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder." - ] - }, - { - "cell_type": "code", - "metadata": { - "scrolled": true, - "id": "iwNmp1PUzRDQ", - "cellView": "form" - }, - "source": [ - "#@markdown ##Start training\n", - "\n", - "start = time.time()\n", - "\n", - "\n", - "history = model.train(X_trn, Y_trn, validation_data=(X_val,Y_val), augmenter=augmenter,\n", - " epochs=number_of_epochs, steps_per_epoch=number_of_steps)\n", - "None;\n", - "\n", - "print(\"Training done\")\n", - "\n", - "print(\"Network optimization in progress\")\n", - "#Here we optimize the network.\n", - "model.optimize_thresholds(X_val, Y_val)\n", - "\n", - "print(\"Done\")\n", - "\n", - "# convert the history.history dict to a pandas DataFrame: \n", - "lossData = pd.DataFrame(history.history) \n", - "\n", - "if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n", - " shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n", - "\n", - "os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n", - "\n", - "# The training evaluation.csv is saved (overwrites the Files if needed). \n", - "lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n", - "with open(lossDataCSVpath, 'w') as f:\n", - " writer = csv.writer(f)\n", - " writer.writerow(['loss','val_loss', 'learning rate'])\n", - " for i in range(len(history.history['loss'])):\n", - " writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n", - "\n", - "\n", - "\n", - "# Displaying the time elapsed for training\n", - "dt = time.time() - start\n", - "mins, sec = divmod(dt, 60) \n", - "hour, mins = divmod(mins, 60) \n", - "print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n", - "\n", - "model.export_TF()\n", - "\n", - "print(\"Your model has been sucessfully exported and can now also be used in the Stardist Fiji plugin\")\n", - "\n", - "pdf_export(trained=True, augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)\n", - "\n", - "#Create a pdf document with training summary" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "_0Hynw3-xHp1" - }, - "source": [ - "# **5. Evaluate your model**\n", - "---\n", - "\n", - "This section allows you to perform important quality checks on the validity and generalisability of the trained model. \n", - "\n", - "\n", - "**We highly recommend to perform quality control on all newly trained models.**\n", - "\n", - "\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "eAJzMwPA6tlH", - "cellView": "form" - }, - "source": [ - "# model name and path\n", - "#@markdown ###Do you want to assess the model you just trained ?\n", - "Use_the_current_trained_model = True #@param {type:\"boolean\"}\n", - "\n", - "#@markdown ###If not, please provide the path to the model folder:\n", - "\n", - "QC_model_folder = \"\" #@param {type:\"string\"}\n", - "\n", - "#Here we define the loaded model name and path\n", - "QC_model_name = os.path.basename(QC_model_folder)\n", - "QC_model_path = os.path.dirname(QC_model_folder)\n", - "\n", - "if (Use_the_current_trained_model): \n", - " QC_model_name = model_name\n", - " QC_model_path = model_path\n", - "\n", - "full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n", - "if os.path.exists(full_QC_model_path):\n", - " print(\"The \"+QC_model_name+\" network will be evaluated\")\n", - "else: \n", - " print(bcolors.WARNING+'!! WARNING: The chosen model does not exist !!')\n", - " print('Please make sure you provide a valid model path and model name before proceeding further.')\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "dhJROwlAMv5o" - }, - "source": [ - "## **5.1. Inspection of the loss function**\n", - "---\n", - "\n", - "First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n", - "\n", - "**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n", - "\n", - "**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n", - "\n", - "During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n", - "\n", - "Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased.\n", - "\n", - "\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "vMzSP50kMv5p", - "cellView": "form" - }, - "source": [ - "#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n", - "\n", - "lossDataFromCSV = []\n", - "vallossDataFromCSV = []\n", - "\n", - "with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n", - " csvRead = csv.reader(csvfile, delimiter=',')\n", - " next(csvRead)\n", - " for row in csvRead:\n", - " lossDataFromCSV.append(float(row[0]))\n", - " vallossDataFromCSV.append(float(row[1]))\n", - "\n", - "epochNumber = range(len(lossDataFromCSV))\n", - "plt.figure(figsize=(15,10))\n", - "\n", - "plt.subplot(2,1,1)\n", - "plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n", - "plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n", - "plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n", - "plt.ylabel('Loss')\n", - "plt.xlabel('Epoch number')\n", - "plt.legend()\n", - "\n", - "plt.subplot(2,1,2)\n", - "plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n", - "plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n", - "plt.title('Training loss and validation loss vs. epoch number (log scale)')\n", - "plt.ylabel('Loss')\n", - "plt.xlabel('Epoch number')\n", - "plt.legend()\n", - "plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png',bbox_inches='tight',pad_inches=0)\n", - "plt.show()\n", - "\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "X5_92nL2xdP6" - }, - "source": [ - "## **5.2. Error mapping and quality metrics estimation**\n", - "---\n", - "This section will calculate the Intersection over Union score for all the images provided in the Source_QC_folder and Target_QC_folder ! The result for one of the image will also be displayed.\n", - "\n", - "The **Intersection over Union** (IuO) metric is a method that can be used to quantify the percent overlap between the target mask and your prediction output. **Therefore, the closer to 1, the better the performance.** This metric can be used to assess the quality of your model to accurately predict nuclei. \n", - "\n", - "Here, the IuO is both calculated over the whole image and on a per-object basis. The value displayed below is the IuO value calculated over the entire image. The IuO value calculated on a per-object basis is used to calculate the other metrics displayed.\n", - "\n", - "“n_true” refers to the number of objects present in the ground truth image. “n_pred” refers to the number of objects present in the predicted image. \n", - "\n", - "When a segmented object has an IuO value above 0.5 (compared to the corresponding ground truth), it is then considered a true positive. The number of “**true positives**” is available in the table below. The number of “false positive” is then defined as “**false positive**” = “n_pred” - “true positive”. The number of “false negative” is defined as “false negative” = “n_true” - “true positive”.\n", - "\n", - "The mean_matched_score is the mean IoUs of matched true positives. The mean_true_score is the mean IoUs of matched true positives but normalized by the total number of ground truth objects. The panoptic_quality is calculated as described by [Kirillov et al. 2019](https://arxiv.org/abs/1801.00868).\n", - "\n", - "For more information about the other metric displayed, please consult the SI of the paper describing ZeroCostDL4Mic.\n", - "\n", - " The results can be found in the \"*Quality Control*\" folder which is located inside your \"model_folder\"." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "w90MdriMxhjD", - "cellView": "form" - }, - "source": [ - "#@markdown ##Choose the folders that contain your Quality Control dataset\n", - "\n", - "from stardist.matching import matching\n", - "from stardist.plot import render_label, render_label_pred \n", - "\n", - "Source_QC_folder = \"\" #@param{type:\"string\"}\n", - "Target_QC_folder = \"\" #@param{type:\"string\"}\n", - "\n", - "\n", - "#Create a quality control Folder and check if the folder already exist\n", - "if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\") == False:\n", - " os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n", - "\n", - "if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"):\n", - " shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n", - "\n", - "os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n", - "\n", - "\n", - "# Generate predictions from the Source_QC_folder and save them in the QC folder\n", - "\n", - "Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\n", - "\n", - "np.random.seed(16)\n", - "lbl_cmap = random_label_cmap()\n", - "Z = sorted(glob(Source_QC_folder_tif))\n", - "Z = list(map(imread,Z))\n", - "n_channel = 1 if Z[0].ndim == 2 else Z[0].shape[-1]\n", - "\n", - "print('Number of test dataset found in the folder: '+str(len(Z)))\n", - " \n", - "#Normalize images.\n", - "\n", - "if n_channel == 1:\n", - " axis_norm = (0,1) # normalize channels independently\n", - " print(\"Normalizing image channels independently\")\n", - "\n", - "if n_channel > 1:\n", - " axis_norm = (0,1,2) # normalize channels jointly\n", - " print(\"Normalizing image channels jointly\") \n", - "\n", - "model = StarDist2D(None, name=QC_model_name, basedir=QC_model_path)\n", - "\n", - "names = [os.path.basename(f) for f in sorted(glob(Source_QC_folder_tif))]\n", - "\n", - " \n", - "# modify the names to suitable form: path_images/image_numberX.tif\n", - " \n", - "lenght_of_Z = len(Z)\n", - " \n", - "for i in range(lenght_of_Z):\n", - " img = normalize(Z[i], 1,99.8, axis=axis_norm)\n", - " labels, polygons = model.predict_instances(img)\n", - " os.chdir(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n", - " imsave(names[i], labels, polygons)\n", - "\n", - "# Here we start testing the differences between GT and predicted masks\n", - "\n", - "with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Quality_Control for \"+QC_model_name+\".csv\", \"w\", newline='') as file:\n", - " writer = csv.writer(file, delimiter=\",\")\n", - " writer.writerow([\"image\",\"Prediction v. GT Intersection over Union\", \"false positive\", \"true positive\", \"false negative\", \"precision\", \"recall\", \"accuracy\", \"f1 score\", \"n_true\", \"n_pred\", \"mean_true_score\", \"mean_matched_score\", \"panoptic_quality\"]) \n", - "\n", - "# define the images\n", - "\n", - " for n in os.listdir(Source_QC_folder):\n", - " \n", - " if not os.path.isdir(os.path.join(Source_QC_folder,n)):\n", - " print('Running QC on: '+n)\n", - " test_input = io.imread(os.path.join(Source_QC_folder,n))\n", - " test_prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\",n))\n", - " test_ground_truth_image = io.imread(os.path.join(Target_QC_folder, n))\n", - "\n", - " # Calculate the matching (with IoU threshold `thresh`) and all metrics\n", - " \n", - " stats = matching(test_ground_truth_image, test_prediction, thresh=0.5)\n", - "\n", - " #Convert pixel values to 0 or 255\n", - " test_prediction_0_to_255 = test_prediction\n", - " test_prediction_0_to_255[test_prediction_0_to_255>0] = 255\n", - "\n", - " #Convert pixel values to 0 or 255\n", - " test_ground_truth_0_to_255 = test_ground_truth_image\n", - " test_ground_truth_0_to_255[test_ground_truth_0_to_255>0] = 255\n", - "\n", - "\n", - " # Intersection over Union metric\n", - "\n", - " intersection = np.logical_and(test_ground_truth_0_to_255, test_prediction_0_to_255)\n", - " union = np.logical_or(test_ground_truth_0_to_255, test_prediction_0_to_255)\n", - " iou_score = np.sum(intersection) / np.sum(union)\n", - " writer.writerow([n, str(iou_score), str(stats.fp), str(stats.tp), str(stats.fn), str(stats.precision), str(stats.recall), str(stats.accuracy), str(stats.f1), str(stats.n_true), str(stats.n_pred), str(stats.mean_true_score), str(stats.mean_matched_score), str(stats.panoptic_quality)])\n", - "\n", - "from tabulate import tabulate\n", - "\n", - "df = pd.read_csv (QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Quality_Control for \"+QC_model_name+\".csv\")\n", - "print(tabulate(df, headers='keys', tablefmt='psql'))\n", - "\n", - "\n", - "from astropy.visualization import simple_norm\n", - "\n", - "# ------------- For display ------------\n", - "print('--------------------------------------------------------------')\n", - "@interact\n", - "def show_QC_results(file = os.listdir(Source_QC_folder)):\n", - " \n", - "\n", - " plt.figure(figsize=(25,5))\n", - " if n_channel > 1:\n", - " source_image = io.imread(os.path.join(Source_QC_folder, file))\n", - " if n_channel == 1:\n", - " source_image = io.imread(os.path.join(Source_QC_folder, file), as_gray = True)\n", - "\n", - " target_image = io.imread(os.path.join(Target_QC_folder, file), as_gray = True)\n", - " prediction = io.imread(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction/\"+file, as_gray = True)\n", - "\n", - " stats = matching(prediction, target_image, thresh=0.5)\n", - "\n", - " target_image_mask = np.empty_like(target_image)\n", - " target_image_mask[target_image > 0] = 255\n", - " target_image_mask[target_image == 0] = 0\n", - " \n", - " prediction_mask = np.empty_like(prediction)\n", - " prediction_mask[prediction > 0] = 255\n", - " prediction_mask[prediction == 0] = 0\n", - "\n", - " intersection = np.logical_and(target_image_mask, prediction_mask)\n", - " union = np.logical_or(target_image_mask, prediction_mask)\n", - " iou_score = np.sum(intersection) / np.sum(union)\n", - "\n", - " norm = simple_norm(source_image, percent = 99)\n", - "\n", - " #Input\n", - " plt.subplot(1,4,1)\n", - " plt.axis('off')\n", - " if n_channel > 1:\n", - " plt.imshow(source_image)\n", - " if n_channel == 1:\n", - " plt.imshow(source_image, aspect='equal', norm=norm, cmap='magma', interpolation='nearest')\n", - " plt.title('Input')\n", - "\n", - " #Ground-truth\n", - " plt.subplot(1,4,2)\n", - " plt.axis('off')\n", - " plt.imshow(target_image_mask, aspect='equal', cmap='Greens')\n", - " plt.title('Ground Truth')\n", - "\n", - " #Prediction\n", - " plt.subplot(1,4,3)\n", - " plt.axis('off')\n", - " plt.imshow(prediction_mask, aspect='equal', cmap='Purples')\n", - " plt.title('Prediction')\n", - "\n", - " #Overlay\n", - " plt.subplot(1,4,4)\n", - " plt.axis('off')\n", - " plt.imshow(target_image_mask, cmap='Greens')\n", - " plt.imshow(prediction_mask, alpha=0.5, cmap='Purples')\n", - " plt.title('Ground Truth and Prediction, Intersection over Union:'+str(round(iou_score,3 )));\n", - " plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n", - "\n", - "\n", - "qc_pdf_export()" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-tJeeJjLnRkP" - }, - "source": [ - "# **6. Using the trained model**\n", - "\n", - "---\n", - "\n", - "In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "d8wuQGjoq6eN" - }, - "source": [ - "\n", - "\n", - "## **6.1 Generate prediction(s) from unseen dataset**\n", - "---\n", - "\n", - "In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive.\n", - "\n", - "---\n", - "\n", - "The current trained model (from section 4.3) can now be used to process images. If an older model needs to be used, please untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Prediction_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n", - "\n", - "**`Data_folder`:** This folder should contains the images that you want to predict using the network that you will train.\n", - "\n", - "**`Result_folder`:** This folder will contain the predicted output ROI.\n", - "\n", - "**`Data_type`:** Please indicate if the images you want to predict are single images or stacks\n", - "\n", - "\n", - "In stardist the following results can be exported:\n", - "- Region of interest (ROI) that can be opened in ImageJ / Fiji. The ROI are saved inside of a .zip file in your choosen result folder. To open the ROI in Fiji, just drag and drop the zip file !**\n", - "- The predicted mask images\n", - "- A tracking file that can easily be imported into Trackmate to track the nuclei.\n", - "- A CSV file that contains the number of nuclei detected per image. \n", - "- A CSV file that contains the coordinate the centre of each detected nuclei (single image only). \n", - "\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "y2TD5p7MZrEb", - "cellView": "form" - }, - "source": [ - "Single_Images = 1\n", - "Stacks = 2\n", - "\n", - "#@markdown ### Provide the path to your dataset and to the folder where the prediction will be saved (Result folder), then play the cell to predict output on your unseen images.\n", - "\n", - "Data_folder = \"\" #@param {type:\"string\"}\n", - "Results_folder = \"\" #@param {type:\"string\"}\n", - "\n", - "#@markdown ###Are your data single images or stacks?\n", - "\n", - "Data_type = Single_Images #@param [\"Single_Images\", \"Stacks\"] {type:\"raw\"}\n", - "\n", - "#@markdown ###What outputs would you like to generate?\n", - "Region_of_interests = True #@param {type:\"boolean\"}\n", - "Mask_images = True #@param {type:\"boolean\"}\n", - "Tracking_file = False #@param {type:\"boolean\"}\n", - "\n", - "\n", - "# model name and path\n", - "#@markdown ###Do you want to use the current trained model?\n", - "Use_the_current_trained_model = True #@param {type:\"boolean\"}\n", - "\n", - "#@markdown ###If not, please provide the path to the model folder:\n", - "\n", - "Prediction_model_folder = \"\" #@param {type:\"string\"}\n", - "\n", - "#Here we find the loaded model name and parent path\n", - "Prediction_model_name = os.path.basename(Prediction_model_folder)\n", - "Prediction_model_path = os.path.dirname(Prediction_model_folder)\n", - "\n", - "if (Use_the_current_trained_model): \n", - " print(\"Using current trained network\")\n", - " Prediction_model_name = model_name\n", - " Prediction_model_path = model_path\n", - "\n", - "full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n", - "if os.path.exists(full_Prediction_model_path):\n", - " print(\"The \"+Prediction_model_name+\" network will be used.\")\n", - "else:\n", - " print(bcolors.WARNING+'!! WARNING: The chosen model does not exist !!'+W)\n", - " print('Please make sure you provide a valid model path and model name before proceeding further.')\n", - "\n", - "#single images\n", - "\n", - "if Data_type == 1 :\n", - "\n", - " Data_folder = Data_folder+\"/*.tif\"\n", - "\n", - " print(\"Single images are now beeing predicted\")\n", - " np.random.seed(16)\n", - " lbl_cmap = random_label_cmap()\n", - " X = sorted(glob(Data_folder))\n", - " X = list(map(imread,X))\n", - " n_channel = 1 if X[0].ndim == 2 else X[0].shape[-1]\n", - " \n", - " # axis_norm = (0,1,2) # normalize channels jointly\n", - " if n_channel == 1:\n", - " axis_norm = (0,1) # normalize channels independently\n", - " print(\"Normalizing image channels independently\")\n", - "\n", - " if n_channel > 1:\n", - " axis_norm = (0,1,2) # normalize channels jointly\n", - " print(\"Normalizing image channels jointly\") \n", - " sys.stdout.flush() \n", - " \n", - " model = StarDist2D(None, name = Prediction_model_name, basedir = Prediction_model_path)\n", - " \n", - " names = [os.path.basename(f) for f in sorted(glob(Data_folder))] \n", - " Nuclei_number = []\n", - "\n", - " # modify the names to suitable form: path_images/image_numberX.tif\n", - " FILEnames = []\n", - " for m in names:\n", - " m = Results_folder+'/'+m\n", - " FILEnames.append(m)\n", - "\n", - " # Create a list of name with no extension\n", - " \n", - " name_no_extension=[]\n", - " for n in names:\n", - " name_no_extension.append(os.path.splitext(n)[0])\n", - " \n", - " # Save all ROIs and masks into results folder\n", - " \n", - " for i in range(len(X)):\n", - " img = normalize(X[i], 1,99.8, axis = axis_norm)\n", - " labels, polygons = model.predict_instances(img)\n", - " \n", - " os.chdir(Results_folder)\n", - "\n", - " if Mask_images:\n", - " imsave(FILEnames[i], labels, polygons)\n", - "\n", - " if Region_of_interests:\n", - " export_imagej_rois(name_no_extension[i], polygons['coord'])\n", - "\n", - " if Tracking_file:\n", - " Tracking_image = np.zeros((img.shape[1], img.shape[0]))\n", - " for point in polygons['points']:\n", - " cv2.circle(Tracking_image,tuple(point),0,(1), -1)\n", - " \n", - " Tracking_image_32 = img_as_float32(Tracking_image, force_copy=False)\n", - " Tracking_image_8 = img_as_ubyte(Tracking_image, force_copy=True) \n", - " Tracking_image_8_rot = np.rot90(Tracking_image_8, axes=(0, 1))\n", - " Tracking_image_8_rot_flip = np.flipud(Tracking_image_8_rot)\n", - " imsave(Results_folder+\"/\"+str(name_no_extension[i])+\"_tracking_file.tif\", Tracking_image_8_rot_flip, compress=ZIP_DEFLATED)\n", - " \n", - " Nuclei_centre_coordinate = polygons['points']\n", - " my_df2 = pd.DataFrame(Nuclei_centre_coordinate)\n", - " my_df2.columns =['Y', 'X']\n", - " \n", - " my_df2.to_csv(Results_folder+'/'+name_no_extension[i]+'_Nuclei_centre.csv', index=False, header=True)\n", - "\n", - " Nuclei_array = polygons['coord']\n", - " Nuclei_array2 = [names[i], Nuclei_array.shape[0]]\n", - " Nuclei_number.append(Nuclei_array2) \n", - "\n", - " my_df = pd.DataFrame(Nuclei_number)\n", - " my_df2.columns =['Frame number', 'Number of objects']\n", - " my_df.to_csv(Results_folder+'/Nuclei_count.csv', index=False, header=False)\n", - " \n", - "# One example is displayed\n", - "\n", - " print(\"One example image is displayed bellow:\")\n", - " plt.figure(figsize=(10,10))\n", - " plt.imshow(img if img.ndim==2 else img[...,:3], clim=(0,1), cmap='gray')\n", - " plt.imshow(labels, cmap=lbl_cmap, alpha=0.5)\n", - " plt.axis('off');\n", - " plt.savefig(name_no_extension[i]+\"_overlay.tif\")\n", - "\n", - "# Here is the code to analyse stacks\n", - "\n", - "if Data_type == 2 :\n", - " print(\"Stacks are now beeing predicted\")\n", - " np.random.seed(42)\n", - " lbl_cmap = random_label_cmap()\n", - "\n", - " # normalize channels independently\n", - " axis_norm = (0,1) \n", - " \n", - " model = StarDist2D(None, name = Prediction_model_name, basedir = Prediction_model_path)\n", - " \n", - " for image in os.listdir(Data_folder):\n", - " print(\"Performing prediction on: \"+image)\n", - "\n", - " Number_of_nuclei_list = []\n", - " Number_of_frame_list = []\n", - "\n", - " timelapse = imread(Data_folder+\"/\"+image)\n", - "\n", - " short_name = os.path.splitext(image) \n", - " \n", - " timelapse = normalize(timelapse, 1,99.8, axis=(0,)+tuple(1+np.array(axis_norm)))\n", - " \n", - "\n", - " if Region_of_interests: \n", - " polygons = [model.predict_instances(frame)[1]['coord'] for frame in tqdm(timelapse)] \n", - " export_imagej_rois(Results_folder+\"/\"+str(short_name[0]), polygons, compression=ZIP_DEFLATED) \n", - " \n", - " n_timepoint = timelapse.shape[0]\n", - " prediction_stack = np.zeros((n_timepoint, timelapse.shape[1], timelapse.shape[2]))\n", - " Tracking_stack = np.zeros((n_timepoint, timelapse.shape[2], timelapse.shape[1]))\n", - "\n", - "# Analyse each time points one after the other\n", - " if Mask_images or Tracking_file:\n", - " for t in range(n_timepoint):\n", - " img_t = timelapse[t]\n", - " labels, polygons = model.predict_instances(img_t) \n", - " prediction_stack[t] = labels\n", - " Nuclei_array = polygons['coord']\n", - " Nuclei_array2 = [str(t), Nuclei_array.shape[0]]\n", - " Number_of_nuclei_list.append(Nuclei_array2)\n", - " Number_of_frame_list.append(t)\n", - "\n", - "# Create a tracking file for trackmate\n", - "\n", - " for point in polygons['points']:\n", - " cv2.circle(Tracking_stack[t],tuple(point),0,(1), -1)\n", - "\n", - " prediction_stack_32 = img_as_float32(prediction_stack, force_copy=False)\n", - " Tracking_stack_32 = img_as_float32(Tracking_stack, force_copy=False)\n", - " Tracking_stack_8 = img_as_ubyte(Tracking_stack_32, force_copy=True)\n", - " \n", - " Tracking_stack_8_rot = np.rot90(Tracking_stack_8, axes=(1,2))\n", - " Tracking_stack_8_rot_flip = np.fliplr(Tracking_stack_8_rot)\n", - "\n", - "# Export a csv file containing the number of nuclei detected at each frame\n", - " my_df = pd.DataFrame(Number_of_nuclei_list)\n", - " my_df.to_csv(Results_folder+'/'+str(short_name[0])+'_Nuclei_number.csv', index=False, header=False)\n", - "\n", - " os.chdir(Results_folder)\n", - " if Mask_images:\n", - " imsave(str(short_name[0])+\".tif\", prediction_stack_32, compress=ZIP_DEFLATED)\n", - " if Tracking_file:\n", - " imsave(str(short_name[0])+\"_tracking_file.tif\", Tracking_stack_8_rot_flip, compress=ZIP_DEFLATED)\n", - "\n", - " # Object detected vs frame number\n", - " plt.figure(figsize=(20,5))\n", - " my_df.plot()\n", - " plt.title('Number of objects vs frame number')\n", - " plt.ylabel('Number of detected objects')\n", - " plt.xlabel('Frame number')\n", - " plt.legend()\n", - " plt.savefig(Results_folder+'/'+str(short_name[0])+'_Object_detected_vs_frame_number.png',bbox_inches='tight',pad_inches=0)\n", - " plt.show() \n", - "\n", - "print(\"Predictions completed\") " - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "KjGHBGmxlk9B" - }, - "source": [ - "## **6.2. Generate prediction(s) from unseen dataset (Big data)**\n", - "---\n", - "\n", - "You can use this section of the notebook to generate predictions on very large images. Compatible file formats include .Tif and .svs files.\n", - "\n", - "**`Data_folder`:** This folder should contains the images that you want to predict using the network that you trained.\n", - "\n", - "**`Result_folder`:** This folder will contain the predicted output ROI.\n", - "\n", - "\n", - "In stardist the following results can be exported:\n", - "- Region of interest (ROI) that can be opened in ImageJ / Fiji. The ROI are saved inside of a .zip file in your choosen result folder. To open the ROI in Fiji, just drag and drop the zip file ! IMPORTANT: ROI files cannot be exported for extremely large images.\n", - "- The predicted mask images\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "jxjHeOFFleSV", - "cellView": "form" - }, - "source": [ - "#@markdown ### Provide the path to your dataset and to the folder where the prediction will be saved (Result folder), then play the cell to predict output on your unseen images.\n", - "\n", - "\n", - "start = time.time()\n", - "\n", - "\n", - "Data_folder = \"\" #@param {type:\"string\"}\n", - "\n", - "Results_folder = \"\" #@param {type:\"string\"}\n", - "\n", - "# model name and path\n", - "#@markdown ###Do you want to use the current trained model?\n", - "Use_the_current_trained_model = True #@param {type:\"boolean\"}\n", - "\n", - "\n", - "#@markdown ###If not, please provide the path to the model folder:\n", - "\n", - "Prediction_model_folder = \"\" #@param {type:\"string\"}\n", - "\n", - "#Here we find the loaded model name and parent path\n", - "Prediction_model_name = os.path.basename(Prediction_model_folder)\n", - "Prediction_model_path = os.path.dirname(Prediction_model_folder)\n", - "\n", - "\n", - "#@markdown #####To analyse very large image, your images need to be divided into blocks. Each blocks will then be processed independently and re-assembled to generate the final image. \n", - "#@markdown #####Here you can choose the dimension of the block.\n", - "\n", - "block_size_Y = 1024#@param {type:\"number\"}\n", - "block_size_X = 1024#@param {type:\"number\"}\n", - "\n", - "\n", - "#@markdown #####Here you can the amount of overlap between each block.\n", - "min_overlap = 50#@param {type:\"number\"}\n", - "\n", - "#@markdown #####To analyse large blocks, your blocks need to be divided into tiles. Each tile will then be processed independently and re-assembled to generate the final block. \n", - "\n", - "n_tiles_Y = 1#@param {type:\"number\"}\n", - "n_tiles_X = 1#@param {type:\"number\"}\n", - "\n", - "\n", - "#@markdown ###What outputs would you like to generate? \n", - "Mask_images = True #@param {type:\"boolean\"}\n", - "\n", - "Region_of_interests = True #@param {type:\"boolean\"}\n", - "\n", - "if (Use_the_current_trained_model): \n", - " print(\"Using current trained network\")\n", - " Prediction_model_name = model_name\n", - " Prediction_model_path = model_path\n", - "\n", - "full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n", - "\n", - "if os.path.exists(full_Prediction_model_path):\n", - " print(\"The \"+Prediction_model_name+\" network will be used.\")\n", - "else:\n", - " W = '\\033[0m' # white (normal)\n", - " R = '\\033[31m' # red\n", - " print(R+'!! WARNING: The chosen model does not exist !!'+W)\n", - " print('Please make sure you provide a valid model path and model name before proceeding further.')\n", - "\n", - "\n", - "#Create a temp folder to save Zarr files\n", - "\n", - "Temp_folder = \"/content/Temp_folder\"\n", - "\n", - "if os.path.exists(Temp_folder):\n", - " shutil.rmtree(Temp_folder)\n", - "os.makedirs(Temp_folder)\n", - "\n", - "\n", - "# mi, ma = np.percentile(img[::8], [1,99.8]) # compute percentiles from low-resolution image\n", - "# mi, ma = np.percentile(img[13000:16000,13000:16000], [1,99.8]) # compute percentiles from smaller crop\n", - "mi, ma = 0, 255 # use min and max dtype values (suitable here)\n", - "normalizer = MyNormalizer(mi, ma)\n", - "\n", - "\n", - "np.random.seed(16)\n", - "lbl_cmap = random_label_cmap()\n", - "\n", - "#Load the StarDist model\n", - "\n", - "model = StarDist2D(None, name=Prediction_model_name, basedir=Prediction_model_path)\n", - "\n", - "\n", - "for image in os.listdir(Data_folder):\n", - " print(\"Performing prediction on: \"+image)\n", - "\n", - " X = imread(Data_folder+\"/\"+image)\n", - "\n", - " print(\"Image dimension \"+str(X.shape))\n", - "\n", - " short_name = os.path.splitext(image)\n", - "\n", - " n_channel = 1 if X.ndim == 2 else X.shape[-1]\n", - " \n", - " # axis_norm = (0,1,2) # normalize channels jointly\n", - " if n_channel == 1:\n", - " axis_norm = (0,1) # normalize channels independently\n", - " print(\"Normalizing image channels independently\")\n", - " block_size = (block_size_Y, block_size_X)\n", - " min_overlap = (min_overlap, min_overlap)\n", - " n_tiles = (n_tiles_Y, n_tiles_X)\n", - " axes=\"YX\"\n", - "\n", - " if n_channel > 1:\n", - " axis_norm = (0,1,2) # normalize channels jointly\n", - " print(\"Normalizing image channels jointly\")\n", - " axes=\"YXC\"\n", - " block_size = (block_size_Y, block_size_X, 3)\n", - " n_tiles = (n_tiles_Y, n_tiles_X, 1)\n", - " min_overlap = (min_overlap, min_overlap, 0) \n", - " sys.stdout.flush()\n", - " \n", - " zarr.save_array(str(Temp_folder+\"/image.zarr\"), X)\n", - " del X\n", - " img = zarr.open(str(Temp_folder+\"/image.zarr\"), mode='r')\n", - " \n", - " labels = zarr.open(str(Temp_folder+\"/labels.zarr\"), mode='w', shape=img.shape[:3], chunks=img.chunks[:3], dtype=np.int32)\n", - " \n", - " \n", - " labels, polygons = model.predict_instances_big(img, axes=axes, block_size=block_size, min_overlap=min_overlap, context=None,\n", - " normalizer=normalizer, show_progress=True, n_tiles=n_tiles)\n", - " \n", - "# Save the predicted mask in the result folder\n", - " os.chdir(Results_folder)\n", - " if Mask_images:\n", - " imsave(str(short_name[0])+\".tif\", labels, compress=ZIP_DEFLATED)\n", - " if Region_of_interests: \n", - " export_imagej_rois(str(short_name[0])+'labels_roi.zip', polygons['coord'], compression=ZIP_DEFLATED)\n", - "\n", - " del img\n", - "\n", - "\n", - "# Displaying the time elapsed for training\n", - "dt = time.time() - start\n", - "mins, sec = divmod(dt, 60) \n", - "hour, mins = divmod(mins, 60) \n", - "print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n", - "\n", - "# One example image \n", - "\n", - "fig, (a,b) = plt.subplots(1,2, figsize=(20,20))\n", - "a.imshow(labels[::8,::8], cmap='tab20b')\n", - "b.imshow(labels[::8,::8], cmap=lbl_cmap)\n", - "a.axis('off'); b.axis('off');\n", - "None;" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "hvkd66PldsXB" - }, - "source": [ - "## **6.3. Download your predictions**\n", - "---\n", - "\n", - "**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "UvSlTaH14s3t" - }, - "source": [ - "\n", - "#**Thank you for using StarDist 2D!**" - ] - } - ] -} \ No newline at end of file +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"StarDist_2D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1A26cn0nxWQCv-LuP3UBfyCWlKBGIo0RU","timestamp":1610969691998},{"file_id":"1MmLTCC0nyX3Akb9V4C_OVxM3X_M8u-eX","timestamp":1610543191319},{"file_id":"1paNjUObR5Rcr4BMGADJTz0PQBBLZDPrY","timestamp":1602522500580},{"file_id":"1WZRIoSBNcRUEq4-Rq5M4mDkIaOlEHnxz","timestamp":1588762142860},{"file_id":"10weAY0es-pEfHlACCaBCKK7PmgdoJqdh","timestamp":1587728072051},{"file_id":"10Ze0rFZoooyyTL_OIVWGdFJEhWE6_cSB","timestamp":1586789421439},{"file_id":"1SsGyUbWcMaLGHFepMuKElRNYLdEBUwf6","timestamp":1583244509550}],"collapsed_sections":[],"machine_shape":"hm"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.7"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"IkSguVy8Xv83"},"source":["# **StarDist (2D)**\n","---\n","\n","**StarDist 2D** is a deep-learning method that can be used to segment cell nuclei from bioimages and was first published by [Schmidt *et al.* in 2018, on arXiv](https://arxiv.org/abs/1806.03535). It uses a shape representation based on star-convex polygons for nuclei in an image to predict the presence and the shape of these nuclei. This StarDist 2D network is based on an adapted U-Net network architecture.\n","\n"," **This particular notebook enables nuclei segmentation of 2D dataset. If you are interested in 3D dataset, you should use the StarDist 3D notebook instead.**\n","\n","---\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the paper:\n","\n","**Cell Detection with Star-convex Polygons** from Schmidt *et al.*, International Conference on Medical Image Computing and Computer-Assisted Intervention (MICCAI), Granada, Spain, September 2018. (https://arxiv.org/abs/1806.03535)\n","\n","and the 3D extension of the approach:\n","\n","**Star-convex Polyhedra for 3D Object Detection and Segmentation in Microscopy** from Weigert *et al.* published on arXiv in 2019 (https://arxiv.org/abs/1908.03636)\n","\n","**The Original code** is freely available in GitHub:\n","/~https://github.com/mpicbg-csbd/stardist\n","\n","**Please also cite this original paper when using or developing this notebook.**\n"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"gKDLkLWUd-YX"},"source":["#**0. Before getting started**\n","---\n"," For StarDist to train, **it needs to have access to a paired training dataset made of images of nuclei and their corresponding masks**. Information on how to generate a training dataset is available in our Wiki page: /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model**. The quality control assessment can be done directly in this notebook.\n","\n","The data structure is important. It is necessary that all the input data are in the same folder and that all the output data is in a separate folder. The provided training dataset is already split in two folders called \"Training - Images\" (Training_source) and \"Training - Masks\" (Training_target).\n","\n","Additionally, the corresponding Training_source and Training_target files need to have **the same name**.\n","\n","Please note that you currently can **only use .tif files!**\n","\n","You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed. This can include Test dataset for which you have the equivalent output and can compare to what the network provides.\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Images of nuclei (Training_source)\n"," - img_1.tif, img_2.tif, ...\n"," - Masks (Training_target)\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - Images of nuclei\n"," - img_1.tif, img_2.tif\n"," - Masks \n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"AdN8B91xZO0x"},"source":["# **1. Install StarDist and dependencies**\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"XuwTHSva_Y5K"},"source":["## **1.1. Install key dependencies**\n","---\n"," "]},{"cell_type":"code","metadata":{"id":"fq21zJVFNASx","cellView":"form"},"source":["#@markdown ##Install StarDist and dependencies\n","\n","# Install packages which are not included in Google Colab\n","\n","!pip install tifffile # contains tools to operate tiff-files\n","!pip install csbdeep # contains tools for restoration of fluorescence microcopy images (Content-aware Image Restoration, CARE). It uses Keras and Tensorflow.\n","!pip install stardist # contains tools to operate STARDIST.\n","!pip install gputools # improves STARDIST performances\n","#!pip install edt # improves STARDIST performances\n","!pip install wget\n","!pip install fpdf\n","!pip install PTable # Nice tables \n","!pip install zarr\n","!pip install imagecodecs\n","\n","#Force session restart\n","exit(0)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"e_oQT-9180CX"},"source":["\n","## **1.2. Restart your runtime**\n","---\n","\n","\n","\n","** Ignore the following message error message. Your Runtime has automatically restarted. This is normal.**\n","\n","\"\"
\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"2uHigbVJ9CUh"},"source":["## **1.3. Load key dependencies**\n","---\n"," "]},{"cell_type":"code","metadata":{"id":"HMAdm-Mc9HFz","cellView":"form"},"source":["#@markdown ##Load key dependencies\n","\n","Notebook_version = '1.13'\n","Network = 'StarDist 2D'\n","\n","from builtins import any as b_any\n","\n","def get_requirements_path():\n"," # Store requirements file in 'contents' directory \n"," current_dir = os.getcwd()\n"," dir_count = current_dir.count('/') - 1\n"," path = '../' * (dir_count) + 'requirements.txt'\n"," return path\n","\n","def filter_files(file_list, filter_list):\n"," filtered_list = []\n"," for fname in file_list:\n"," if b_any(fname.split('==')[0] in s for s in filter_list):\n"," filtered_list.append(fname)\n"," return filtered_list\n","\n","def build_requirements_file(before, after):\n"," path = get_requirements_path()\n","\n"," # Exporting requirements.txt for local run\n"," !pip freeze > $path\n","\n"," # Get minimum requirements file\n"," df = pd.read_csv(path, delimiter = \"\\n\")\n"," mod_list = [m.split('.')[0] for m in after if not m in before]\n"," req_list_temp = df.values.tolist()\n"," req_list = [x[0] for x in req_list_temp]\n","\n"," # Replace with package name and handle cases where import name is different to module name\n"," mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]\n"," mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list] \n"," filtered_list = filter_files(req_list, mod_replace_list)\n","\n"," file=open(path,'w')\n"," for item in filtered_list:\n"," file.writelines(item + '\\n')\n","\n"," file.close()\n","\n","import sys\n","before = [str(m) for m in sys.modules]\n","\n","\n","#%load_ext memory_profiler\n","\n","\n","%tensorflow_version 1.x\n","\n","import tensorflow\n","print(tensorflow.__version__)\n","print(\"Tensorflow enabled.\")\n","\n","\n","import imagecodecs\n","\n","# ------- Variable specific to Stardist -------\n","from stardist import fill_label_holes, random_label_cmap, calculate_extents, gputools_available, relabel_image_stardist, random_label_cmap, relabel_image_stardist, _draw_polygons, export_imagej_rois\n","from stardist.models import Config2D, StarDist2D, StarDistData2D # import objects\n","from stardist.matching import matching_dataset\n","from __future__ import print_function, unicode_literals, absolute_import, division\n","from csbdeep.utils import Path, normalize, download_and_extract_zip_file, plot_history # for loss plot\n","from csbdeep.io import save_tiff_imagej_compatible\n","import numpy as np\n","np.random.seed(42)\n","lbl_cmap = random_label_cmap()\n","%matplotlib inline\n","%config InlineBackend.figure_format = 'retina'\n","\n","from PIL import Image\n","import zarr\n","from zipfile import ZIP_DEFLATED\n","from csbdeep.data import Normalizer, normalize_mi_ma\n","import imagecodecs\n","\n","\n","class MyNormalizer(Normalizer):\n"," def __init__(self, mi, ma):\n"," self.mi, self.ma = mi, ma\n"," def before(self, x, axes):\n"," return normalize_mi_ma(x, self.mi, self.ma, dtype=np.float32)\n"," def after(*args, **kwargs):\n"," assert False\n"," @property\n"," def do_after(self):\n"," return False\n","\n","\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32, img_as_ubyte, img_as_float\n","from skimage.util import img_as_ubyte\n","from tqdm import tqdm \n","import cv2\n","from fpdf import FPDF, HTMLMixin\n","from datetime import datetime\n","from pip._internal.operations.freeze import freeze\n","import subprocess\n","\n","# For sliders and dropdown menu and progress bar\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print('------------------------------------------')\n","print(\"Libraries installed\")\n","\n","\n","# Check if this is the latest version of the notebook\n","All_notebook_versions = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_Notebook_versions.csv\", dtype=str)\n","print('Notebook version: '+Notebook_version)\n","Latest_Notebook_version = All_notebook_versions[All_notebook_versions[\"Notebook\"] == Network]['Version'].iloc[0]\n","print('Latest notebook version: '+Latest_Notebook_version)\n","if Notebook_version == Latest_Notebook_version:\n"," print(\"This notebook is up-to-date.\")\n","else:\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","\n","\n","# PDF export\n","\n","def pdf_export(trained=False, augmentation = False, pretrained_model = False):\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," \n"," # add another cell \n"," if trained:\n"," training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n"," pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n"," pdf.ln(1)\n","\n"," Header_2 = 'Information for your materials and method:'\n"," pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n"," #print(all_packages)\n","\n"," #Main Packages\n"," main_packages = ''\n"," version_numbers = []\n"," for name in ['tensorflow','numpy','Keras','csbdeep']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n"," cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n"," cuda_version = cuda_version.stdout.decode('utf-8')\n"," cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n"," gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n"," gpu_name = gpu_name.stdout.decode('utf-8')\n"," gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n"," #print(cuda_version[cuda_version.find(', V')+3:-1])\n"," #print(gpu_name)\n","\n"," shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n"," dataset_size = len(os.listdir(Training_source))\n","\n"," text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+conf.train_dist_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," #text = 'The '+Network+' model ('+model_name+') was trained using '+str(dataset_size)+' paired images (image dimensions: '+str(shape)+') using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The GPU used was a '+gpu_name+'.'\n","\n"," if pretrained_model:\n"," text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+conf.train_dist_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was retrained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.multi_cell(190, 5, txt = text, align='L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n"," pdf.set_font('')\n"," if augmentation:\n"," aug_text = 'The dataset was augmented by a factor of '+str(Multiply_dataset_by)\n"," \n"," else:\n"," aug_text = 'No augmentation was used for training.'\n"," pdf.multi_cell(190, 5, txt=aug_text, align='L')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," if Use_Default_Advanced_Parameters:\n"," pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n"," pdf.cell(200, 5, txt='The following parameters were used for training:')\n"," pdf.ln(1)\n"," html = \"\"\" \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
number_of_epochs{0}
patch_size{1}
batch_size{2}
number_of_steps{3}
percentage_validation{4}
n_rays{5}
grid_parameter{6}
initial_learning_rate{7}
\n"," \"\"\".format(number_of_epochs,str(patch_size)+'x'+str(patch_size),batch_size,number_of_steps,percentage_validation,n_rays,grid_parameter,initial_learning_rate)\n"," pdf.write_html(html)\n","\n"," #pdf.multi_cell(190, 5, txt = text_2, align='L')\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," pdf.ln(1)\n"," pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(30, 5, txt= 'Training_source:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(28, 5, txt= 'Training_target:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n"," #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(21, 5, txt= 'Model Path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n"," pdf.ln(1)\n"," pdf.cell(60, 5, txt = 'Example Training pair', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread('/content/TrainingDataExample_StarDist2D.png').shape\n"," pdf.image('/content/TrainingDataExample_StarDist2D.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- StarDist 2D: Schmidt, Uwe, et al. \"Cell detection with star-convex polygons.\" International Conference on Medical Image Computing and Computer-Assisted Intervention. Springer, Cham, 2018.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," if augmentation:\n"," ref_4 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\n"," pdf.multi_cell(190, 5, txt = ref_4, align='L')\n"," pdf.ln(3)\n"," reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(model_path+'/'+model_name+'/'+model_name+\"_training_report.pdf\")\n","\n","\n","def qc_pdf_export():\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'Stardist 2D'\n","\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(2)\n"," pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/lossCurvePlots.png').shape\n"," if os.path.exists(full_QC_model_path+'/Quality Control/lossCurvePlots.png'):\n"," pdf.image(full_QC_model_path+'/Quality Control/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.')\n"," pdf.ln(2)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(3)\n"," pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n"," pdf.image(full_QC_model_path+'/Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/10))\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n","\n"," pdf.ln(1)\n"," html = \"\"\"\n"," \n"," \n"," \"\"\"\n"," with open(full_QC_model_path+'/Quality Control/Quality_Control for '+QC_model_name+'.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," #image = header[0]\n"," #PvGT_IoU = header[1]\n"," fp = header[2]\n"," tp = header[3]\n"," fn = header[4]\n"," precision = header[5]\n"," recall = header[6]\n"," acc = header[7]\n"," f1 = header[8]\n"," n_true = header[9]\n"," n_pred = header[10]\n"," mean_true = header[11]\n"," mean_matched = header[12]\n"," panoptic = header[13]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(\"image #\",\"Prediction v. GT IoU\",'false pos.','true pos.','false neg.',precision,recall,acc,f1,n_true,n_pred,mean_true,mean_matched,panoptic)\n"," html = html+header\n"," i=0\n"," for row in metrics:\n"," i+=1\n"," #image = row[0]\n"," PvGT_IoU = row[1]\n"," fp = row[2]\n"," tp = row[3]\n"," fn = row[4]\n"," precision = row[5]\n"," recall = row[6]\n"," acc = row[7]\n"," f1 = row[8]\n"," n_true = row[9]\n"," n_pred = row[10]\n"," mean_true = row[11]\n"," mean_matched = row[12]\n"," panoptic = row[13]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(str(i),str(round(float(PvGT_IoU),3)),fp,tp,fn,str(round(float(precision),3)),str(round(float(recall),3)),str(round(float(acc),3)),str(round(float(f1),3)),n_true,n_pred,str(round(float(mean_true),3)),str(round(float(mean_matched),3)),str(round(float(panoptic),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}{3}{4}{5}{6}{7}{8}{9}{10}{11}{12}{13}
{0}{1}{2}{3}{4}{5}{6}{7}{8}{9}{10}{11}{12}{13}
\"\"\"\n"," \n"," pdf.write_html(html)\n","\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- StarDist 2D: Schmidt, Uwe, et al. \"Cell detection with star-convex polygons.\" International Conference on Medical Image Computing and Computer-Assisted Intervention. Springer, Cham, 2018.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n"," pdf.ln(3)\n"," reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n","\n","# Build requirements file for local run\n","after = [str(m) for m in sys.modules]\n","build_requirements_file(before, after)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **2. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["\n","## **2.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelerator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"zCvebubeSaGY","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"sNIVx8_CLolt"},"source":["## **2.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","cellView":"form"},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')\n","\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xyjdhq0Zt5TF"},"source":["** If you cannot see your files, reactivate your session by connecting to your hosted runtime.** \n","\n","\n","\"Example
Connect to a hosted runtime.
"]},{"cell_type":"markdown","metadata":{"id":"HLYcZR9gMv42"},"source":["# **3. Select your parameters and paths**\n","---"]},{"cell_type":"markdown","metadata":{"id":"FQ_QxtSWQ7CL"},"source":["## **3.1. Setting main training parameters**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"id":"AuESFimvMv43"},"source":[" **Paths for training, predictions and results**\n","\n","\n","**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source (images of nuclei) and Training_target (masks) training data respecively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","\n","**Training parameters**\n","\n","**`number_of_epochs`:** Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a 50-100 epochs, but a full training should run for up to 400 epochs. Evaluate the performance after training (see 5.). **Default value: 100**\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 2**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each image / patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n","\n","**`patch_size`:** Input the size of the patches use to train StarDist 2D (length of a side). The value should be smaller or equal to the dimensions of the image. Make the patch size as large as possible and divisible by 8. **Default value: dimension of the training images** \n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during the training. **Default value: 10** \n","\n","**`n_rays`:** Set number of rays (corners) used for StarDist (for instance, a square has 4 corners). **Default value: 32** \n","\n","**`grid_parameter`:** increase this number if the cells/nuclei are very large or decrease it if they are very small. **Default value: 2**\n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0003**\n","\n","**If you get an Out of memory (OOM) error during the training, manually decrease the patch_size value until the OOM error disappear.**\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","cellView":"form"},"source":["#@markdown ###Path to training images: \n","Training_source = \"\" #@param {type:\"string\"}\n","\n","Training_target = \"\" #@param {type:\"string\"}\n","\n","\n","#@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","\n","model_path = \"\" #@param {type:\"string\"}\n","#trained_model = model_path \n","\n","\n","#@markdown ### Other parameters for training:\n","number_of_epochs = 100#@param {type:\"number\"}\n","\n","#@markdown ###Advanced Parameters\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please input:\n","\n","#GPU_limit = 90 #@param {type:\"number\"}\n","batch_size = 4 #@param {type:\"number\"}\n","number_of_steps = 0#@param {type:\"number\"}\n","patch_size = 1024 #@param {type:\"number\"}\n","percentage_validation = 10 #@param {type:\"number\"}\n","n_rays = 32 #@param {type:\"number\"}\n","grid_parameter = 2#@param [1, 2, 4, 8, 16, 32] {type:\"raw\"}\n","initial_learning_rate = 0.0003 #@param {type:\"number\"}\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 2\n"," n_rays = 32\n"," percentage_validation = 10\n"," grid_parameter = 2\n"," initial_learning_rate = 0.0003\n","\n","percentage = percentage_validation/100\n","\n","#here we check that no model with the same name already exist, if so print a warning\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted !!\")\n"," print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\"+W)\n","\n"," \n","# Here we open will randomly chosen input and output image\n","random_choice = random.choice(os.listdir(Training_source))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","# Here we check the image dimensions\n","\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","print('Loaded images (width, length) =', x.shape)\n","\n","# If default parameters, patch size is the same as image size\n","if (Use_Default_Advanced_Parameters):\n"," patch_size = min(Image_Y, Image_X)\n"," \n","#Hyperparameters failsafes\n","\n","# Here we check that patch_size is smaller than the smallest xy dimension of the image \n","\n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print(bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","if patch_size > 2048:\n"," patch_size = 2048\n"," print(bcolors.WARNING + \" Your image dimension is large; therefore the patch_size chosen is now:\",patch_size)\n","\n","\n","# Here we check that the patch_size is divisible by 16\n","if not patch_size % 16 == 0:\n"," patch_size = ((int(patch_size / 16)-1) * 16)\n"," print(bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is:\",patch_size)\n","\n","# Here we disable pre-trained model by default (in case the next cell is not ran)\n","Use_pretrained_model = False\n","\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = False\n","\n","\n","print(\"Parameters initiated.\")\n","\n","\n","os.chdir(Training_target)\n","y = imread(Training_target+\"/\"+random_choice)\n","\n","#Here we use a simple normalisation strategy to visualise the image\n","norm = simple_norm(x, percent = 99)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest', norm=norm, cmap='magma')\n","plt.title('Training source')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(y, interpolation='nearest', cmap=lbl_cmap)\n","plt.title('Training target')\n","plt.axis('off');\n","plt.savefig('/content/TrainingDataExample_StarDist2D.png',bbox_inches='tight',pad_inches=0)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xyQZKby8yFME"},"source":["## **3.2. Data augmentation**\n","---\n",""]},{"cell_type":"markdown","metadata":{"id":"w_jCy7xOx2g3"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n","Data augmentation is performed here via random rotations, flips, and intensity changes.\n","\n","\n"," **However, data augmentation is not a magic solution and may also introduce issues. Therefore, we recommend that you train your network with and without augmentation, and use the QC section to validate that it improves overall performances.** "]},{"cell_type":"code","metadata":{"id":"DMqWq5-AxnFU","cellView":"form"},"source":["#Data augmentation\n","\n","Use_Data_augmentation = False #@param {type:\"boolean\"}\n","\n","#@markdown ####Choose a factor by which you want to multiply your original dataset\n","\n","Multiply_dataset_by = 4 #@param {type:\"slider\", min:1, max:10, step:1}\n","\n","\n","def random_fliprot(img, mask): \n"," assert img.ndim >= mask.ndim\n"," axes = tuple(range(mask.ndim))\n"," perm = tuple(np.random.permutation(axes))\n"," img = img.transpose(perm + tuple(range(mask.ndim, img.ndim))) \n"," mask = mask.transpose(perm) \n"," for ax in axes: \n"," if np.random.rand() > 0.5:\n"," img = np.flip(img, axis=ax)\n"," mask = np.flip(mask, axis=ax)\n"," return img, mask \n","\n","def random_intensity_change(img):\n"," img = img*np.random.uniform(0.6,2) + np.random.uniform(-0.2,0.2)\n"," return img\n","\n","\n","def augmenter(x, y):\n"," \"\"\"Augmentation of a single input/label image pair.\n"," x is an input image\n"," y is the corresponding ground-truth label image\n"," \"\"\"\n"," x, y = random_fliprot(x, y)\n"," x = random_intensity_change(x)\n"," # add some gaussian noise\n"," sig = 0.02*np.random.uniform(0,1)\n"," x = x + sig*np.random.normal(0,1,x.shape)\n"," return x, y\n","\n","\n","\n","if Use_Data_augmentation:\n"," augmenter = augmenter\n"," print(\"Data augmentation enabled\")\n","\n","\n","if not Use_Data_augmentation:\n"," augmenter = None\n"," print(bcolors.WARNING+\"Data augmentation disabled\") \n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"3L9zSGtORKYI"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a StarDist model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"9vC2n-HeLdiJ","cellView":"form"},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","\n","Use_pretrained_model = True #@param {type:\"boolean\"}\n","\n","pretrained_model_choice = \"2D_versatile_fluo_from_Stardist_Fiji\" #@param [\"Model_from_file\", \"2D_versatile_fluo_from_Stardist_Fiji\", \"2D_Demo_Model_from_Stardist_Github\", \"Versatile_H&E_nuclei\"]\n","\n","Weights_choice = \"best\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","\n","# --------------------- Download the Demo 2D model provided in the Stardist 2D github ------------------------\n","\n"," if pretrained_model_choice == \"2D_Demo_Model_from_Stardist_Github\":\n"," pretrained_model_name = \"2D_Demo\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the 2D_Demo_Model_from_Stardist_Github\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"/~https://github.com/mpicbg-csbd/stardist/raw/master/models/examples/2D_demo/config.json\", pretrained_model_path)\n"," wget.download(\"/~https://github.com/mpicbg-csbd/stardist/raw/master/models/examples/2D_demo/thresholds.json\", pretrained_model_path)\n"," wget.download(\"/~https://github.com/mpicbg-csbd/stardist/blob/master/models/examples/2D_demo/weights_best.h5?raw=true\", pretrained_model_path) \n"," wget.download(\"/~https://github.com/mpicbg-csbd/stardist/blob/master/models/examples/2D_demo/weights_last.h5?raw=true\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Download the Demo 2D_versatile_fluo_from_Stardist_Fiji ------------------------\n","\n"," if pretrained_model_choice == \"2D_versatile_fluo_from_Stardist_Fiji\":\n"," print(\"Downloading the 2D_versatile_fluo_from_Stardist_Fiji\")\n"," pretrained_model_name = \"2D_versatile_fluo\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," \n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," \n"," wget.download(\"https://cloud.mpi-cbg.de/index.php/s/1k5Zcy7PpFWRb0Q/download?path=/versatile&files=2D_versatile_fluo.zip\", pretrained_model_path)\n"," \n"," with zipfile.ZipFile(pretrained_model_path+\"/2D_versatile_fluo.zip\", 'r') as zip_ref:\n"," zip_ref.extractall(pretrained_model_path)\n"," \n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_best.h5\")\n","\n","# --------------------- Download the Versatile (H&E nuclei)_fluo_from_Stardist_Fiji ------------------------\n","\n"," if pretrained_model_choice == \"Versatile_H&E_nuclei\":\n"," print(\"Downloading the Versatile_H&E_nuclei from_Stardist_Fiji\")\n"," pretrained_model_name = \"2D_versatile_he\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," \n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," \n"," wget.download(\"https://cloud.mpi-cbg.de/index.php/s/1k5Zcy7PpFWRb0Q/download?path=/versatile&files=2D_versatile_he.zip\", pretrained_model_path)\n"," \n"," with zipfile.ZipFile(pretrained_model_path+\"/2D_versatile_he.zip\", 'r') as zip_ref:\n"," zip_ref.extractall(pretrained_model_path)\n"," \n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_best.h5\")\n","\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_last.h5 pretrained model does not exist' + W)\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print(bcolors.WARNING+'No pretrained network will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"MCGklf1vZf2M"},"source":["#**4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"1KYOuygETJkT"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"id":"lIUAOJ_LMv5E","cellView":"form"},"source":["#@markdown ##Create the model and dataset objects\n","\n","# --------------------- Here we delete the model folder if it already exist ------------------------\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: Model folder already exists and has been removed !!\" + W)\n"," shutil.rmtree(model_path+'/'+model_name)\n","\n","\n","# --------------------- Here we load the augmented data or the raw data ------------------------\n","\n","\n","Training_source_dir = Training_source\n","Training_target_dir = Training_target\n","# --------------------- ------------------------------------------------\n","\n","training_images_tiff=Training_source_dir+\"/*.tif\"\n","mask_images_tiff=Training_target_dir+\"/*.tif\"\n","\n","# this funtion imports training images and masks and sorts them suitable for the network\n","X = sorted(glob(training_images_tiff)) \n","Y = sorted(glob(mask_images_tiff)) \n","\n","# assert -funtion check that X and Y really have images. If not this cell raises an error\n","assert all(Path(x).name==Path(y).name for x,y in zip(X,Y))\n","\n","# Here we map the training dataset (images and masks).\n","X = list(map(imread,X))\n","Y = list(map(imread,Y))\n","n_channel = 1 if X[0].ndim == 2 else X[0].shape[-1]\n","\n","# --------------------- Piece of code to remove alpha channel in RGBA images ------------------------\n","\n","if n_channel == 4:\n"," X[:] = [i[:,:,:3] for i in X]\n"," print(\"The alpha channel has been removed\")\n"," n_channel = 1 if X[0].ndim == 2 else X[0].shape[-1]\n","\n","\n","if not Use_Data_augmentation:\n"," augmenter = None\n"," \n","\n","#Normalize images and fill small label holes.\n","\n","if n_channel == 1:\n"," axis_norm = (0,1) # normalize channels independently\n"," print(\"Normalizing image channels independently\")\n","\n","\n","if n_channel > 1:\n"," axis_norm = (0,1,2) # normalize channels jointly\n"," print(\"Normalizing image channels jointly\") \n"," sys.stdout.flush()\n","\n","X = [normalize(x,1,99.8,axis=axis_norm) for x in tqdm(X)]\n","Y = [fill_label_holes(y) for y in tqdm(Y)]\n","\n","#Here we split the your training dataset into training images (90 %) and validation images (10 %). \n","#It is advisable to use 10 % of your training dataset for validation. This ensures the truthfull validation error value. If only few validation images are used network may choose too easy or too challenging images for validation. \n","# split training data (images and masks) into training images and validation images.\n","assert len(X) > 1, \"not enough training data\"\n","rng = np.random.RandomState(42)\n","ind = rng.permutation(len(X))\n","n_val = max(1, int(round(percentage * len(ind))))\n","ind_train, ind_val = ind[:-n_val], ind[-n_val:]\n","X_val, Y_val = [X[i] for i in ind_val] , [Y[i] for i in ind_val]\n","X_trn, Y_trn = [X[i] for i in ind_train], [Y[i] for i in ind_train] \n","print('number of images: %3d' % len(X))\n","print('- training: %3d' % len(X_trn))\n","print('- validation: %3d' % len(X_val))\n","\n","# Use OpenCL-based computations for data generator during training (requires 'gputools')\n","# Currently always false for stability\n","use_gpu = False and gputools_available()\n","\n","#Here we ensure that our network has a minimal number of steps\n","\n","if (Use_Default_Advanced_Parameters) or (number_of_steps == 0): \n"," # number_of_steps= (int(len(X)/batch_size)+1)\n"," number_of_steps = int(Image_X*Image_Y/(patch_size*patch_size))*(int(len(X)/batch_size)+1)\n"," if (Use_Data_augmentation):\n"," augmentation_factor = Multiply_dataset_by\n"," number_of_steps = number_of_steps * augmentation_factor\n","\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","\n","\n","conf = Config2D (\n"," n_rays = n_rays,\n"," use_gpu = use_gpu,\n"," train_batch_size = batch_size,\n"," n_channel_in = n_channel,\n"," train_patch_size = (patch_size, patch_size),\n"," grid = (grid_parameter, grid_parameter),\n"," train_learning_rate = initial_learning_rate,\n",")\n","\n","# Here we create a model according to section 5.3.\n","model = StarDist2D(conf, name=model_name, basedir=model_path)\n","\n","# --------------------- Using pretrained model ------------------------\n","# Load the pretrained weights \n","if Use_pretrained_model:\n"," model.load_weights(h5_file_path)\n","\n","\n","# --------------------- ---------------------- ------------------------\n","\n","#Here we check the FOV of the network.\n","median_size = calculate_extents(list(Y), np.median)\n","fov = np.array(model._axes_tile_overlap('YX'))\n","if any(median_size > fov):\n"," print(bcolors.WARNING+\"WARNING: median object size larger than field of view of the neural network.\")\n","print(conf)\n","\n","pdf_export(augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)\n","\n","print(\"Number of steps: \"+str(number_of_steps))\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"0Dfn8ZsEMv5d"},"source":["## **4.2. Start Training**\n","---\n","\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point.\n","\n","**Of Note:** At the end of the training, your model will be automatically exported so it can be used in the Stardist Fiji plugin. You can find it in your model folder (TF_SavedModel.zip). In Fiji, Make sure to choose the right version of tensorflow. You can check at: Edit-- Options-- Tensorflow. Choose the version 1.4 (CPU or GPU depending on your system).\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"code","metadata":{"scrolled":true,"id":"iwNmp1PUzRDQ","cellView":"form"},"source":["#@markdown ##Start training\n","\n","start = time.time()\n","\n","\n","history = model.train(X_trn, Y_trn, validation_data=(X_val,Y_val), augmenter=augmenter,\n"," epochs=number_of_epochs, steps_per_epoch=number_of_steps)\n","None;\n","\n","print(\"Training done\")\n","\n","print(\"Network optimization in progress\")\n","#Here we optimize the network.\n","model.optimize_thresholds(X_val, Y_val)\n","\n","print(\"Done\")\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n"," shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n","\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","model.export_TF()\n","\n","print(\"Your model has been sucessfully exported and can now also be used in the Stardist Fiji plugin\")\n","\n","pdf_export(trained=True, augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)\n","\n","#Create a pdf document with training summary"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"_0Hynw3-xHp1"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows you to perform important quality checks on the validity and generalisability of the trained model. \n","\n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"eAJzMwPA6tlH","cellView":"form"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else: \n"," print(bcolors.WARNING+'!! WARNING: The chosen model does not exist !!')\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"dhJROwlAMv5o"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased.\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"vMzSP50kMv5p","cellView":"form"},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png',bbox_inches='tight',pad_inches=0)\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"X5_92nL2xdP6"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","This section will calculate the Intersection over Union score for all the images provided in the Source_QC_folder and Target_QC_folder ! The result for one of the image will also be displayed.\n","\n","The **Intersection over Union** (IuO) metric is a method that can be used to quantify the percent overlap between the target mask and your prediction output. **Therefore, the closer to 1, the better the performance.** This metric can be used to assess the quality of your model to accurately predict nuclei. \n","\n","Here, the IuO is both calculated over the whole image and on a per-object basis. The value displayed below is the IuO value calculated over the entire image. The IuO value calculated on a per-object basis is used to calculate the other metrics displayed.\n","\n","“n_true” refers to the number of objects present in the ground truth image. “n_pred” refers to the number of objects present in the predicted image. \n","\n","When a segmented object has an IuO value above 0.5 (compared to the corresponding ground truth), it is then considered a true positive. The number of “**true positives**” is available in the table below. The number of “false positive” is then defined as “**false positive**” = “n_pred” - “true positive”. The number of “false negative” is defined as “false negative” = “n_true” - “true positive”.\n","\n","The mean_matched_score is the mean IoUs of matched true positives. The mean_true_score is the mean IoUs of matched true positives but normalized by the total number of ground truth objects. The panoptic_quality is calculated as described by [Kirillov et al. 2019](https://arxiv.org/abs/1801.00868).\n","\n","For more information about the other metric displayed, please consult the SI of the paper describing ZeroCostDL4Mic.\n","\n"," The results can be found in the \"*Quality Control*\" folder which is located inside your \"model_folder\"."]},{"cell_type":"code","metadata":{"id":"w90MdriMxhjD","cellView":"form"},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","from stardist.matching import matching\n","from stardist.plot import render_label, render_label_pred \n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","\n","#Create a quality control Folder and check if the folder already exist\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\") == False:\n"," os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n","\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","\n","# Generate predictions from the Source_QC_folder and save them in the QC folder\n","\n","Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\n","\n","np.random.seed(16)\n","lbl_cmap = random_label_cmap()\n","Z = sorted(glob(Source_QC_folder_tif))\n","Z = list(map(imread,Z))\n","n_channel = 1 if Z[0].ndim == 2 else Z[0].shape[-1]\n","\n","print('Number of test dataset found in the folder: '+str(len(Z)))\n"," \n","#Normalize images.\n","\n","if n_channel == 1:\n"," axis_norm = (0,1) # normalize channels independently\n"," print(\"Normalizing image channels independently\")\n","\n","if n_channel > 1:\n"," axis_norm = (0,1,2) # normalize channels jointly\n"," print(\"Normalizing image channels jointly\") \n","\n","model = StarDist2D(None, name=QC_model_name, basedir=QC_model_path)\n","\n","names = [os.path.basename(f) for f in sorted(glob(Source_QC_folder_tif))]\n","\n"," \n","# modify the names to suitable form: path_images/image_numberX.tif\n"," \n","lenght_of_Z = len(Z)\n"," \n","for i in range(lenght_of_Z):\n"," img = normalize(Z[i], 1,99.8, axis=axis_norm)\n"," labels, polygons = model.predict_instances(img)\n"," os.chdir(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n"," imsave(names[i], labels, polygons)\n","\n","# Here we start testing the differences between GT and predicted masks\n","\n","with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Quality_Control for \"+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file, delimiter=\",\")\n"," writer.writerow([\"image\",\"Prediction v. GT Intersection over Union\", \"false positive\", \"true positive\", \"false negative\", \"precision\", \"recall\", \"accuracy\", \"f1 score\", \"n_true\", \"n_pred\", \"mean_true_score\", \"mean_matched_score\", \"panoptic_quality\"]) \n","\n","# define the images\n","\n"," for n in os.listdir(Source_QC_folder):\n"," \n"," if not os.path.isdir(os.path.join(Source_QC_folder,n)):\n"," print('Running QC on: '+n)\n"," test_input = io.imread(os.path.join(Source_QC_folder,n))\n"," test_prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\",n))\n"," test_ground_truth_image = io.imread(os.path.join(Target_QC_folder, n))\n","\n"," # Calculate the matching (with IoU threshold `thresh`) and all metrics\n"," \n"," stats = matching(test_ground_truth_image, test_prediction, thresh=0.5)\n","\n"," #Convert pixel values to 0 or 255\n"," test_prediction_0_to_255 = test_prediction\n"," test_prediction_0_to_255[test_prediction_0_to_255>0] = 255\n","\n"," #Convert pixel values to 0 or 255\n"," test_ground_truth_0_to_255 = test_ground_truth_image\n"," test_ground_truth_0_to_255[test_ground_truth_0_to_255>0] = 255\n","\n","\n"," # Intersection over Union metric\n","\n"," intersection = np.logical_and(test_ground_truth_0_to_255, test_prediction_0_to_255)\n"," union = np.logical_or(test_ground_truth_0_to_255, test_prediction_0_to_255)\n"," iou_score = np.sum(intersection) / np.sum(union)\n"," writer.writerow([n, str(iou_score), str(stats.fp), str(stats.tp), str(stats.fn), str(stats.precision), str(stats.recall), str(stats.accuracy), str(stats.f1), str(stats.n_true), str(stats.n_pred), str(stats.mean_true_score), str(stats.mean_matched_score), str(stats.panoptic_quality)])\n","\n","from tabulate import tabulate\n","\n","df = pd.read_csv (QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Quality_Control for \"+QC_model_name+\".csv\")\n","print(tabulate(df, headers='keys', tablefmt='psql'))\n","\n","\n","from astropy.visualization import simple_norm\n","\n","# ------------- For display ------------\n","print('--------------------------------------------------------------')\n","@interact\n","def show_QC_results(file = os.listdir(Source_QC_folder)):\n"," \n","\n"," plt.figure(figsize=(25,5))\n"," if n_channel > 1:\n"," source_image = io.imread(os.path.join(Source_QC_folder, file))\n"," if n_channel == 1:\n"," source_image = io.imread(os.path.join(Source_QC_folder, file), as_gray = True)\n","\n"," target_image = io.imread(os.path.join(Target_QC_folder, file), as_gray = True)\n"," prediction = io.imread(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction/\"+file, as_gray = True)\n","\n"," stats = matching(prediction, target_image, thresh=0.5)\n","\n"," target_image_mask = np.empty_like(target_image)\n"," target_image_mask[target_image > 0] = 255\n"," target_image_mask[target_image == 0] = 0\n"," \n"," prediction_mask = np.empty_like(prediction)\n"," prediction_mask[prediction > 0] = 255\n"," prediction_mask[prediction == 0] = 0\n","\n"," intersection = np.logical_and(target_image_mask, prediction_mask)\n"," union = np.logical_or(target_image_mask, prediction_mask)\n"," iou_score = np.sum(intersection) / np.sum(union)\n","\n"," norm = simple_norm(source_image, percent = 99)\n","\n"," #Input\n"," plt.subplot(1,4,1)\n"," plt.axis('off')\n"," if n_channel > 1:\n"," plt.imshow(source_image)\n"," if n_channel == 1:\n"," plt.imshow(source_image, aspect='equal', norm=norm, cmap='magma', interpolation='nearest')\n"," plt.title('Input')\n","\n"," #Ground-truth\n"," plt.subplot(1,4,2)\n"," plt.axis('off')\n"," plt.imshow(target_image_mask, aspect='equal', cmap='Greens')\n"," plt.title('Ground Truth')\n","\n"," #Prediction\n"," plt.subplot(1,4,3)\n"," plt.axis('off')\n"," plt.imshow(prediction_mask, aspect='equal', cmap='Purples')\n"," plt.title('Prediction')\n","\n"," #Overlay\n"," plt.subplot(1,4,4)\n"," plt.axis('off')\n"," plt.imshow(target_image_mask, cmap='Greens')\n"," plt.imshow(prediction_mask, alpha=0.5, cmap='Purples')\n"," plt.title('Ground Truth and Prediction, Intersection over Union:'+str(round(iou_score,3 )));\n"," plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","\n","\n","qc_pdf_export()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-tJeeJjLnRkP"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"d8wuQGjoq6eN"},"source":["\n","\n","## **6.1 Generate prediction(s) from unseen dataset**\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive.\n","\n","---\n","\n","The current trained model (from section 4.3) can now be used to process images. If an older model needs to be used, please untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Prediction_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contains the images that you want to predict using the network that you will train.\n","\n","**`Result_folder`:** This folder will contain the predicted output ROI.\n","\n","**`Data_type`:** Please indicate if the images you want to predict are single images or stacks\n","\n","\n","In stardist the following results can be exported:\n","- Region of interest (ROI) that can be opened in ImageJ / Fiji. The ROI are saved inside of a .zip file in your choosen result folder. To open the ROI in Fiji, just drag and drop the zip file !**\n","- The predicted mask images\n","- A tracking file that can easily be imported into Trackmate to track the nuclei.\n","- A CSV file that contains the number of nuclei detected per image. \n","- A CSV file that contains the coordinate the centre of each detected nuclei (single image only). \n","\n"]},{"cell_type":"code","metadata":{"id":"y2TD5p7MZrEb","cellView":"form"},"source":["Single_Images = 1\n","Stacks = 2\n","\n","#@markdown ### Provide the path to your dataset and to the folder where the prediction will be saved (Result folder), then play the cell to predict output on your unseen images.\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Results_folder = \"\" #@param {type:\"string\"}\n","\n","#@markdown ###Are your data single images or stacks?\n","\n","Data_type = Single_Images #@param [\"Single_Images\", \"Stacks\"] {type:\"raw\"}\n","\n","#@markdown ###What outputs would you like to generate?\n","Region_of_interests = True #@param {type:\"boolean\"}\n","Mask_images = True #@param {type:\"boolean\"}\n","Tracking_file = False #@param {type:\"boolean\"}\n","\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," print(bcolors.WARNING+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","#single images\n","\n","if Data_type == 1 :\n","\n"," Data_folder = Data_folder+\"/*.tif\"\n","\n"," print(\"Single images are now beeing predicted\")\n"," np.random.seed(16)\n"," lbl_cmap = random_label_cmap()\n"," X = sorted(glob(Data_folder))\n"," X = list(map(imread,X))\n"," n_channel = 1 if X[0].ndim == 2 else X[0].shape[-1]\n"," \n"," # axis_norm = (0,1,2) # normalize channels jointly\n"," if n_channel == 1:\n"," axis_norm = (0,1) # normalize channels independently\n"," print(\"Normalizing image channels independently\")\n","\n"," if n_channel > 1:\n"," axis_norm = (0,1,2) # normalize channels jointly\n"," print(\"Normalizing image channels jointly\") \n"," sys.stdout.flush() \n"," \n"," model = StarDist2D(None, name = Prediction_model_name, basedir = Prediction_model_path)\n"," \n"," names = [os.path.basename(f) for f in sorted(glob(Data_folder))] \n"," Nuclei_number = []\n","\n"," # modify the names to suitable form: path_images/image_numberX.tif\n"," FILEnames = []\n"," for m in names:\n"," m = Results_folder+'/'+m\n"," FILEnames.append(m)\n","\n"," # Create a list of name with no extension\n"," \n"," name_no_extension=[]\n"," for n in names:\n"," name_no_extension.append(os.path.splitext(n)[0])\n"," \n"," # Save all ROIs and masks into results folder\n"," \n"," for i in range(len(X)):\n"," img = normalize(X[i], 1,99.8, axis = axis_norm)\n"," labels, polygons = model.predict_instances(img)\n"," \n"," os.chdir(Results_folder)\n","\n"," if Mask_images:\n"," imsave(FILEnames[i], labels, polygons)\n","\n"," if Region_of_interests:\n"," export_imagej_rois(name_no_extension[i], polygons['coord'])\n","\n"," if Tracking_file:\n"," Tracking_image = np.zeros((img.shape[1], img.shape[0]))\n"," for point in polygons['points']:\n"," cv2.circle(Tracking_image,tuple(point),0,(1), -1)\n"," \n"," Tracking_image_32 = img_as_float32(Tracking_image, force_copy=False)\n"," Tracking_image_8 = img_as_ubyte(Tracking_image, force_copy=True) \n"," Tracking_image_8_rot = np.rot90(Tracking_image_8, axes=(0, 1))\n"," Tracking_image_8_rot_flip = np.flipud(Tracking_image_8_rot)\n"," imsave(Results_folder+\"/\"+str(name_no_extension[i])+\"_tracking_file.tif\", Tracking_image_8_rot_flip, compress=ZIP_DEFLATED)\n"," \n"," Nuclei_centre_coordinate = polygons['points']\n"," my_df2 = pd.DataFrame(Nuclei_centre_coordinate)\n"," my_df2.columns =['Y', 'X']\n"," \n"," my_df2.to_csv(Results_folder+'/'+name_no_extension[i]+'_Nuclei_centre.csv', index=False, header=True)\n","\n"," Nuclei_array = polygons['coord']\n"," Nuclei_array2 = [names[i], Nuclei_array.shape[0]]\n"," Nuclei_number.append(Nuclei_array2) \n","\n"," my_df = pd.DataFrame(Nuclei_number)\n"," my_df2.columns =['Frame number', 'Number of objects']\n"," my_df.to_csv(Results_folder+'/Nuclei_count.csv', index=False, header=False)\n"," \n","# One example is displayed\n","\n"," print(\"One example image is displayed bellow:\")\n"," plt.figure(figsize=(10,10))\n"," plt.imshow(img if img.ndim==2 else img[...,:3], clim=(0,1), cmap='gray')\n"," plt.imshow(labels, cmap=lbl_cmap, alpha=0.5)\n"," plt.axis('off');\n"," plt.savefig(name_no_extension[i]+\"_overlay.tif\")\n","\n","# Here is the code to analyse stacks\n","\n","if Data_type == 2 :\n"," print(\"Stacks are now beeing predicted\")\n"," np.random.seed(42)\n"," lbl_cmap = random_label_cmap()\n","\n"," # normalize channels independently\n"," axis_norm = (0,1) \n"," \n"," model = StarDist2D(None, name = Prediction_model_name, basedir = Prediction_model_path)\n"," \n"," for image in os.listdir(Data_folder):\n"," print(\"Performing prediction on: \"+image)\n","\n"," Number_of_nuclei_list = []\n"," Number_of_frame_list = []\n","\n"," timelapse = imread(Data_folder+\"/\"+image)\n","\n"," short_name = os.path.splitext(image) \n"," \n"," timelapse = normalize(timelapse, 1,99.8, axis=(0,)+tuple(1+np.array(axis_norm)))\n"," \n","\n"," if Region_of_interests: \n"," polygons = [model.predict_instances(frame)[1]['coord'] for frame in tqdm(timelapse)] \n"," export_imagej_rois(Results_folder+\"/\"+str(short_name[0]), polygons, compression=ZIP_DEFLATED) \n"," \n"," n_timepoint = timelapse.shape[0]\n"," prediction_stack = np.zeros((n_timepoint, timelapse.shape[1], timelapse.shape[2]))\n"," Tracking_stack = np.zeros((n_timepoint, timelapse.shape[2], timelapse.shape[1]))\n","\n","# Analyse each time points one after the other\n"," if Mask_images or Tracking_file:\n"," for t in range(n_timepoint):\n"," img_t = timelapse[t]\n"," labels, polygons = model.predict_instances(img_t) \n"," prediction_stack[t] = labels\n"," Nuclei_array = polygons['coord']\n"," Nuclei_array2 = [str(t), Nuclei_array.shape[0]]\n"," Number_of_nuclei_list.append(Nuclei_array2)\n"," Number_of_frame_list.append(t)\n","\n","# Create a tracking file for trackmate\n","\n"," for point in polygons['points']:\n"," cv2.circle(Tracking_stack[t],tuple(point),0,(1), -1)\n","\n"," prediction_stack_32 = img_as_float32(prediction_stack, force_copy=False)\n"," Tracking_stack_32 = img_as_float32(Tracking_stack, force_copy=False)\n"," Tracking_stack_8 = img_as_ubyte(Tracking_stack_32, force_copy=True)\n"," \n"," Tracking_stack_8_rot = np.rot90(Tracking_stack_8, axes=(1,2))\n"," Tracking_stack_8_rot_flip = np.fliplr(Tracking_stack_8_rot)\n","\n","# Export a csv file containing the number of nuclei detected at each frame\n"," my_df = pd.DataFrame(Number_of_nuclei_list)\n"," my_df.to_csv(Results_folder+'/'+str(short_name[0])+'_Nuclei_number.csv', index=False, header=False)\n","\n"," os.chdir(Results_folder)\n"," if Mask_images:\n"," imsave(str(short_name[0])+\".tif\", prediction_stack_32, compress=ZIP_DEFLATED)\n"," if Tracking_file:\n"," imsave(str(short_name[0])+\"_tracking_file.tif\", Tracking_stack_8_rot_flip, compress=ZIP_DEFLATED)\n","\n"," # Object detected vs frame number\n"," plt.figure(figsize=(20,5))\n"," my_df.plot()\n"," plt.title('Number of objects vs frame number')\n"," plt.ylabel('Number of detected objects')\n"," plt.xlabel('Frame number')\n"," plt.legend()\n"," plt.savefig(Results_folder+'/'+str(short_name[0])+'_Object_detected_vs_frame_number.png',bbox_inches='tight',pad_inches=0)\n"," plt.show() \n","\n","print(\"Predictions completed\") "],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"KjGHBGmxlk9B"},"source":["## **6.2. Generate prediction(s) from unseen dataset (Big data)**\n","---\n","\n","You can use this section of the notebook to generate predictions on very large images. Compatible file formats include .Tif and .svs files.\n","\n","**`Data_folder`:** This folder should contains the images that you want to predict using the network that you trained.\n","\n","**`Result_folder`:** This folder will contain the predicted output ROI.\n","\n","\n","In stardist the following results can be exported:\n","- Region of interest (ROI) that can be opened in ImageJ / Fiji. The ROI are saved inside of a .zip file in your choosen result folder. To open the ROI in Fiji, just drag and drop the zip file ! IMPORTANT: ROI files cannot be exported for extremely large images.\n","- The predicted mask images\n","\n","\n","\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"jxjHeOFFleSV","cellView":"form"},"source":["#@markdown ### Provide the path to your dataset and to the folder where the prediction will be saved (Result folder), then play the cell to predict output on your unseen images.\n","\n","\n","start = time.time()\n","\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","\n","Results_folder = \"\" #@param {type:\"string\"}\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","\n","#@markdown #####To analyse very large image, your images need to be divided into blocks. Each blocks will then be processed independently and re-assembled to generate the final image. \n","#@markdown #####Here you can choose the dimension of the block.\n","\n","block_size_Y = 1024#@param {type:\"number\"}\n","block_size_X = 1024#@param {type:\"number\"}\n","\n","\n","#@markdown #####Here you can the amount of overlap between each block.\n","min_overlap = 50#@param {type:\"number\"}\n","\n","#@markdown #####To analyse large blocks, your blocks need to be divided into tiles. Each tile will then be processed independently and re-assembled to generate the final block. \n","\n","n_tiles_Y = 1#@param {type:\"number\"}\n","n_tiles_X = 1#@param {type:\"number\"}\n","\n","\n","#@markdown ###What outputs would you like to generate? \n","Mask_images = True #@param {type:\"boolean\"}\n","\n","Region_of_interests = True #@param {type:\"boolean\"}\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","#Create a temp folder to save Zarr files\n","\n","Temp_folder = \"/content/Temp_folder\"\n","\n","if os.path.exists(Temp_folder):\n"," shutil.rmtree(Temp_folder)\n","os.makedirs(Temp_folder)\n","\n","\n","# mi, ma = np.percentile(img[::8], [1,99.8]) # compute percentiles from low-resolution image\n","# mi, ma = np.percentile(img[13000:16000,13000:16000], [1,99.8]) # compute percentiles from smaller crop\n","mi, ma = 0, 255 # use min and max dtype values (suitable here)\n","normalizer = MyNormalizer(mi, ma)\n","\n","\n","np.random.seed(16)\n","lbl_cmap = random_label_cmap()\n","\n","#Load the StarDist model\n","\n","model = StarDist2D(None, name=Prediction_model_name, basedir=Prediction_model_path)\n","\n","\n","for image in os.listdir(Data_folder):\n"," print(\"Performing prediction on: \"+image)\n","\n"," X = imread(Data_folder+\"/\"+image)\n","\n"," print(\"Image dimension \"+str(X.shape))\n","\n"," short_name = os.path.splitext(image)\n","\n"," n_channel = 1 if X.ndim == 2 else X.shape[-1]\n"," \n"," # axis_norm = (0,1,2) # normalize channels jointly\n"," if n_channel == 1:\n"," axis_norm = (0,1) # normalize channels independently\n"," print(\"Normalizing image channels independently\")\n"," block_size = (block_size_Y, block_size_X)\n"," min_overlap = (min_overlap, min_overlap)\n"," n_tiles = (n_tiles_Y, n_tiles_X)\n"," axes=\"YX\"\n","\n"," if n_channel > 1:\n"," axis_norm = (0,1,2) # normalize channels jointly\n"," print(\"Normalizing image channels jointly\")\n"," axes=\"YXC\"\n"," block_size = (block_size_Y, block_size_X, 3)\n"," n_tiles = (n_tiles_Y, n_tiles_X, 1)\n"," min_overlap = (min_overlap, min_overlap, 0) \n"," sys.stdout.flush()\n"," \n"," zarr.save_array(str(Temp_folder+\"/image.zarr\"), X)\n"," del X\n"," img = zarr.open(str(Temp_folder+\"/image.zarr\"), mode='r')\n"," \n"," labels = zarr.open(str(Temp_folder+\"/labels.zarr\"), mode='w', shape=img.shape[:3], chunks=img.chunks[:3], dtype=np.int32)\n"," \n"," \n"," labels, polygons = model.predict_instances_big(img, axes=axes, block_size=block_size, min_overlap=min_overlap, context=None,\n"," normalizer=normalizer, show_progress=True, n_tiles=n_tiles)\n"," \n","# Save the predicted mask in the result folder\n"," os.chdir(Results_folder)\n"," if Mask_images:\n"," imsave(str(short_name[0])+\".tif\", labels, compress=ZIP_DEFLATED)\n"," if Region_of_interests: \n"," export_imagej_rois(str(short_name[0])+'labels_roi.zip', polygons['coord'], compression=ZIP_DEFLATED)\n","\n"," del img\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","# One example image \n","\n","fig, (a,b) = plt.subplots(1,2, figsize=(20,20))\n","a.imshow(labels[::8,::8], cmap='tab20b')\n","b.imshow(labels[::8,::8], cmap=lbl_cmap)\n","a.axis('off'); b.axis('off');\n","None;"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"VxGg1fmkFol2"},"source":["# **7. Version log**\n","---\n","**v1.13**: \n","\n","\n","* This version now includes an automatic restart allowing to set the h5py library to v2.10.\n","* The section 1 and 2 are now swapped for better export of *requirements.txt*.\n","\n","* This version also now includes built-in version check and the version log that you're reading now.\n","\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"UvSlTaH14s3t"},"source":["\n","#**Thank you for using StarDist 2D!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/StarDist_3D_ZeroCostDL4Mic.ipynb b/Colab_notebooks/StarDist_3D_ZeroCostDL4Mic.ipynb index 0002133f..e0246bfe 100644 --- a/Colab_notebooks/StarDist_3D_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/StarDist_3D_ZeroCostDL4Mic.ipynb @@ -1,1877 +1 @@ -{ - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "StarDist_3D_ZeroCostDL4Mic.ipynb", - "provenance": [], - "collapsed_sections": [], - "toc_visible": true, - "machine_shape": "hm" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.7" - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "accelerator": "GPU" - }, - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "IkSguVy8Xv83" - }, - "source": [ - "# **StarDist (3D)**\n", - "---\n", - "\n", - "**StarDist 3D** is a deep-learning method that can be used to segment cell nuclei from 3D bioimages and was first published by [Weigert *et al.* in 2019 on arXiv](https://arxiv.org/abs/1908.03636), extending to 3D the 2D appraoch from [Schmidt *et al.* in 2018](https://arxiv.org/abs/1806.03535). It uses a shape representation based on star-convex polygons for nuclei in an image to predict the presence and the shape of these nuclei. This StarDist 3D network is based on an adapted ResNet network architecture.\n", - "\n", - " **This particular notebook enables nuclei segmentation of 3D dataset. If you are interested in 2D dataset, you should use the StarDist 2D notebook instead.**\n", - "\n", - "---\n", - "*Disclaimer*:\n", - "\n", - "This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n", - "\n", - "This notebook is largely based on the paper:\n", - "\n", - "**Cell Detection with Star-convex Polygons** from Schmidt *et al.*, International Conference on Medical Image Computing and Computer-Assisted Intervention (MICCAI), Granada, Spain, September 2018. (https://arxiv.org/abs/1806.03535)\n", - "\n", - "and the 3D extension of the approach:\n", - "\n", - "**Star-convex Polyhedra for 3D Object Detection and Segmentation in Microscopy** from Weigert *et al.* published on arXiv in 2019 (https://arxiv.org/abs/1908.03636)\n", - "\n", - "**The Original code** is freely available in GitHub:\n", - "/~https://github.com/mpicbg-csbd/stardist\n", - "\n", - "**Please also cite this original paper when using or developing this notebook.**\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jWAz2i7RdxUV" - }, - "source": [ - "# **How to use this notebook?**\n", - "\n", - "---\n", - "\n", - "Video describing how to use our notebooks are available on youtube:\n", - " - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n", - " - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n", - "\n", - "\n", - "---\n", - "###**Structure of a notebook**\n", - "\n", - "The notebook contains two types of cell: \n", - "\n", - "**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n", - "\n", - "**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n", - "\n", - "---\n", - "###**Table of contents, Code snippets** and **Files**\n", - "\n", - "On the top left side of the notebook you find three tabs which contain from top to bottom:\n", - "\n", - "*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n", - "\n", - "*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n", - "\n", - "*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n", - "\n", - "**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n", - "\n", - "**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n", - "\n", - "---\n", - "###**Making changes to the notebook**\n", - "\n", - "**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n", - "\n", - "To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n", - "You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "gKDLkLWUd-YX" - }, - "source": [ - "#**0. Before getting started**\n", - "---\n", - " For StarDist to train, **it needs to have access to a paired training dataset made of images of nuclei and their corresponding masks**. Information on how to generate a training dataset is available in our Wiki page: /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n", - "\n", - "**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n", - "\n", - "The data structure is important. It is necessary that all the input data are in the same folder and that all the output data is in a separate folder. The provided training dataset is already split in two folders called \"Training - Images\" (Training_source) and \"Training - Masks\" (Training_target).\n", - "\n", - "Additionally, the corresponding Training_source and Training_target files need to have **the same name**.\n", - "\n", - "Please note that you currently can **only use .tif files!**\n", - "\n", - "You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed.\n", - "\n", - "Here's a common data structure that can work:\n", - "* Experiment A\n", - " - **Training dataset**\n", - " - Images of nuclei (Training_source)\n", - " - img_1.tif, img_2.tif, ...\n", - " - Masks (Training_target)\n", - " - img_1.tif, img_2.tif, ...\n", - " - **Quality control dataset**\n", - " - Images of nuclei\n", - " - img_1.tif, img_2.tif\n", - " - **Masks** \n", - " - img_1.tif, img_2.tif\n", - " - **Data to be predicted**\n", - " - **Results**\n", - "\n", - "---\n", - "**Important note**\n", - "\n", - "- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n", - "\n", - "- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n", - "\n", - "- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "n4yWFoJNnoin" - }, - "source": [ - "# **1. Initialise the Colab session**\n", - "\n", - "\n", - "\n", - "\n", - "---\n", - "\n", - "\n", - "\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "DMNHVZfHmbKb" - }, - "source": [ - "\n", - "## **1.1. Check for GPU access**\n", - "---\n", - "\n", - "By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n", - "\n", - "Go to **Runtime -> Change the Runtime type**\n", - "\n", - "**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n", - "\n", - "**Accelerator: GPU** *(Graphics processing unit)*\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "zCvebubeSaGY", - "cellView": "form" - }, - "source": [ - "#@markdown ##Run this cell to check if you have GPU access\n", - "# %tensorflow_version 1.x\n", - "\n", - "\n", - "import tensorflow as tf\n", - "if tf.test.gpu_device_name()=='':\n", - " print('You do not have GPU access.') \n", - " print('Did you change your runtime ?') \n", - " print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n", - " print('Expect slow performance. To access GPU try reconnecting later')\n", - "\n", - "else:\n", - " print('You have GPU access')\n", - " !nvidia-smi" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "sNIVx8_CLolt" - }, - "source": [ - "## **1.2. Mount your Google Drive**\n", - "---\n", - " To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n", - "\n", - " Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n", - "\n", - " Once this is done, your data are available in the **Files** tab on the top left of notebook." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "01Djr8v-5pPk", - "cellView": "form" - }, - "source": [ - "#@markdown ##Play the cell to connect your Google Drive to Colab\n", - "\n", - "#@markdown * Click on the URL. \n", - "\n", - "#@markdown * Sign in your Google Account. \n", - "\n", - "#@markdown * Copy the authorization code. \n", - "\n", - "#@markdown * Enter the authorization code. \n", - "\n", - "#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n", - "\n", - "# mount user's Google Drive to Google Colab.\n", - "from google.colab import drive\n", - "drive.mount('/content/gdrive')" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "AdN8B91xZO0x" - }, - "source": [ - "# **2. Install StarDist and dependencies**\n", - "---\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "vvtlJacqSmOi", - "cellView": "form" - }, - "source": [ - "#@markdown ##Install StarDist and dependencies\n", - "\n", - "!pip install tifffile # contains tools to operate tiff-files\n", - "!pip install csbdeep # contains tools for restoration of fluorescence microcopy images (Content-aware Image Restoration, CARE). It uses Keras and Tensorflow.\n", - "!pip install stardist # contains tools to operate STARDIST.\n", - "!pip install gputools\n", - "#!pip install edt\n", - "!pip install wget\n", - "!pip install fpdf\n", - "\n", - "#Force session restart\n", - "exit(0)\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "_IG7IqJtSuFY" - }, - "source": [ - "## **2.2. Restart your runtime**\n", - "---\n", - "\n", - "\n", - "\n", - "** Your Runtime has automatically restarted. This is normal.**\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "40rqHN2PTJlu" - }, - "source": [ - "## **2.3. Load key dependencies**\n", - "---\n", - " " - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "fq21zJVFNASx", - "cellView": "form" - }, - "source": [ - "#@markdown ##Load key dependencies\n", - "\n", - "Notebook_version = ['1.12']\n", - "\n", - "from builtins import any as b_any\n", - "\n", - "def get_requirements_path():\n", - " # Store requirements file in 'contents' directory \n", - " current_dir = os.getcwd()\n", - " dir_count = current_dir.count('/') - 1\n", - " path = '../' * (dir_count) + 'requirements.txt'\n", - " return path\n", - "\n", - "def filter_files(file_list, filter_list):\n", - " filtered_list = []\n", - " for fname in file_list:\n", - " if b_any(fname.split('==')[0] in s for s in filter_list):\n", - " filtered_list.append(fname)\n", - " return filtered_list\n", - "\n", - "def build_requirements_file(before, after):\n", - " path = get_requirements_path()\n", - "\n", - " # Exporting requirements.txt for local run\n", - " !pip freeze > $path\n", - "\n", - " # Get minimum requirements file\n", - " df = pd.read_csv(path, delimiter = \"\\n\")\n", - " mod_list = [m.split('.')[0] for m in after if not m in before]\n", - " req_list_temp = df.values.tolist()\n", - " req_list = [x[0] for x in req_list_temp]\n", - "\n", - " # Replace with package name and handle cases where import name is different to module name\n", - " mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]\n", - " mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list] \n", - " filtered_list = filter_files(req_list, mod_replace_list)\n", - "\n", - " file=open(path,'w')\n", - " for item in filtered_list:\n", - " file.writelines(item + '\\n')\n", - "\n", - " file.close()\n", - "\n", - "import sys\n", - "before = [str(m) for m in sys.modules]\n", - "\n", - "# %tensorflow_version 1.x\n", - "import tensorflow\n", - "print(tensorflow.__version__)\n", - "print(\"Tensorflow enabled.\")\n", - "\n", - "# ------- Variable specific to Stardist -------\n", - "from stardist import fill_label_holes, random_label_cmap, calculate_extents, gputools_available\n", - "from stardist.models import Config3D, StarDist3D, StarDistData3D\n", - "from stardist import relabel_image_stardist3D, Rays_GoldenSpiral, calculate_extents\n", - "from stardist.matching import matching_dataset\n", - "from csbdeep.utils import Path, normalize, download_and_extract_zip_file, plot_history # for loss plot\n", - "from csbdeep.io import save_tiff_imagej_compatible\n", - "import numpy as np\n", - "np.random.seed(42)\n", - "lbl_cmap = random_label_cmap()\n", - "from __future__ import print_function, unicode_literals, absolute_import, division\n", - "import cv2\n", - "%matplotlib inline\n", - "%config InlineBackend.figure_format = 'retina'\n", - "\n", - "# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n", - "import numpy as np\n", - "from matplotlib import pyplot as plt\n", - "import urllib\n", - "import os, random\n", - "import shutil \n", - "import zipfile\n", - "from tifffile import imread, imsave\n", - "import time\n", - "import sys\n", - "import wget\n", - "from pathlib import Path\n", - "import pandas as pd\n", - "import csv\n", - "from glob import glob\n", - "from scipy import signal\n", - "from scipy import ndimage\n", - "from skimage import io\n", - "from sklearn.linear_model import LinearRegression\n", - "from skimage.util import img_as_uint\n", - "import matplotlib as mpl\n", - "from skimage.metrics import structural_similarity\n", - "from skimage.metrics import peak_signal_noise_ratio as psnr\n", - "from astropy.visualization import simple_norm\n", - "from skimage import img_as_float32\n", - "from skimage.util import img_as_ubyte\n", - "from tqdm import tqdm \n", - "from fpdf import FPDF, HTMLMixin\n", - "from datetime import datetime\n", - "from pip._internal.operations.freeze import freeze\n", - "import subprocess\n", - "from astropy.visualization import simple_norm\n", - "\n", - "\n", - "# For sliders and dropdown menu and progress bar\n", - "from ipywidgets import interact\n", - "import ipywidgets as widgets\n", - "\n", - "# Colors for the warning messages\n", - "class bcolors:\n", - " WARNING = '\\033[31m'\n", - "\n", - "W = '\\033[0m' # white (normal)\n", - "R = '\\033[31m' # red\n", - "\n", - "#Disable some of the tensorflow warnings\n", - "import warnings\n", - "warnings.filterwarnings(\"ignore\")\n", - "\n", - "print(\"Libraries installed\")\n", - "\n", - "\n", - "\n", - "# Check if this is the latest version of the notebook\n", - "Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n", - "\n", - "if Notebook_version == list(Latest_notebook_version.columns):\n", - " print(\"This notebook is up-to-date.\")\n", - "\n", - "if not Notebook_version == list(Latest_notebook_version.columns):\n", - " print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n", - "\n", - "def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n", - " class MyFPDF(FPDF, HTMLMixin):\n", - " pass\n", - "\n", - " pdf = MyFPDF()\n", - " pdf.add_page()\n", - " pdf.set_right_margin(-1)\n", - " pdf.set_font(\"Arial\", size = 11, style='B') \n", - "\n", - " Network = 'StarDist 3D'\n", - " day = datetime.now()\n", - " datetime_str = str(day)[0:10]\n", - "\n", - " Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n", - " pdf.multi_cell(180, 5, txt = Header, align = 'L') \n", - " \n", - " # add another cell \n", - " if trained:\n", - " training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n", - " pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n", - " pdf.ln(1)\n", - "\n", - " Header_2 = 'Information for your materials and methods:'\n", - " pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n", - "\n", - " all_packages = ''\n", - " for requirement in freeze(local_only=True):\n", - " all_packages = all_packages+requirement+', '\n", - " #print(all_packages)\n", - "\n", - " #Main Packages\n", - " main_packages = ''\n", - " version_numbers = []\n", - " for name in ['tensorflow','numpy','Keras','csbdeep']:\n", - " find_name=all_packages.find(name)\n", - " main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n", - " #Version numbers only here:\n", - " version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n", - "\n", - " cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n", - " cuda_version = cuda_version.stdout.decode('utf-8')\n", - " cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n", - " gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n", - " gpu_name = gpu_name.stdout.decode('utf-8')\n", - " gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n", - " #print(cuda_version[cuda_version.find(', V')+3:-1])\n", - " #print(gpu_name)\n", - "\n", - " shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n", - " dataset_size = len(os.listdir(Training_source))\n", - "\n", - " text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(int(len(X)))+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_height)+','+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+conf.train_dist_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n", - "\n", - " #text = 'The '+Network+' model ('+model_name+') was trained using '+str(dataset_size)+' paired images (image dimensions: '+str(shape)+') using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The GPU used was a '+gpu_name+'.'\n", - "\n", - " if pretrained_model:\n", - " text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(int(len(X)))+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_height)+','+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+conf.train_dist_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was retrained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n", - "\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - " pdf.multi_cell(190, 5, txt = text, align='L')\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.ln(1)\n", - " pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n", - " pdf.set_font('')\n", - " if augmentation:\n", - " aug_text = 'The dataset was augmented.'\n", - " else:\n", - " aug_text = 'No augmentation was used for training.'\n", - " pdf.multi_cell(190, 5, txt=aug_text, align='L')\n", - " pdf.set_font('Arial', size = 11, style = 'B')\n", - " pdf.ln(1)\n", - " pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - " if Use_Default_Advanced_Parameters:\n", - " pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n", - " pdf.cell(200, 5, txt='The following parameters were used for training:')\n", - " pdf.ln(1)\n", - " html = \"\"\" \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
ParameterValue
number_of_epochs{0}
patch_size{1}
batch_size{2}
number_of_steps{3}
percentage_validation{4}
n_rays{5}
initial_learning_rate{6}
\n", - " \"\"\".format(number_of_epochs,str(patch_height)+'x'+str(patch_size)+'x'+str(patch_size),batch_size,number_of_steps,percentage_validation,n_rays,initial_learning_rate)\n", - " pdf.write_html(html)\n", - "\n", - " #pdf.multi_cell(190, 5, txt = text_2, align='L')\n", - " pdf.set_font(\"Arial\", size = 11, style='B')\n", - " pdf.ln(1)\n", - " pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.cell(30, 5, txt= 'Training_source:', align = 'L', ln=0)\n", - " pdf.set_font('')\n", - " pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.cell(28, 5, txt= 'Training_target:', align = 'L', ln=0)\n", - " pdf.set_font('')\n", - " pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n", - " #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n", - " pdf.ln(1)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.cell(21, 5, txt= 'Model Path:', align = 'L', ln=0)\n", - " pdf.set_font('')\n", - " pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n", - " pdf.ln(1)\n", - " pdf.cell(60, 5, txt = 'Example Training pair', ln=1)\n", - " pdf.ln(1)\n", - " exp_size = io.imread('/content/TrainingDataExample_StarDist3D.png').shape\n", - " pdf.image('/content/TrainingDataExample_StarDist3D.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n", - " pdf.ln(1)\n", - " ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n", - " pdf.multi_cell(190, 5, txt = ref_1, align='L')\n", - " ref_2 = '- StarDist 3D: Schmidt, Uwe, et al. \"Cell detection with star-convex polygons.\" International Conference on Medical Image Computing and Computer-Assisted Intervention. Springer, Cham, 2018.'\n", - " pdf.multi_cell(190, 5, txt = ref_2, align='L')\n", - " ref_3 = '- StarDist 3D: Weigert, Martin, et al. \"Star-convex polyhedra for 3d object detection and segmentation in microscopy.\" The IEEE Winter Conference on Applications of Computer Vision. 2020.'\n", - " pdf.multi_cell(190, 5, txt = ref_3, align='L')\n", - " # if Use_Data_augmentation:\n", - " # ref_4 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\n", - " # pdf.multi_cell(190, 5, txt = ref_4, align='L')\n", - " pdf.ln(3)\n", - " reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n", - " pdf.set_font('Arial', size = 11, style='B')\n", - " pdf.multi_cell(190, 5, txt=reminder, align='C')\n", - "\n", - " pdf.output(model_path+'/'+model_name+'/'+model_name+\"_training_report.pdf\")\n", - "\n", - "\n", - "def qc_pdf_export():\n", - " class MyFPDF(FPDF, HTMLMixin):\n", - " pass\n", - "\n", - " pdf = MyFPDF()\n", - " pdf.add_page()\n", - " pdf.set_right_margin(-1)\n", - " pdf.set_font(\"Arial\", size = 11, style='B') \n", - "\n", - " Network = 'Stardist 3D'\n", - "\n", - " day = datetime.now()\n", - " datetime_str = str(day)[0:10]\n", - "\n", - " Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n", - " pdf.multi_cell(180, 5, txt = Header, align = 'L') \n", - "\n", - " all_packages = ''\n", - " for requirement in freeze(local_only=True):\n", - " all_packages = all_packages+requirement+', '\n", - "\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 11, style = 'B')\n", - " pdf.ln(2)\n", - " pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n", - " pdf.ln(1)\n", - " exp_size = io.imread(full_QC_model_path+'/Quality Control/lossCurvePlots.png').shape\n", - " if os.path.exists(full_QC_model_path+'/Quality Control/lossCurvePlots.png'):\n", - " pdf.image(full_QC_model_path+'/Quality Control/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n", - " else:\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size=10)\n", - " pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.')\n", - " pdf.ln(2)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.ln(3)\n", - " pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n", - " pdf.ln(1)\n", - " exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n", - " pdf.image(full_QC_model_path+'/Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/10))\n", - " pdf.ln(1)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 11, style = 'B')\n", - " pdf.ln(1)\n", - " pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - "\n", - " pdf.ln(1)\n", - " html = \"\"\"\n", - " \n", - " \n", - " \"\"\"\n", - " with open(full_QC_model_path+'/Quality Control/Quality_Control for '+QC_model_name+'.csv', 'r') as csvfile:\n", - " metrics = csv.reader(csvfile)\n", - " header = next(metrics)\n", - " image = header[0]\n", - " PvGT_IoU = header[1]\n", - " header = \"\"\"\n", - " \n", - " \n", - " \n", - " \"\"\".format(image,PvGT_IoU)\n", - " html = html+header\n", - " for row in metrics:\n", - " image = row[0]\n", - " PvGT_IoU = row[1]\n", - " cells = \"\"\"\n", - " \n", - " \n", - " \n", - " \"\"\".format(image,str(round(float(PvGT_IoU),3)))\n", - " html = html+cells\n", - " html = html+\"\"\"
{0}{1}
{0}{1}
\"\"\"\n", - " \n", - " pdf.write_html(html)\n", - "\n", - " pdf.ln(1)\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - " ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n", - " pdf.multi_cell(190, 5, txt = ref_1, align='L')\n", - " ref_2 = ' - Weigert, Martin, et al. \"Star-convex polyhedra for 3d object detection and segmentation in microscopy.\" The IEEE Winter Conference on Applications of Computer Vision. 2020.'\n", - " pdf.multi_cell(190, 5, txt = ref_2, align='L')\n", - "\n", - " pdf.ln(3)\n", - " reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n", - "\n", - " pdf.set_font('Arial', size = 11, style='B')\n", - " pdf.multi_cell(190, 5, txt=reminder, align='C')\n", - "\n", - " pdf.output(full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n", - "\n", - "\n", - "\n", - "def random_fliprot(img, mask, axis=None): \n", - " if axis is None:\n", - " axis = tuple(range(mask.ndim))\n", - " axis = tuple(axis)\n", - " \n", - " assert img.ndim>=mask.ndim\n", - " perm = tuple(np.random.permutation(axis))\n", - " transpose_axis = np.arange(mask.ndim)\n", - " for a, p in zip(axis, perm):\n", - " transpose_axis[a] = p\n", - " transpose_axis = tuple(transpose_axis)\n", - " img = img.transpose(transpose_axis + tuple(range(mask.ndim, img.ndim))) \n", - " mask = mask.transpose(transpose_axis) \n", - " for ax in axis: \n", - " if np.random.rand() > 0.5:\n", - " img = np.flip(img, axis=ax)\n", - " mask = np.flip(mask, axis=ax)\n", - " return img, mask \n", - "\n", - "def random_intensity_change(img):\n", - " img = img*np.random.uniform(0.6,2) + np.random.uniform(-0.2,0.2)\n", - " return img\n", - "\n", - "def augmenter(x, y):\n", - " \"\"\"Augmentation of a single input/label image pair.\n", - " x is an input image\n", - " y is the corresponding ground-truth label image\n", - " \"\"\"\n", - " # Note that we only use fliprots along axis=(1,2), i.e. the yx axis \n", - " # as 3D microscopy acquisitions are usually not axially symmetric\n", - " x, y = random_fliprot(x, y, axis=(1,2))\n", - " x = random_intensity_change(x)\n", - " return x, y\n", - "\n", - "# Build requirements file for local run\n", - "after = [str(m) for m in sys.modules]\n", - "build_requirements_file(before, after)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "HLYcZR9gMv42" - }, - "source": [ - "# **3. Select your parameters and paths**\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "FQ_QxtSWQ7CL" - }, - "source": [ - "## **3.1. Setting main training parameters**\n", - "---\n", - " " - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "AuESFimvMv43" - }, - "source": [ - " **Paths for training, predictions and results**\n", - "\n", - "**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source (images of nuclei) and Training_target (masks) training data respecively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n", - "\n", - "**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n", - "\n", - "**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n", - "\n", - "**Training parameters**\n", - "\n", - "**`number_of_epochs`:** Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a 200 epochs, but a full training should run for more. Evaluate the performance after training (see 5.). **Default value: 200**\n", - "\n", - "**Advanced parameters - experienced users only**\n", - "\n", - "**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 1** \n", - "\n", - "**`number_of_steps`:** Define the number of training steps by epoch. By default (or if set to 0) this parameter is calculated so that each image / patch is seen at least once per epoch. **Default value: Number of patch / batch_size**. This value is multiplied by 6 when augmentation is enabled.\n", - "\n", - "**`patch_size`:** and **`patch_height`:** Input the size of the patches use to train StarDist 3D (length of a side). The value should be smaller or equal to the dimensions of the image. Make patch size and patch_height as large as possible and divisible by 8 and 4, respectively. **Default value: dimension of the training images**\n", - "\n", - "**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during the training. **Default value: 10** \n", - "\n", - "**`n_rays`:** Set number of rays (corners) used for StarDist (for instance a cube has 8 corners). **Default value: 96** \n", - "\n", - "**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0003**\n", - "\n", - "**If you get an Out of memory (OOM) error during the training, manually decrease the patch_size and patch_height values until the OOM error disappear.**" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "ewpNJ_I0Mv47", - "cellView": "form" - }, - "source": [ - "\n", - "\n", - "#@markdown ###Path to training images: \n", - "Training_source = \"\" #@param {type:\"string\"}\n", - "training_images = Training_source\n", - "\n", - "\n", - "Training_target = \"\" #@param {type:\"string\"}\n", - "mask_images = Training_target \n", - "\n", - "\n", - "#@markdown ###Name of the model and path to model folder:\n", - "model_name = \"\" #@param {type:\"string\"}\n", - "\n", - "model_path = \"\" #@param {type:\"string\"}\n", - "trained_model = model_path \n", - "\n", - "#@markdown ### Other parameters for training:\n", - "number_of_epochs = 200#@param {type:\"number\"}\n", - "\n", - "#@markdown ###Advanced Parameters\n", - "Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n", - "\n", - "#@markdown ###If not, please input:\n", - "\n", - "#GPU_limit = 90 #@param {type:\"number\"}\n", - "batch_size = 2#@param {type:\"number\"}\n", - "number_of_steps = 0#@param {type:\"number\"}\n", - "patch_size = 96#@param {type:\"number\"} # pixels in\n", - "patch_height = 48#@param {type:\"number\"}\n", - "percentage_validation = 10#@param {type:\"number\"}\n", - "n_rays = 96 #@param {type:\"number\"}\n", - "initial_learning_rate = 0.0003 #@param {type:\"number\"}\n", - "\n", - "if (Use_Default_Advanced_Parameters): \n", - " print(\"Default advanced parameters enabled\")\n", - " batch_size = 2 # default from original author's notebook\n", - " n_rays = 96\n", - " percentage_validation = 10\n", - " initial_learning_rate = 0.0003\n", - "\n", - " patch_size = 96 # default from original author's notebook\n", - " patch_height = 48 # default from original author's notebook\n", - "\n", - "\n", - "percentage = percentage_validation/100\n", - "\n", - "#here we check that no model with the same name already exist, if so print a warning\n", - "\n", - "if os.path.exists(model_path+'/'+model_name):\n", - " print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n", - " print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\"+W)\n", - " \n", - "\n", - "random_choice=random.choice(os.listdir(Training_source))\n", - "x = imread(Training_source+\"/\"+random_choice)\n", - "\n", - "# Here we check that the input images are stacks\n", - "if len(x.shape) == 3:\n", - " print(\"Image dimensions (z,y,x)\",x.shape)\n", - "\n", - "if not len(x.shape) == 3:\n", - " print(bcolors.WARNING +\"Your images appear to have the wrong dimensions. Image dimension\",x.shape)\n", - "\n", - "\n", - "#Find image Z dimension and select the mid-plane\n", - "Image_Z = x.shape[0]\n", - "mid_plane = int(Image_Z / 2)+1\n", - "\n", - "\n", - "#Find image XY dimension\n", - "Image_Y = x.shape[1]\n", - "Image_X = x.shape[2]\n", - "\n", - "# If default parameters, patch size is the same as image size\n", - "# if (Use_Default_Advanced_Parameters): \n", - "# patch_size = min(Image_Y, Image_X) \n", - "# patch_height = Image_Z\n", - "\n", - "#Hyperparameters failsafes\n", - "\n", - "# Here we check that patch_size is smaller than the smallest xy dimension of the image \n", - "\n", - "if patch_size > min(Image_Y, Image_X):\n", - " patch_size = min(Image_Y, Image_X)\n", - " print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n", - "\n", - "# Here we check that patch_size is divisible by 8\n", - "if not patch_size % 8 == 0:\n", - " patch_size = ((int(patch_size / 8)-1) * 8)\n", - " print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is now:\",patch_size)\n", - "\n", - "# Here we check that patch_height is smaller than the z dimension of the image \n", - "\n", - "if patch_height > Image_Z :\n", - " patch_height = Image_Z\n", - " print (bcolors.WARNING + \" Your chosen patch_height is bigger than the z dimension of your image; therefore the patch_size chosen is now:\",patch_height)\n", - "\n", - "# Here we check that patch_height is divisible by 4\n", - "if not patch_height % 4 == 0:\n", - " patch_height = ((int(patch_height / 4)-1) * 4)\n", - " if patch_height == 0:\n", - " patch_height = 4\n", - " print (bcolors.WARNING + \" Your chosen patch_height is not divisible by 4; therefore the patch_size chosen is now:\",patch_height)\n", - "\n", - "# Here we disable pre-trained model by default (in case the next cell is not ran)\n", - "Use_pretrained_model = False\n", - "\n", - "# Here we disable data augmentation by default (in case the cell is not ran)\n", - "\n", - "Use_Data_augmentation = False\n", - "\n", - "print(\"Parameters initiated.\")\n", - "\n", - "\n", - "os.chdir(Training_target)\n", - "y = imread(Training_target+\"/\"+random_choice)\n", - "\n", - "#Here we use a simple normalisation strategy to visualise the image\n", - "from astropy.visualization import simple_norm\n", - "norm = simple_norm(x, percent = 99)\n", - "\n", - "mid_plane = int(Image_Z / 2)+1\n", - "\n", - "f=plt.figure(figsize=(16,8))\n", - "plt.subplot(1,2,1)\n", - "plt.imshow(x[mid_plane], interpolation='nearest', norm=norm, cmap='magma')\n", - "plt.axis('off')\n", - "plt.title('Training source (single Z plane)');\n", - "plt.subplot(1,2,2)\n", - "plt.imshow(y[mid_plane], interpolation='nearest', cmap=lbl_cmap)\n", - "plt.axis('off')\n", - "plt.title('Training target (single Z plane)');\n", - "plt.savefig('/content/TrainingDataExample_StarDist3D.png',bbox_inches='tight',pad_inches=0)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "xyQZKby8yFME" - }, - "source": [ - "## **3.2. Data augmentation**\n", - "---\n", - "" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "w_jCy7xOx2g3" - }, - "source": [ - "Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n", - "\n", - " **However, data augmentation is not a magic solution and may also introduce issues. Therefore, we recommend that you train your network with and without augmentation, and use the QC section to validate that it improves overall performances.** \n", - "\n", - "Data augmentation is performed here by flipping, rotating and modifying the intensity of the images.\n", - "\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "cellView": "form", - "id": "2ArWgdghVlQw" - }, - "source": [ - "#Data augmentation\n", - "\n", - "Use_Data_augmentation = False #@param {type:\"boolean\"}\n", - "\n", - "if Use_Data_augmentation:\n", - " augmenter = augmenter\n", - " print(\"Data augmentation enabled. Let's flip!\")\n", - "else:\n", - " augmenter = None\n", - " print(\"Data augmentation disabled.\")\n", - "\n", - "\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "3L9zSGtORKYI" - }, - "source": [ - "\n", - "## **3.3. Using weights from a pre-trained model as initial weights**\n", - "---\n", - " Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a StarDist model**. \n", - "\n", - " This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n", - "\n", - " In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. " - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "9vC2n-HeLdiJ", - "cellView": "form" - }, - "source": [ - "# @markdown ##Loading weights from a pre-trained network\n", - "\n", - "Use_pretrained_model = False #@param {type:\"boolean\"}\n", - "\n", - "pretrained_model_choice = \"Demo_3D_Model_from_Stardist_3D_paper\" #@param [\"Model_from_file\", \"Demo_3D_Model_from_Stardist_3D_paper\"]\n", - "\n", - "Weights_choice = \"best\" #@param [\"last\", \"best\"]\n", - "\n", - "\n", - "#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n", - "pretrained_model_path = \"\" #@param {type:\"string\"}\n", - "\n", - "# --------------------- Check if we load a previously trained model ------------------------\n", - "if Use_pretrained_model:\n", - "\n", - "# --------------------- Load the model from the choosen path ------------------------\n", - " if pretrained_model_choice == \"Model_from_file\":\n", - " h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n", - "\n", - "\n", - "# --------------------- Download the Demo 3D model provided in the Stardist 3D github ------------------------\n", - "\n", - " if pretrained_model_choice == \"Demo_3D_Model_from_Stardist_3D_paper\":\n", - " pretrained_model_name = \"Demo_3D\"\n", - " pretrained_model_path = \"/content/\"+pretrained_model_name\n", - " print(\"Downloading the Demo 3D model from the Stardist_3D paper\")\n", - " if os.path.exists(pretrained_model_path):\n", - " shutil.rmtree(pretrained_model_path)\n", - " os.makedirs(pretrained_model_path)\n", - " wget.download(\"https://raw.githubusercontent.com/mpicbg-csbd/stardist/master/models/examples/3D_demo/config.json\", pretrained_model_path)\n", - " wget.download(\"/~https://github.com/mpicbg-csbd/stardist/raw/master/models/examples/3D_demo/thresholds.json\", pretrained_model_path)\n", - " wget.download(\"/~https://github.com/mpicbg-csbd/stardist/blob/master/models/examples/3D_demo/weights_best.h5?raw=true\", pretrained_model_path)\n", - " wget.download(\"/~https://github.com/mpicbg-csbd/stardist/blob/master/models/examples/3D_demo/weights_last.h5?raw=true\", pretrained_model_path)\n", - " h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n", - "\n", - "# --------------------- Add additional pre-trained models here ------------------------\n", - "\n", - "\n", - "\n", - "# --------------------- Check the model exist ------------------------\n", - "# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n", - " if not os.path.exists(h5_file_path):\n", - " print(bcolors.WARNING+'WARNING: weights_last.h5 pretrained model does not exist'+W)\n", - " Use_pretrained_model = False\n", - "\n", - " \n", - "# If the model path contains a pretrain model, we load the training rate, \n", - " if os.path.exists(h5_file_path):\n", - "#Here we check if the learning rate can be loaded from the quality control folder\n", - " if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n", - "\n", - " with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n", - " csvRead = pd.read_csv(csvfile, sep=',')\n", - " #print(csvRead)\n", - " \n", - " if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n", - " print(\"pretrained network learning rate found\")\n", - " #find the last learning rate\n", - " lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n", - " #Find the learning rate corresponding to the lowest validation loss\n", - " min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n", - " #print(min_val_loss)\n", - " bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n", - "\n", - " if Weights_choice == \"last\":\n", - " print('Last learning rate: '+str(lastLearningRate))\n", - "\n", - " if Weights_choice == \"best\":\n", - " print('Learning rate of best validation loss: '+str(bestLearningRate))\n", - "\n", - " if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n", - " bestLearningRate = initial_learning_rate\n", - " lastLearningRate = initial_learning_rate\n", - " print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n", - "\n", - "#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n", - " if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n", - " print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n", - " bestLearningRate = initial_learning_rate\n", - " lastLearningRate = initial_learning_rate\n", - "\n", - "\n", - "# Display info about the pretrained model to be loaded (or not)\n", - "if Use_pretrained_model:\n", - " print(bcolors.WARNING+'Weights found in:')\n", - " print(h5_file_path)\n", - " print(bcolors.WARNING+'will be loaded prior to training.')\n", - "\n", - "else:\n", - " print(bcolors.WARNING+'No pretrained network will be used.')" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "MCGklf1vZf2M" - }, - "source": [ - "#**4. Train the network**\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1KYOuygETJkT" - }, - "source": [ - "## **4.1. Prepare the training data and model for training**\n", - "---\n", - "Here, we use the information from 3. to build the model and convert the training data into a suitable format for training." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "lIUAOJ_LMv5E", - "cellView": "form" - }, - "source": [ - "#@markdown ##Create the model and dataset objects\n", - "\n", - "# --------------------- Here we delete the model folder if it already exist ------------------------\n", - "\n", - "if os.path.exists(model_path+'/'+model_name):\n", - " print(bcolors.WARNING +\"!! WARNING: Model folder already exists and has been removed !!\" + W)\n", - " shutil.rmtree(model_path+'/'+model_name)\n", - "\n", - "\n", - "import warnings\n", - "warnings.simplefilter(\"ignore\")\n", - "\n", - "# --------------------- Here we load the augmented data or the raw data ------------------------\n", - "\n", - "# if Use_Data_augmentation:\n", - "# Training_source_dir = Training_source_augmented\n", - "# Training_target_dir = Training_target_augmented\n", - "\n", - "# if not Use_Data_augmentation:\n", - "# Training_source_dir = Training_source\n", - "# Training_target_dir = Training_target\n", - "# --------------------- ------------------------------------------------\n", - "\n", - "Training_source_dir = Training_source\n", - "Training_target_dir = Training_target\n", - "training_images_tiff=Training_source_dir+\"/*.tif\"\n", - "mask_images_tiff=Training_target_dir+\"/*.tif\"\n", - "\n", - "\n", - "# this funtion imports training images and masks and sorts them suitable for the network\n", - "X = sorted(glob(training_images_tiff)) \n", - "Y = sorted(glob(mask_images_tiff)) \n", - "\n", - "# assert -funtion check that X and Y really have images. If not this cell raises an error\n", - "assert all(Path(x).name==Path(y).name for x,y in zip(X,Y))\n", - "\n", - "# Here we map the training dataset (images and masks).\n", - "X = list(map(imread,X))\n", - "Y = list(map(imread,Y))\n", - "\n", - "n_channel = 1 if X[0].ndim == 3 else X[0].shape[-1]\n", - "\n", - "\n", - "\n", - "#Normalize images and fill small label holes.\n", - "axis_norm = (0,1,2) # normalize channels independently\n", - "# axis_norm = (0,1,2,3) # normalize channels jointly\n", - "if n_channel > 1:\n", - " print(\"Normalizing image channels %s.\" % ('jointly' if axis_norm is None or 3 in axis_norm else 'independently'))\n", - " sys.stdout.flush()\n", - "\n", - "X = [normalize(x,1,99.8,axis=axis_norm) for x in tqdm(X)]\n", - "Y = [fill_label_holes(y) for y in tqdm(Y)]\n", - "\n", - "#Here we split the your training dataset into training images (90 %) and validation images (10 %). \n", - "\n", - "assert len(X) > 1, \"not enough training data\"\n", - "rng = np.random.RandomState(42)\n", - "ind = rng.permutation(len(X))\n", - "n_val = max(1, int(round(percentage * len(ind))))\n", - "ind_train, ind_val = ind[:-n_val], ind[-n_val:]\n", - "X_val, Y_val = [X[i] for i in ind_val] , [Y[i] for i in ind_val]\n", - "X_trn, Y_trn = [X[i] for i in ind_train], [Y[i] for i in ind_train] \n", - "print('number of images: %3d' % len(X))\n", - "print('- training: %3d' % len(X_trn))\n", - "print('- validation: %3d' % len(X_val))\n", - "\n", - "\n", - "extents = calculate_extents(Y)\n", - "anisotropy = tuple(np.max(extents) / extents)\n", - "print('empirical anisotropy of labeled objects = %s' % str(anisotropy))\n", - "\n", - "\n", - "# Use OpenCL-based computations for data generator during training (requires 'gputools')\n", - "use_gpu = False and gputools_available()\n", - "\n", - "\n", - "#Here we ensure that our network has a minimal number of steps\n", - "if (Use_Default_Advanced_Parameters) or (number_of_steps == 0):\n", - " number_of_steps = (Image_X//patch_size)*(Image_Y//patch_size)*(Image_Z//patch_height) * int(len(X)/batch_size)+1\n", - " if (Use_Data_augmentation):\n", - " number_of_steps = number_of_steps*6\n", - "\n", - "print(\"Number of steps: \"+str(number_of_steps))\n", - "\n", - "# --------------------- Using pretrained model ------------------------\n", - "#Here we ensure that the learning rate set correctly when using pre-trained models\n", - "if Use_pretrained_model:\n", - " if Weights_choice == \"last\":\n", - " initial_learning_rate = lastLearningRate\n", - "\n", - " if Weights_choice == \"best\": \n", - " initial_learning_rate = bestLearningRate\n", - "# --------------------- ---------------------- ------------------------\n", - "\n", - "# Predict on subsampled grid for increased efficiency and larger field of view\n", - "grid = tuple(1 if a > 1.5 else 2 for a in anisotropy)\n", - "\n", - "# Use rays on a Fibonacci lattice adjusted for measured anisotropy of the training data\n", - "rays = Rays_GoldenSpiral(n_rays, anisotropy=anisotropy)\n", - "\n", - "conf = Config3D (\n", - " rays = rays,\n", - " grid = grid,\n", - " anisotropy = anisotropy,\n", - " use_gpu = use_gpu,\n", - " n_channel_in = n_channel,\n", - " train_learning_rate = initial_learning_rate,\n", - " train_patch_size = (patch_height, patch_size, patch_size),\n", - " train_batch_size = batch_size,\n", - ")\n", - "print(conf)\n", - "vars(conf)\n", - "\n", - "\n", - "# --------------------- This is currently disabled as it give an error ------------------------\n", - "#here we limit GPU to 80%\n", - "if use_gpu:\n", - " from csbdeep.utils.tf import limit_gpu_memory\n", - " # adjust as necessary: limit GPU memory to be used by TensorFlow to leave some to OpenCL-based computations\n", - " limit_gpu_memory(0.8)\n", - "# --------------------- ---------------------- ------------------------\n", - "\n", - "\n", - "# Here we create a model according to section 5.3.\n", - "model = StarDist3D(conf, name=model_name, basedir=trained_model)\n", - "\n", - "# --------------------- Using pretrained model ------------------------\n", - "# Load the pretrained weights \n", - "if Use_pretrained_model:\n", - " model.load_weights(h5_file_path)\n", - "# --------------------- ---------------------- ------------------------\n", - "\n", - "\n", - "#Here we check the FOV of the network.\n", - "median_size = calculate_extents(Y, np.median)\n", - "fov = np.array(model._axes_tile_overlap('ZYX'))\n", - "if any(median_size > fov):\n", - " print(\"WARNING: median object size larger than field of view of the neural network.\")\n", - "\n", - "pdf_export(augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "0Dfn8ZsEMv5d" - }, - "source": [ - "## **4.2. Start Training**\n", - "---\n", - "\n", - "When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n", - "\n", - "* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point.\n", - "\n", - "Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder." - ] - }, - { - "cell_type": "code", - "metadata": { - "scrolled": true, - "id": "iwNmp1PUzRDQ", - "cellView": "form" - }, - "source": [ - "import time\n", - "start = time.time()\n", - "\n", - "import warnings\n", - "warnings.filterwarnings(\"ignore\")\n", - "\n", - "#@markdown ##Start training\n", - "\n", - "# augmenter = None\n", - "\n", - "# def augmenter(X_batch, Y_batch):\n", - "# \"\"\"Augmentation for data batch.\n", - "# X_batch is a list of input images (length at most batch_size)\n", - "# Y_batch is the corresponding list of ground-truth label images\n", - "# \"\"\"\n", - "# # ...\n", - "# return X_batch, Y_batch\n", - "\n", - "# Training the model. \n", - "# 'input_epochs' and 'steps' refers to your input data in section 5.1 \n", - "history = model.train(X_trn, Y_trn, validation_data=(X_val,Y_val), augmenter=augmenter,\n", - " epochs=number_of_epochs, steps_per_epoch=number_of_steps)\n", - "None;\n", - "\n", - "print(\"Training done\")\n", - "\n", - "\n", - "# convert the history.history dict to a pandas DataFrame: \n", - "lossData = pd.DataFrame(history.history) \n", - "\n", - "if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n", - " shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n", - "\n", - "os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n", - "\n", - "# The training evaluation.csv is saved (overwrites the Files if needed). \n", - "lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n", - "with open(lossDataCSVpath, 'w') as f:\n", - " writer = csv.writer(f)\n", - " writer.writerow(['loss','val_loss', 'learning rate'])\n", - " for i in range(len(history.history['loss'])):\n", - " writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n", - "\n", - "\n", - "print(\"Network optimization in progress\")\n", - "\n", - "#Here we optimize the network.\n", - "model.optimize_thresholds(X_val, Y_val)\n", - "print(\"Done\")\n", - "\n", - "\n", - "# Displaying the time elapsed for training\n", - "dt = time.time() - start\n", - "mins, sec = divmod(dt, 60) \n", - "hour, mins = divmod(mins, 60) \n", - "print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n", - "\n", - "#Create a pdf document with training summary\n", - "pdf_export(trained = True, augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "_0Hynw3-xHp1" - }, - "source": [ - "# **5. Evaluate your model**\n", - "---\n", - "\n", - "This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n", - "\n", - "**We highly recommend to perform quality control on all newly trained models.**\n", - "\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "eAJzMwPA6tlH", - "cellView": "form" - }, - "source": [ - "# model name and path\n", - "#@markdown ###Do you want to assess the model you just trained ?\n", - "Use_the_current_trained_model = True #@param {type:\"boolean\"}\n", - "\n", - "#@markdown ###If not, please provide the path to the model folder:\n", - "\n", - "QC_model_folder = \"\" #@param {type:\"string\"}\n", - "\n", - "#Here we define the loaded model name and path\n", - "QC_model_name = os.path.basename(QC_model_folder)\n", - "QC_model_path = os.path.dirname(QC_model_folder)\n", - "\n", - "\n", - "\n", - "if (Use_the_current_trained_model): \n", - " QC_model_name = model_name\n", - " QC_model_path = model_path\n", - "\n", - "full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n", - "if os.path.exists(full_QC_model_path):\n", - " print(\"The \"+QC_model_name+\" network will be evaluated\")\n", - "else:\n", - " W = '\\033[0m' # white (normal)\n", - " R = '\\033[31m' # red\n", - " print(R+'!! WARNING: The chosen model does not exist !!'+W)\n", - " print('Please make sure you provide a valid model path and model name before proceeding further.')\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "dhJROwlAMv5o" - }, - "source": [ - "## **5.1. Inspection of the loss function**\n", - "---\n", - "\n", - "First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n", - "\n", - "**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n", - "\n", - "**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n", - "\n", - "During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n", - "\n", - "Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "vMzSP50kMv5p", - "cellView": "form" - }, - "source": [ - "#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n", - "import csv\n", - "from matplotlib import pyplot as plt\n", - "\n", - "lossDataFromCSV = []\n", - "vallossDataFromCSV = []\n", - "\n", - "with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n", - " csvRead = csv.reader(csvfile, delimiter=',')\n", - " next(csvRead)\n", - " for row in csvRead:\n", - " lossDataFromCSV.append(float(row[0]))\n", - " vallossDataFromCSV.append(float(row[1]))\n", - "\n", - "epochNumber = range(len(lossDataFromCSV))\n", - "plt.figure(figsize=(15,10))\n", - "\n", - "plt.subplot(2,1,1)\n", - "plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n", - "plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n", - "plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n", - "plt.ylabel('Loss')\n", - "plt.xlabel('Epoch number')\n", - "plt.legend()\n", - "\n", - "\n", - "plt.subplot(2,1,2)\n", - "plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n", - "plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n", - "plt.title('Training loss and validation loss vs. epoch number (log scale)')\n", - "plt.ylabel('Loss')\n", - "plt.xlabel('Epoch number')\n", - "plt.legend()\n", - "plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png',bbox_inches='tight',pad_inches=0)\n", - "plt.show()\n", - "\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "X5_92nL2xdP6" - }, - "source": [ - "## **5.2. Error mapping and quality metrics estimation**\n", - "---\n", - "This section will calculate the Intersection over Union score for all the images provided in the Source_QC_folder and Target_QC_folder ! The result for one of the image will also be displayed.\n", - "\n", - "The **Intersection over Union** metric is a method that can be used to quantify the percent overlap between the target mask and your prediction output. **Therefore, the closer to 1, the better the performance.** This metric can be used to assess the quality of your model to accurately predict nuclei. \n", - "\n", - " The results can be found in the \"*Quality Control*\" folder which is located inside your \"model_folder\"." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "w90MdriMxhjD", - "cellView": "form" - }, - "source": [ - "#@markdown ##Give the paths to an image to test the performance of the model with.\n", - "\n", - "import warnings\n", - "warnings.filterwarnings(\"ignore\")\n", - "\n", - "Source_QC_folder = \"\" #@param{type:\"string\"}\n", - "Target_QC_folder = \"\" #@param{type:\"string\"}\n", - "\n", - "#Here we allow the user to choose the number of tile to be used when predicting the images\n", - "#@markdown #####To analyse large image, your images need to be divided into tiles. Each tile will then be processed independently and re-assembled to generate the final image. \"Automatic_number_of_tiles\" will search for and use the smallest number of tiles that can be used, at the expanse of your runtime. Alternatively, manually input the number of tiles in each dimension to be used to process your images. \n", - "\n", - "Automatic_number_of_tiles = False #@param {type:\"boolean\"}\n", - "#@markdown #####If you get an Out of memory (OOM) error when using the \"Automatic_number_of_tiles\" option, disable it and manually input the values to be used to process your images. Progressively increases these numbers until the OOM error disappear.\n", - "n_tiles_Z = 1#@param {type:\"number\"}\n", - "n_tiles_Y = 1#@param {type:\"number\"}\n", - "n_tiles_X = 1#@param {type:\"number\"}\n", - "\n", - "if (Automatic_number_of_tiles): \n", - " n_tilesZYX = None\n", - "\n", - "if not (Automatic_number_of_tiles):\n", - " n_tilesZYX = (n_tiles_Z, n_tiles_Y, n_tiles_X)\n", - "\n", - "\n", - "#Create a quality control Folder and check if the folder already exist\n", - "if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\") == False:\n", - " os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n", - "\n", - "if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"):\n", - " shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n", - "\n", - "os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n", - "\n", - "\n", - "# Generate predictions from the Source_QC_folder and save them in the QC folder\n", - "\n", - "Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\n", - "\n", - "\n", - "np.random.seed(16)\n", - "lbl_cmap = random_label_cmap()\n", - "Z = sorted(glob(Source_QC_folder_tif))\n", - "Z = list(map(imread,Z))\n", - "n_channel = 1 if Z[0].ndim == 2 else Z[0].shape[-1]\n", - "axis_norm = (0,1) # normalize channels independently\n", - "\n", - "print('Number of test dataset found in the folder: '+str(len(Z)))\n", - "\n", - " \n", - " # axis_norm = (0,1,2) # normalize channels jointly\n", - "if n_channel > 1:\n", - " print(\"Normalizing image channels %s.\" % ('jointly' if axis_norm is None or 2 in axis_norm else 'independently'))\n", - "\n", - "model = StarDist3D(None, name=QC_model_name, basedir=QC_model_path)\n", - "\n", - "names = [os.path.basename(f) for f in sorted(glob(Source_QC_folder_tif))]\n", - "\n", - " \n", - "# modify the names to suitable form: path_images/image_numberX.tif\n", - " \n", - "lenght_of_Z = len(Z)\n", - " \n", - "for i in range(lenght_of_Z):\n", - " img = normalize(Z[i], 1,99.8, axis=axis_norm)\n", - " labels, polygons = model.predict_instances(img, n_tiles=n_tilesZYX)\n", - " os.chdir(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n", - " imsave(names[i], labels, polygons)\n", - "\n", - "\n", - "# Here we start testing the differences between GT and predicted masks\n", - "\n", - "\n", - "with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Quality_Control for \"+QC_model_name+\".csv\", \"w\", newline='') as file:\n", - " writer = csv.writer(file)\n", - " writer.writerow([\"image\",\"Prediction v. GT Intersection over Union\"]) \n", - "\n", - " # Initialise the lists \n", - " filename_list = []\n", - " IoU_score_list = []\n", - "\n", - " for n in os.listdir(Source_QC_folder):\n", - " if not os.path.isdir(os.path.join(Source_QC_folder,n)):\n", - " print('Running QC on: '+n)\n", - " \n", - " test_input = io.imread(os.path.join(Source_QC_folder,n))\n", - " test_prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\",n))\n", - " test_ground_truth_image = io.imread(os.path.join(Target_QC_folder, n))\n", - "\n", - "#Convert pixel values to 0 or 255\n", - " test_prediction_0_to_255 = test_prediction\n", - " test_prediction_0_to_255[test_prediction_0_to_255>0] = 255\n", - "\n", - "#Convert pixel values to 0 or 255\n", - " test_ground_truth_0_to_255 = test_ground_truth_image\n", - " test_ground_truth_0_to_255[test_ground_truth_0_to_255>0] = 255\n", - "\n", - "# Intersection over Union metric\n", - "\n", - " intersection = np.logical_and(test_ground_truth_0_to_255, test_prediction_0_to_255)\n", - " union = np.logical_or(test_ground_truth_0_to_255, test_prediction_0_to_255)\n", - " iou_score = np.sum(intersection) / np.sum(union)\n", - " writer.writerow([n, str(iou_score)])\n", - "\n", - " print(\"IoU: \"+str(round(iou_score,3)))\n", - "\n", - " filename_list.append(n)\n", - " IoU_score_list.append(iou_score)\n", - "\n", - "\n", - "\n", - "# Table with metrics as dataframe output\n", - "pdResults = pd.DataFrame(index = filename_list)\n", - "pdResults[\"IoU\"] = IoU_score_list\n", - "\n", - "# Display results\n", - "pdResults.head()\n", - "\n", - "\n", - "# ------------- For display ------------\n", - "print('--------------------------------------------------------------')\n", - "@interact\n", - "def show_QC_results(file=os.listdir(Source_QC_folder)):\n", - "\n", - " f=plt.figure(figsize=(32,8))\n", - "\n", - " test_input = io.imread(os.path.join(Source_QC_folder, file))\n", - " test_prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\", file))\n", - " test_ground_truth_image = io.imread(os.path.join(Target_QC_folder, file))\n", - "\n", - " norm = simple_norm(test_input, percent = 99)\n", - " Image_Z = test_input.shape[0]\n", - " mid_plane = int(Image_Z / 2)+1\n", - "\n", - " #Convert pixel values to 0 or 255\n", - " test_prediction_0_to_255 = test_prediction\n", - " test_prediction_0_to_255[test_prediction_0_to_255>0] = 255\n", - "\n", - " #Convert pixel values to 0 or 255\n", - " test_ground_truth_0_to_255 = test_ground_truth_image\n", - " test_ground_truth_0_to_255[test_ground_truth_0_to_255>0] = 255\n", - "\n", - " #Input\n", - " plt.subplot(1,4,1)\n", - " plt.axis('off')\n", - " plt.imshow(test_input[mid_plane], aspect='equal', norm=norm, cmap='magma', interpolation='nearest')\n", - " plt.title('Input')\n", - "\n", - " #Ground-truth\n", - " plt.subplot(1,4,2)\n", - " plt.axis('off')\n", - " plt.imshow(test_ground_truth_0_to_255[mid_plane], aspect='equal', cmap='Greens')\n", - " plt.title('Ground Truth')\n", - "\n", - " #Prediction\n", - " plt.subplot(1,4,3)\n", - " plt.axis('off')\n", - " plt.imshow(test_prediction_0_to_255[mid_plane], aspect='equal', cmap='Purples')\n", - " plt.title('Prediction')\n", - "\n", - " #Overlay\n", - " plt.subplot(1,4,4)\n", - " plt.axis('off')\n", - " plt.imshow(test_ground_truth_0_to_255[mid_plane], cmap='Greens')\n", - " plt.imshow(test_prediction_0_to_255[mid_plane], alpha=0.5, cmap='Purples')\n", - " plt.title('Ground Truth and Prediction, Intersection over Union:'+str(round(pdResults.loc[file][\"IoU\"],3)))\n", - " plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n", - "\n", - " \n", - "\n", - "# Make a pdf summary of the QC results\n", - "qc_pdf_export()\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-tJeeJjLnRkP" - }, - "source": [ - "# **6. Using the trained model**\n", - "\n", - "---\n", - "\n", - "In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "d8wuQGjoq6eN" - }, - "source": [ - "## **6.1. Generate prediction(s) from unseen dataset**\n", - "---\n", - "\n", - "The current trained model (from section 4.3) can now be used to process images. If an older model needs to be used, please untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Prediction_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n", - "\n", - "**`Data_folder`:** This folder should contains the images that you want to predict using the network that you trained.\n", - "\n", - "**`Result_folder`:** This folder will contain the predicted output ROI.\n", - "\n", - "**`Data_type`:** Please indicate if the images you want to predict are single images or stacks\n", - "\n", - "\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "y2TD5p7MZrEb", - "cellView": "form" - }, - "source": [ - "from PIL import Image\n", - "\n", - "#@markdown ### Provide the path to your dataset and to the folder where the prediction will be saved (Result folder), then play the cell to predict output on your unseen images.\n", - "\n", - "Data_folder = \"\" #@param {type:\"string\"}\n", - "#test_dataset = Data_folder\n", - "\n", - "Results_folder = \"\" #@param {type:\"string\"}\n", - "#results = results_folder\n", - "\n", - "\n", - "# model name and path\n", - "#@markdown ###Do you want to use the current trained model?\n", - "Use_the_current_trained_model = False #@param {type:\"boolean\"}\n", - "\n", - "\n", - "#@markdown ###If not, please provide the path to the model folder:\n", - "\n", - "Prediction_model_folder = \"\" #@param {type:\"string\"}\n", - "\n", - "#Here we find the loaded model name and parent path\n", - "Prediction_model_name = os.path.basename(Prediction_model_folder)\n", - "Prediction_model_path = os.path.dirname(Prediction_model_folder)\n", - "\n", - "#Here we allow the user to choose the number of tile to be used when predicting the images\n", - "#@markdown #####To analyse large image, your images need to be divided into tiles. Each tile will then be processed independently and re-assembled to generate the final image. \"Automatic_number_of_tiles\" will search for and use the smallest number of tiles that can be used, at the expanse of your runtime. Alternatively, manually input the number of tiles in each dimension to be used to process your images. \n", - "\n", - "Automatic_number_of_tiles = False #@param {type:\"boolean\"}\n", - "#@markdown #####If you get an Out of memory (OOM) error when using the \"Automatic_number_of_tiles\" option, disable it and manually input the values to be used to process your images. Progressively increases these numbers until the OOM error disappear.\n", - "n_tiles_Z = 1#@param {type:\"number\"}\n", - "n_tiles_Y = 1#@param {type:\"number\"}\n", - "n_tiles_X = 1#@param {type:\"number\"}\n", - "\n", - "if (Automatic_number_of_tiles): \n", - " n_tilesZYX = None\n", - "\n", - "if not (Automatic_number_of_tiles):\n", - " n_tilesZYX = (n_tiles_Z, n_tiles_Y, n_tiles_X)\n", - "\n", - "\n", - "if (Use_the_current_trained_model): \n", - " print(\"Using current trained network\")\n", - " Prediction_model_name = model_name\n", - " Prediction_model_path = model_path\n", - "\n", - "full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n", - "if os.path.exists(full_Prediction_model_path):\n", - " print(\"The \"+Prediction_model_name+\" network will be used.\")\n", - "else:\n", - " W = '\\033[0m' # white (normal)\n", - " R = '\\033[31m' # red\n", - " print(R+'!! WARNING: The chosen model does not exist !!'+W)\n", - " print('Please make sure you provide a valid model path and model name before proceeding further.')\n", - "\n", - "\n", - "#single images\n", - "#testDATA = test_dataset\n", - "Dataset = Data_folder+\"/*.tif\"\n", - "\n", - "\n", - "np.random.seed(16)\n", - "lbl_cmap = random_label_cmap()\n", - "X = sorted(glob(Dataset))\n", - "X = list(map(imread,X))\n", - "n_channel = 1 if X[0].ndim == 2 else X[0].shape[-1]\n", - "axis_norm = (0,1) # normalize channels independently\n", - " \n", - "# axis_norm = (0,1,2) # normalize channels jointly\n", - "if n_channel > 1:\n", - " print(\"Normalizing image channels %s.\" % ('jointly' if axis_norm is None or 2 in axis_norm else 'independently'))\n", - "model = StarDist3D(None, name=Prediction_model_name, basedir=Prediction_model_path)\n", - " \n", - "#Sorting and mapping original test dataset\n", - "X = sorted(glob(Dataset))\n", - "X = list(map(imread,X))\n", - "names = [os.path.basename(f) for f in sorted(glob(Dataset))]\n", - "\n", - "# modify the names to suitable form: path_images/image_numberX.tif\n", - "FILEnames=[]\n", - "for m in names:\n", - " m=Results_folder+'/'+m\n", - " FILEnames.append(m)\n", - "\n", - " # Predictions folder\n", - "lenght_of_X = len(X)\n", - "for i in range(lenght_of_X):\n", - " img = normalize(X[i], 1,99.8, axis=axis_norm)\n", - " labels, polygons = model.predict_instances(img, n_tiles=n_tilesZYX)\n", - " \n", - "# Save the predicted mask in the result folder\n", - " os.chdir(Results_folder)\n", - " imsave(FILEnames[i], labels, polygons)\n", - "\n", - "\n", - "print(\"The mid-plane image is displayed below.\")\n", - "# ------------- For display ------------\n", - "print('--------------------------------------------------------------')\n", - "@interact\n", - "def show_QC_results(file=os.listdir(Data_folder)):\n", - " plt.figure(figsize=(13,10))\n", - "\n", - " img = imread(os.path.join(Data_folder, file))\n", - " img = normalize(img, 1,99.8, axis=axis_norm)\n", - " labels = imread(os.path.join(Results_folder, file))\n", - " z = max(0, img.shape[0] // 2 - 5)\n", - "\n", - " plt.subplot(121)\n", - " plt.imshow((img if img.ndim==3 else img[...,:3])[z], clim=(0,1), cmap='gray')\n", - " plt.title('Raw image (XY slice)')\n", - " plt.axis('off')\n", - " plt.subplot(122)\n", - " plt.imshow((img if img.ndim==3 else img[...,:3])[z], clim=(0,1), cmap='gray')\n", - " plt.imshow(labels[z], cmap=lbl_cmap, alpha=0.5)\n", - " plt.title('Image and predicted labels (XY slice)')\n", - " plt.axis('off');\n", - "\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "hvkd66PldsXB" - }, - "source": [ - "## **6.2. Download your predictions**\n", - "---\n", - "\n", - "**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "UvSlTaH14s3t" - }, - "source": [ - "#**Thank you for using StarDist 3D!**" - ] - } - ] -} \ No newline at end of file +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"StarDist_3D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1A26cn0nxWQCv-LuP3UBfyCWlKBGIo0RU","timestamp":1610975750230},{"file_id":"1MmLTCC0nyX3Akb9V4C_OVxM3X_M8u-eX","timestamp":1610543191319},{"file_id":"1paNjUObR5Rcr4BMGADJTz0PQBBLZDPrY","timestamp":1602522500580},{"file_id":"1WZRIoSBNcRUEq4-Rq5M4mDkIaOlEHnxz","timestamp":1588762142860},{"file_id":"10weAY0es-pEfHlACCaBCKK7PmgdoJqdh","timestamp":1587728072051},{"file_id":"10Ze0rFZoooyyTL_OIVWGdFJEhWE6_cSB","timestamp":1586789421439},{"file_id":"1SsGyUbWcMaLGHFepMuKElRNYLdEBUwf6","timestamp":1583244509550}],"collapsed_sections":[],"machine_shape":"hm"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.7"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"IkSguVy8Xv83"},"source":["# **StarDist (3D)**\n","---\n","\n","**StarDist 3D** is a deep-learning method that can be used to segment cell nuclei from 3D bioimages and was first published by [Weigert *et al.* in 2019 on arXiv](https://arxiv.org/abs/1908.03636), extending to 3D the 2D appraoch from [Schmidt *et al.* in 2018](https://arxiv.org/abs/1806.03535). It uses a shape representation based on star-convex polygons for nuclei in an image to predict the presence and the shape of these nuclei. This StarDist 3D network is based on an adapted ResNet network architecture.\n","\n"," **This particular notebook enables nuclei segmentation of 3D dataset. If you are interested in 2D dataset, you should use the StarDist 2D notebook instead.**\n","\n","---\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the paper:\n","\n","**Cell Detection with Star-convex Polygons** from Schmidt *et al.*, International Conference on Medical Image Computing and Computer-Assisted Intervention (MICCAI), Granada, Spain, September 2018. (https://arxiv.org/abs/1806.03535)\n","\n","and the 3D extension of the approach:\n","\n","**Star-convex Polyhedra for 3D Object Detection and Segmentation in Microscopy** from Weigert *et al.* published on arXiv in 2019 (https://arxiv.org/abs/1908.03636)\n","\n","**The Original code** is freely available in GitHub:\n","/~https://github.com/mpicbg-csbd/stardist\n","\n","**Please also cite this original paper when using or developing this notebook.**\n"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"gKDLkLWUd-YX"},"source":["#**0. Before getting started**\n","---\n"," For StarDist to train, **it needs to have access to a paired training dataset made of images of nuclei and their corresponding masks**. Information on how to generate a training dataset is available in our Wiki page: /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n","The data structure is important. It is necessary that all the input data are in the same folder and that all the output data is in a separate folder. The provided training dataset is already split in two folders called \"Training - Images\" (Training_source) and \"Training - Masks\" (Training_target).\n","\n","Additionally, the corresponding Training_source and Training_target files need to have **the same name**.\n","\n","Please note that you currently can **only use .tif files!**\n","\n","You can also provide a folder that contains the data that you wish to analyse with the trained network once all training has been performed.\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Images of nuclei (Training_source)\n"," - img_1.tif, img_2.tif, ...\n"," - Masks (Training_target)\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - Images of nuclei\n"," - img_1.tif, img_2.tif\n"," - **Masks** \n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"AdN8B91xZO0x"},"source":["# **1. Install StarDist and dependencies**\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"n16OlQru8MHl"},"source":["## **1.1 Install StarDist**\n","---\n"]},{"cell_type":"code","metadata":{"id":"vvtlJacqSmOi","cellView":"form"},"source":["#@markdown ##Install StarDist and dependencies\n","\n","!pip install tifffile # contains tools to operate tiff-files\n","!pip install csbdeep # contains tools for restoration of fluorescence microcopy images (Content-aware Image Restoration, CARE). It uses Keras and Tensorflow.\n","!pip install -e git://github.com/stardist/stardist.git@0.6.2#egg=stardist\n","!pip install gputools\n","#!pip install edt\n","!pip install wget\n","!pip install fpdf\n","\n","#Force session restart\n","exit(0)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"_IG7IqJtSuFY"},"source":["## **1.2. Restart your runtime**\n","---\n","\n","\n","\n","** Ignore the following message error message. Your Runtime has automatically restarted. This is normal.**\n","\n","\"\"
\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"40rqHN2PTJlu"},"source":["## **1.3. Load key dependencies**\n","---\n"," "]},{"cell_type":"code","metadata":{"id":"fq21zJVFNASx","cellView":"form"},"source":["#@markdown ##Load key dependencies\n","\n","Notebook_version = '1.13'\n","Network = 'StarDist 3D'\n","\n","from builtins import any as b_any\n","\n","def get_requirements_path():\n"," # Store requirements file in 'contents' directory \n"," current_dir = os.getcwd()\n"," dir_count = current_dir.count('/') - 1\n"," path = '../' * (dir_count) + 'requirements.txt'\n"," return path\n","\n","def filter_files(file_list, filter_list):\n"," filtered_list = []\n"," for fname in file_list:\n"," if b_any(fname.split('==')[0] in s for s in filter_list):\n"," filtered_list.append(fname)\n"," return filtered_list\n","\n","def build_requirements_file(before, after):\n"," path = get_requirements_path()\n","\n"," # Exporting requirements.txt for local run\n"," !pip freeze > $path\n","\n"," # Get minimum requirements file\n"," df = pd.read_csv(path, delimiter = \"\\n\")\n"," mod_list = [m.split('.')[0] for m in after if not m in before]\n"," req_list_temp = df.values.tolist()\n"," req_list = [x[0] for x in req_list_temp]\n","\n"," # Replace with package name and handle cases where import name is different to module name\n"," mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]\n"," mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list] \n"," filtered_list = filter_files(req_list, mod_replace_list)\n","\n"," file=open(path,'w')\n"," for item in filtered_list:\n"," file.writelines(item + '\\n')\n","\n"," file.close()\n","\n","import sys\n","before = [str(m) for m in sys.modules]\n","\n","# %tensorflow_version 1.x\n","import tensorflow\n","print(tensorflow.__version__)\n","print(\"Tensorflow enabled.\")\n","\n","# ------- Variable specific to Stardist -------\n","from stardist import fill_label_holes, random_label_cmap, calculate_extents, gputools_available\n","from stardist.models import Config3D, StarDist3D, StarDistData3D\n","from stardist import relabel_image_stardist3D, Rays_GoldenSpiral, calculate_extents\n","from stardist.matching import matching_dataset\n","from csbdeep.utils import Path, normalize, download_and_extract_zip_file, plot_history # for loss plot\n","from csbdeep.io import save_tiff_imagej_compatible\n","import numpy as np\n","np.random.seed(42)\n","lbl_cmap = random_label_cmap()\n","from __future__ import print_function, unicode_literals, absolute_import, division\n","import cv2\n","%matplotlib inline\n","%config InlineBackend.figure_format = 'retina'\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","from skimage.util import img_as_ubyte\n","from tqdm import tqdm \n","from fpdf import FPDF, HTMLMixin\n","from datetime import datetime\n","from pip._internal.operations.freeze import freeze\n","import subprocess\n","from astropy.visualization import simple_norm\n","\n","\n","# For sliders and dropdown menu and progress bar\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print(\"Libraries installed\")\n","\n","\n","\n","# Check if this is the latest version of the notebook\n","All_notebook_versions = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_Notebook_versions.csv\", dtype=str)\n","print('Notebook version: '+Notebook_version)\n","Latest_Notebook_version = All_notebook_versions[All_notebook_versions[\"Notebook\"] == Network]['Version'].iloc[0]\n","print('Latest notebook version: '+Latest_Notebook_version)\n","if Notebook_version == Latest_Notebook_version:\n"," print(\"This notebook is up-to-date.\")\n","else:\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","\n","\n","\n","def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," \n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," \n"," # add another cell \n"," if trained:\n"," training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n"," pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n"," pdf.ln(1)\n","\n"," Header_2 = 'Information for your materials and methods:'\n"," pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n"," #print(all_packages)\n","\n"," #Main Packages\n"," main_packages = ''\n"," version_numbers = []\n"," for name in ['tensorflow','numpy','Keras','csbdeep']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n"," cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n"," cuda_version = cuda_version.stdout.decode('utf-8')\n"," cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n"," gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n"," gpu_name = gpu_name.stdout.decode('utf-8')\n"," gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n"," #print(cuda_version[cuda_version.find(', V')+3:-1])\n"," #print(gpu_name)\n","\n"," shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n"," dataset_size = len(os.listdir(Training_source))\n","\n"," text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(int(len(X)))+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_height)+','+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+conf.train_dist_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," #text = 'The '+Network+' model ('+model_name+') was trained using '+str(dataset_size)+' paired images (image dimensions: '+str(shape)+') using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The GPU used was a '+gpu_name+'.'\n","\n"," if pretrained_model:\n"," text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(int(len(X)))+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_height)+','+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+conf.train_dist_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was retrained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.multi_cell(190, 5, txt = text, align='L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n"," pdf.set_font('')\n"," if augmentation:\n"," aug_text = 'The dataset was augmented.'\n"," else:\n"," aug_text = 'No augmentation was used for training.'\n"," pdf.multi_cell(190, 5, txt=aug_text, align='L')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," if Use_Default_Advanced_Parameters:\n"," pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n"," pdf.cell(200, 5, txt='The following parameters were used for training:')\n"," pdf.ln(1)\n"," html = \"\"\" \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
number_of_epochs{0}
patch_size{1}
batch_size{2}
number_of_steps{3}
percentage_validation{4}
n_rays{5}
initial_learning_rate{6}
\n"," \"\"\".format(number_of_epochs,str(patch_height)+'x'+str(patch_size)+'x'+str(patch_size),batch_size,number_of_steps,percentage_validation,n_rays,initial_learning_rate)\n"," pdf.write_html(html)\n","\n"," #pdf.multi_cell(190, 5, txt = text_2, align='L')\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," pdf.ln(1)\n"," pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(30, 5, txt= 'Training_source:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(28, 5, txt= 'Training_target:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n"," #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(21, 5, txt= 'Model Path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n"," pdf.ln(1)\n"," pdf.cell(60, 5, txt = 'Example Training pair', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread('/content/TrainingDataExample_StarDist3D.png').shape\n"," pdf.image('/content/TrainingDataExample_StarDist3D.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- StarDist 3D: Schmidt, Uwe, et al. \"Cell detection with star-convex polygons.\" International Conference on Medical Image Computing and Computer-Assisted Intervention. Springer, Cham, 2018.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," ref_3 = '- StarDist 3D: Weigert, Martin, et al. \"Star-convex polyhedra for 3d object detection and segmentation in microscopy.\" The IEEE Winter Conference on Applications of Computer Vision. 2020.'\n"," pdf.multi_cell(190, 5, txt = ref_3, align='L')\n"," # if Use_Data_augmentation:\n"," # ref_4 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\n"," # pdf.multi_cell(190, 5, txt = ref_4, align='L')\n"," pdf.ln(3)\n"," reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(model_path+'/'+model_name+'/'+model_name+\"_training_report.pdf\")\n","\n","\n","def qc_pdf_export():\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'Stardist 3D'\n","\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(2)\n"," pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/lossCurvePlots.png').shape\n"," if os.path.exists(full_QC_model_path+'/Quality Control/lossCurvePlots.png'):\n"," pdf.image(full_QC_model_path+'/Quality Control/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.')\n"," pdf.ln(2)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(3)\n"," pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n"," pdf.image(full_QC_model_path+'/Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/10))\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n","\n"," pdf.ln(1)\n"," html = \"\"\"\n"," \n"," \n"," \"\"\"\n"," with open(full_QC_model_path+'/Quality Control/Quality_Control for '+QC_model_name+'.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," PvGT_IoU = header[1]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \"\"\".format(image,PvGT_IoU)\n"," html = html+header\n"," for row in metrics:\n"," image = row[0]\n"," PvGT_IoU = row[1]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \"\"\".format(image,str(round(float(PvGT_IoU),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}
{0}{1}
\"\"\"\n"," \n"," pdf.write_html(html)\n","\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = ' - Weigert, Martin, et al. \"Star-convex polyhedra for 3d object detection and segmentation in microscopy.\" The IEEE Winter Conference on Applications of Computer Vision. 2020.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n"," pdf.ln(3)\n"," reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n","\n","\n","\n","def random_fliprot(img, mask, axis=None): \n"," if axis is None:\n"," axis = tuple(range(mask.ndim))\n"," axis = tuple(axis)\n"," \n"," assert img.ndim>=mask.ndim\n"," perm = tuple(np.random.permutation(axis))\n"," transpose_axis = np.arange(mask.ndim)\n"," for a, p in zip(axis, perm):\n"," transpose_axis[a] = p\n"," transpose_axis = tuple(transpose_axis)\n"," img = img.transpose(transpose_axis + tuple(range(mask.ndim, img.ndim))) \n"," mask = mask.transpose(transpose_axis) \n"," for ax in axis: \n"," if np.random.rand() > 0.5:\n"," img = np.flip(img, axis=ax)\n"," mask = np.flip(mask, axis=ax)\n"," return img, mask \n","\n","def random_intensity_change(img):\n"," img = img*np.random.uniform(0.6,2) + np.random.uniform(-0.2,0.2)\n"," return img\n","\n","def augmenter(x, y):\n"," \"\"\"Augmentation of a single input/label image pair.\n"," x is an input image\n"," y is the corresponding ground-truth label image\n"," \"\"\"\n"," # Note that we only use fliprots along axis=(1,2), i.e. the yx axis \n"," # as 3D microscopy acquisitions are usually not axially symmetric\n"," x, y = random_fliprot(x, y, axis=(1,2))\n"," x = random_intensity_change(x)\n"," return x, y\n","\n","# Build requirements file for local run\n","after = [str(m) for m in sys.modules]\n","build_requirements_file(before, after)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **2. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["\n","## **2.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelerator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"zCvebubeSaGY","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","# %tensorflow_version 1.x\n","\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"sNIVx8_CLolt"},"source":["## **2.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","cellView":"form"},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"7ncSG54e9LU4"},"source":["** If you cannot see your files, reactivate your session by connecting to your hosted runtime.** \n","\n","\n","\"Example
Connect to a hosted runtime.
"]},{"cell_type":"markdown","metadata":{"id":"HLYcZR9gMv42"},"source":["# **3. Select your parameters and paths**\n","---"]},{"cell_type":"markdown","metadata":{"id":"FQ_QxtSWQ7CL"},"source":["## **3.1. Setting main training parameters**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"id":"AuESFimvMv43"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source (images of nuclei) and Training_target (masks) training data respecively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**Training parameters**\n","\n","**`number_of_epochs`:** Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a 200 epochs, but a full training should run for more. Evaluate the performance after training (see 5.). **Default value: 200**\n","\n","**Advanced parameters - experienced users only**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 1** \n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default (or if set to 0) this parameter is calculated so that each image / patch is seen at least once per epoch. **Default value: Number of patch / batch_size**. This value is multiplied by 6 when augmentation is enabled.\n","\n","**`patch_size`:** and **`patch_height`:** Input the size of the patches use to train StarDist 3D (length of a side). The value should be smaller or equal to the dimensions of the image. Make patch size and patch_height as large as possible and divisible by 8 and 4, respectively. **Default value: dimension of the training images**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during the training. **Default value: 10** \n","\n","**`n_rays`:** Set number of rays (corners) used for StarDist (for instance a cube has 8 corners). **Default value: 96** \n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0003**\n","\n","**If you get an Out of memory (OOM) error during the training, manually decrease the patch_size and patch_height values until the OOM error disappear.**"]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","cellView":"form"},"source":["\n","\n","#@markdown ###Path to training images: \n","Training_source = \"\" #@param {type:\"string\"}\n","training_images = Training_source\n","\n","\n","Training_target = \"\" #@param {type:\"string\"}\n","mask_images = Training_target \n","\n","\n","#@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","\n","model_path = \"\" #@param {type:\"string\"}\n","trained_model = model_path \n","\n","#@markdown ### Other parameters for training:\n","number_of_epochs = 200#@param {type:\"number\"}\n","\n","#@markdown ###Advanced Parameters\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please input:\n","\n","#GPU_limit = 90 #@param {type:\"number\"}\n","batch_size = 2#@param {type:\"number\"}\n","number_of_steps = 0#@param {type:\"number\"}\n","patch_size = 96#@param {type:\"number\"} # pixels in\n","patch_height = 48#@param {type:\"number\"}\n","percentage_validation = 10#@param {type:\"number\"}\n","n_rays = 96 #@param {type:\"number\"}\n","initial_learning_rate = 0.0003 #@param {type:\"number\"}\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 2 # default from original author's notebook\n"," n_rays = 96\n"," percentage_validation = 10\n"," initial_learning_rate = 0.0003\n","\n"," patch_size = 96 # default from original author's notebook\n"," patch_height = 48 # default from original author's notebook\n","\n","\n","percentage = percentage_validation/100\n","\n","#here we check that no model with the same name already exist, if so print a warning\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n"," print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\"+W)\n"," \n","\n","random_choice=random.choice(os.listdir(Training_source))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","# Here we check that the input images are stacks\n","if len(x.shape) == 3:\n"," print(\"Image dimensions (z,y,x)\",x.shape)\n","\n","if not len(x.shape) == 3:\n"," print(bcolors.WARNING +\"Your images appear to have the wrong dimensions. Image dimension\",x.shape)\n","\n","\n","#Find image Z dimension and select the mid-plane\n","Image_Z = x.shape[0]\n","mid_plane = int(Image_Z / 2)+1\n","\n","\n","#Find image XY dimension\n","Image_Y = x.shape[1]\n","Image_X = x.shape[2]\n","\n","# If default parameters, patch size is the same as image size\n","# if (Use_Default_Advanced_Parameters): \n","# patch_size = min(Image_Y, Image_X) \n","# patch_height = Image_Z\n","\n","#Hyperparameters failsafes\n","\n","# Here we check that patch_size is smaller than the smallest xy dimension of the image \n","\n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is divisible by 8\n","if not patch_size % 8 == 0:\n"," patch_size = ((int(patch_size / 8)-1) * 8)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_height is smaller than the z dimension of the image \n","\n","if patch_height > Image_Z :\n"," patch_height = Image_Z\n"," print (bcolors.WARNING + \" Your chosen patch_height is bigger than the z dimension of your image; therefore the patch_size chosen is now:\",patch_height)\n","\n","# Here we check that patch_height is divisible by 4\n","if not patch_height % 4 == 0:\n"," patch_height = ((int(patch_height / 4)-1) * 4)\n"," if patch_height == 0:\n"," patch_height = 4\n"," print (bcolors.WARNING + \" Your chosen patch_height is not divisible by 4; therefore the patch_size chosen is now:\",patch_height)\n","\n","# Here we disable pre-trained model by default (in case the next cell is not ran)\n","Use_pretrained_model = False\n","\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = False\n","\n","print(\"Parameters initiated.\")\n","\n","\n","os.chdir(Training_target)\n","y = imread(Training_target+\"/\"+random_choice)\n","\n","#Here we use a simple normalisation strategy to visualise the image\n","from astropy.visualization import simple_norm\n","norm = simple_norm(x, percent = 99)\n","\n","mid_plane = int(Image_Z / 2)+1\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x[mid_plane], interpolation='nearest', norm=norm, cmap='magma')\n","plt.axis('off')\n","plt.title('Training source (single Z plane)');\n","plt.subplot(1,2,2)\n","plt.imshow(y[mid_plane], interpolation='nearest', cmap=lbl_cmap)\n","plt.axis('off')\n","plt.title('Training target (single Z plane)');\n","plt.savefig('/content/TrainingDataExample_StarDist3D.png',bbox_inches='tight',pad_inches=0)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xyQZKby8yFME"},"source":["## **3.2. Data augmentation**\n","---\n",""]},{"cell_type":"markdown","metadata":{"id":"w_jCy7xOx2g3"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n"," **However, data augmentation is not a magic solution and may also introduce issues. Therefore, we recommend that you train your network with and without augmentation, and use the QC section to validate that it improves overall performances.** \n","\n","Data augmentation is performed here by flipping, rotating and modifying the intensity of the images.\n","\n"]},{"cell_type":"code","metadata":{"cellView":"form","id":"2ArWgdghVlQw"},"source":["#Data augmentation\n","\n","Use_Data_augmentation = False #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation:\n"," augmenter = augmenter\n"," print(\"Data augmentation enabled. Let's flip!\")\n","else:\n"," augmenter = None\n"," print(\"Data augmentation disabled.\")\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"3L9zSGtORKYI"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a StarDist model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"9vC2n-HeLdiJ","cellView":"form"},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","pretrained_model_choice = \"Demo_3D_Model_from_Stardist_3D_paper\" #@param [\"Model_from_file\", \"Demo_3D_Model_from_Stardist_3D_paper\"]\n","\n","Weights_choice = \"best\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","\n","# --------------------- Download the Demo 3D model provided in the Stardist 3D github ------------------------\n","\n"," if pretrained_model_choice == \"Demo_3D_Model_from_Stardist_3D_paper\":\n"," pretrained_model_name = \"Demo_3D\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the Demo 3D model from the Stardist_3D paper\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"https://raw.githubusercontent.com/mpicbg-csbd/stardist/master/models/examples/3D_demo/config.json\", pretrained_model_path)\n"," wget.download(\"/~https://github.com/mpicbg-csbd/stardist/raw/master/models/examples/3D_demo/thresholds.json\", pretrained_model_path)\n"," wget.download(\"/~https://github.com/mpicbg-csbd/stardist/blob/master/models/examples/3D_demo/weights_best.h5?raw=true\", pretrained_model_path)\n"," wget.download(\"/~https://github.com/mpicbg-csbd/stardist/blob/master/models/examples/3D_demo/weights_last.h5?raw=true\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_last.h5 pretrained model does not exist'+W)\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print(bcolors.WARNING+'Weights found in:')\n"," print(h5_file_path)\n"," print(bcolors.WARNING+'will be loaded prior to training.')\n","\n","else:\n"," print(bcolors.WARNING+'No pretrained network will be used.')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"MCGklf1vZf2M"},"source":["#**4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"1KYOuygETJkT"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"id":"lIUAOJ_LMv5E","cellView":"form"},"source":["#@markdown ##Create the model and dataset objects\n","\n","# --------------------- Here we delete the model folder if it already exist ------------------------\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: Model folder already exists and has been removed !!\" + W)\n"," shutil.rmtree(model_path+'/'+model_name)\n","\n","\n","import warnings\n","warnings.simplefilter(\"ignore\")\n","\n","# --------------------- Here we load the augmented data or the raw data ------------------------\n","\n","# if Use_Data_augmentation:\n","# Training_source_dir = Training_source_augmented\n","# Training_target_dir = Training_target_augmented\n","\n","# if not Use_Data_augmentation:\n","# Training_source_dir = Training_source\n","# Training_target_dir = Training_target\n","# --------------------- ------------------------------------------------\n","\n","Training_source_dir = Training_source\n","Training_target_dir = Training_target\n","training_images_tiff=Training_source_dir+\"/*.tif\"\n","mask_images_tiff=Training_target_dir+\"/*.tif\"\n","\n","\n","# this funtion imports training images and masks and sorts them suitable for the network\n","X = sorted(glob(training_images_tiff)) \n","Y = sorted(glob(mask_images_tiff)) \n","\n","# assert -funtion check that X and Y really have images. If not this cell raises an error\n","assert all(Path(x).name==Path(y).name for x,y in zip(X,Y))\n","\n","# Here we map the training dataset (images and masks).\n","X = list(map(imread,X))\n","Y = list(map(imread,Y))\n","\n","n_channel = 1 if X[0].ndim == 3 else X[0].shape[-1]\n","\n","\n","\n","#Normalize images and fill small label holes.\n","axis_norm = (0,1,2) # normalize channels independently\n","# axis_norm = (0,1,2,3) # normalize channels jointly\n","if n_channel > 1:\n"," print(\"Normalizing image channels %s.\" % ('jointly' if axis_norm is None or 3 in axis_norm else 'independently'))\n"," sys.stdout.flush()\n","\n","X = [normalize(x,1,99.8,axis=axis_norm) for x in tqdm(X)]\n","Y = [fill_label_holes(y) for y in tqdm(Y)]\n","\n","#Here we split the your training dataset into training images (90 %) and validation images (10 %). \n","\n","assert len(X) > 1, \"not enough training data\"\n","rng = np.random.RandomState(42)\n","ind = rng.permutation(len(X))\n","n_val = max(1, int(round(percentage * len(ind))))\n","ind_train, ind_val = ind[:-n_val], ind[-n_val:]\n","X_val, Y_val = [X[i] for i in ind_val] , [Y[i] for i in ind_val]\n","X_trn, Y_trn = [X[i] for i in ind_train], [Y[i] for i in ind_train] \n","print('number of images: %3d' % len(X))\n","print('- training: %3d' % len(X_trn))\n","print('- validation: %3d' % len(X_val))\n","\n","\n","extents = calculate_extents(Y)\n","anisotropy = tuple(np.max(extents) / extents)\n","print('empirical anisotropy of labeled objects = %s' % str(anisotropy))\n","\n","\n","# Use OpenCL-based computations for data generator during training (requires 'gputools')\n","use_gpu = False and gputools_available()\n","\n","\n","#Here we ensure that our network has a minimal number of steps\n","if (Use_Default_Advanced_Parameters) or (number_of_steps == 0):\n"," number_of_steps = (Image_X//patch_size)*(Image_Y//patch_size)*(Image_Z//patch_height) * int(len(X)/batch_size)+1\n"," if (Use_Data_augmentation):\n"," number_of_steps = number_of_steps*6\n","\n","print(\"Number of steps: \"+str(number_of_steps))\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","# --------------------- ---------------------- ------------------------\n","\n","# Predict on subsampled grid for increased efficiency and larger field of view\n","grid = tuple(1 if a > 1.5 else 2 for a in anisotropy)\n","\n","# Use rays on a Fibonacci lattice adjusted for measured anisotropy of the training data\n","rays = Rays_GoldenSpiral(n_rays, anisotropy=anisotropy)\n","\n","conf = Config3D (\n"," rays = rays,\n"," grid = grid,\n"," anisotropy = anisotropy,\n"," use_gpu = use_gpu,\n"," n_channel_in = n_channel,\n"," train_learning_rate = initial_learning_rate,\n"," train_patch_size = (patch_height, patch_size, patch_size),\n"," train_batch_size = batch_size,\n",")\n","print(conf)\n","vars(conf)\n","\n","\n","# --------------------- This is currently disabled as it give an error ------------------------\n","#here we limit GPU to 80%\n","if use_gpu:\n"," from csbdeep.utils.tf import limit_gpu_memory\n"," # adjust as necessary: limit GPU memory to be used by TensorFlow to leave some to OpenCL-based computations\n"," limit_gpu_memory(0.8)\n","# --------------------- ---------------------- ------------------------\n","\n","\n","# Here we create a model according to section 5.3.\n","model = StarDist3D(conf, name=model_name, basedir=trained_model)\n","\n","# --------------------- Using pretrained model ------------------------\n","# Load the pretrained weights \n","if Use_pretrained_model:\n"," model.load_weights(h5_file_path)\n","# --------------------- ---------------------- ------------------------\n","\n","\n","#Here we check the FOV of the network.\n","median_size = calculate_extents(Y, np.median)\n","fov = np.array(model._axes_tile_overlap('ZYX'))\n","if any(median_size > fov):\n"," print(\"WARNING: median object size larger than field of view of the neural network.\")\n","\n","pdf_export(augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"0Dfn8ZsEMv5d"},"source":["## **4.2. Start Training**\n","---\n","\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point.\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"code","metadata":{"scrolled":true,"id":"iwNmp1PUzRDQ","cellView":"form"},"source":["import time\n","start = time.time()\n","\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","#@markdown ##Start training\n","\n","# augmenter = None\n","\n","# def augmenter(X_batch, Y_batch):\n","# \"\"\"Augmentation for data batch.\n","# X_batch is a list of input images (length at most batch_size)\n","# Y_batch is the corresponding list of ground-truth label images\n","# \"\"\"\n","# # ...\n","# return X_batch, Y_batch\n","\n","# Training the model. \n","# 'input_epochs' and 'steps' refers to your input data in section 5.1 \n","history = model.train(X_trn, Y_trn, validation_data=(X_val,Y_val), augmenter=augmenter,\n"," epochs=number_of_epochs, steps_per_epoch=number_of_steps)\n","None;\n","\n","print(\"Training done\")\n","\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","if os.path.exists(model_path+\"/\"+model_name+\"/Quality Control\"):\n"," shutil.rmtree(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","os.makedirs(model_path+\"/\"+model_name+\"/Quality Control\")\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = model_path+'/'+model_name+'/Quality Control/training_evaluation.csv'\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n","\n","\n","print(\"Network optimization in progress\")\n","\n","#Here we optimize the network.\n","model.optimize_thresholds(X_val, Y_val)\n","print(\"Done\")\n","\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","#Create a pdf document with training summary\n","pdf_export(trained = True, augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"_0Hynw3-xHp1"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"id":"eAJzMwPA6tlH","cellView":"form"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"dhJROwlAMv5o"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"id":"vMzSP50kMv5p","cellView":"form"},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","import csv\n","from matplotlib import pyplot as plt\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png',bbox_inches='tight',pad_inches=0)\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"X5_92nL2xdP6"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","This section will calculate the Intersection over Union score for all the images provided in the Source_QC_folder and Target_QC_folder ! The result for one of the image will also be displayed.\n","\n","The **Intersection over Union** metric is a method that can be used to quantify the percent overlap between the target mask and your prediction output. **Therefore, the closer to 1, the better the performance.** This metric can be used to assess the quality of your model to accurately predict nuclei. \n","\n"," The results can be found in the \"*Quality Control*\" folder which is located inside your \"model_folder\"."]},{"cell_type":"code","metadata":{"id":"w90MdriMxhjD","cellView":"form"},"source":["#@markdown ##Give the paths to an image to test the performance of the model with.\n","\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","#Here we allow the user to choose the number of tile to be used when predicting the images\n","#@markdown #####To analyse large image, your images need to be divided into tiles. Each tile will then be processed independently and re-assembled to generate the final image. \"Automatic_number_of_tiles\" will search for and use the smallest number of tiles that can be used, at the expanse of your runtime. Alternatively, manually input the number of tiles in each dimension to be used to process your images. \n","\n","Automatic_number_of_tiles = False #@param {type:\"boolean\"}\n","#@markdown #####If you get an Out of memory (OOM) error when using the \"Automatic_number_of_tiles\" option, disable it and manually input the values to be used to process your images. Progressively increases these numbers until the OOM error disappear.\n","n_tiles_Z = 1#@param {type:\"number\"}\n","n_tiles_Y = 1#@param {type:\"number\"}\n","n_tiles_X = 1#@param {type:\"number\"}\n","\n","if (Automatic_number_of_tiles): \n"," n_tilesZYX = None\n","\n","if not (Automatic_number_of_tiles):\n"," n_tilesZYX = (n_tiles_Z, n_tiles_Y, n_tiles_X)\n","\n","\n","#Create a quality control Folder and check if the folder already exist\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\") == False:\n"," os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n","\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","\n","# Generate predictions from the Source_QC_folder and save them in the QC folder\n","\n","Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\n","\n","\n","np.random.seed(16)\n","lbl_cmap = random_label_cmap()\n","Z = sorted(glob(Source_QC_folder_tif))\n","Z = list(map(imread,Z))\n","n_channel = 1 if Z[0].ndim == 2 else Z[0].shape[-1]\n","axis_norm = (0,1) # normalize channels independently\n","\n","print('Number of test dataset found in the folder: '+str(len(Z)))\n","\n"," \n"," # axis_norm = (0,1,2) # normalize channels jointly\n","if n_channel > 1:\n"," print(\"Normalizing image channels %s.\" % ('jointly' if axis_norm is None or 2 in axis_norm else 'independently'))\n","\n","model = StarDist3D(None, name=QC_model_name, basedir=QC_model_path)\n","\n","names = [os.path.basename(f) for f in sorted(glob(Source_QC_folder_tif))]\n","\n"," \n","# modify the names to suitable form: path_images/image_numberX.tif\n"," \n","lenght_of_Z = len(Z)\n"," \n","for i in range(lenght_of_Z):\n"," img = normalize(Z[i], 1,99.8, axis=axis_norm)\n"," labels, polygons = model.predict_instances(img, n_tiles=n_tilesZYX)\n"," os.chdir(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n"," imsave(names[i], labels, polygons)\n","\n","\n","# Here we start testing the differences between GT and predicted masks\n","\n","\n","with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Quality_Control for \"+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"image\",\"Prediction v. GT Intersection over Union\"]) \n","\n"," # Initialise the lists \n"," filename_list = []\n"," IoU_score_list = []\n","\n"," for n in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder,n)):\n"," print('Running QC on: '+n)\n"," \n"," test_input = io.imread(os.path.join(Source_QC_folder,n))\n"," test_prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\",n))\n"," test_ground_truth_image = io.imread(os.path.join(Target_QC_folder, n))\n","\n","#Convert pixel values to 0 or 255\n"," test_prediction_0_to_255 = test_prediction\n"," test_prediction_0_to_255[test_prediction_0_to_255>0] = 255\n","\n","#Convert pixel values to 0 or 255\n"," test_ground_truth_0_to_255 = test_ground_truth_image\n"," test_ground_truth_0_to_255[test_ground_truth_0_to_255>0] = 255\n","\n","# Intersection over Union metric\n","\n"," intersection = np.logical_and(test_ground_truth_0_to_255, test_prediction_0_to_255)\n"," union = np.logical_or(test_ground_truth_0_to_255, test_prediction_0_to_255)\n"," iou_score = np.sum(intersection) / np.sum(union)\n"," writer.writerow([n, str(iou_score)])\n","\n"," print(\"IoU: \"+str(round(iou_score,3)))\n","\n"," filename_list.append(n)\n"," IoU_score_list.append(iou_score)\n","\n","\n","\n","# Table with metrics as dataframe output\n","pdResults = pd.DataFrame(index = filename_list)\n","pdResults[\"IoU\"] = IoU_score_list\n","\n","# Display results\n","pdResults.head()\n","\n","\n","# ------------- For display ------------\n","print('--------------------------------------------------------------')\n","@interact\n","def show_QC_results(file=os.listdir(Source_QC_folder)):\n","\n"," f=plt.figure(figsize=(32,8))\n","\n"," test_input = io.imread(os.path.join(Source_QC_folder, file))\n"," test_prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\", file))\n"," test_ground_truth_image = io.imread(os.path.join(Target_QC_folder, file))\n","\n"," norm = simple_norm(test_input, percent = 99)\n"," Image_Z = test_input.shape[0]\n"," mid_plane = int(Image_Z / 2)+1\n","\n"," #Convert pixel values to 0 or 255\n"," test_prediction_0_to_255 = test_prediction\n"," test_prediction_0_to_255[test_prediction_0_to_255>0] = 255\n","\n"," #Convert pixel values to 0 or 255\n"," test_ground_truth_0_to_255 = test_ground_truth_image\n"," test_ground_truth_0_to_255[test_ground_truth_0_to_255>0] = 255\n","\n"," #Input\n"," plt.subplot(1,4,1)\n"," plt.axis('off')\n"," plt.imshow(test_input[mid_plane], aspect='equal', norm=norm, cmap='magma', interpolation='nearest')\n"," plt.title('Input')\n","\n"," #Ground-truth\n"," plt.subplot(1,4,2)\n"," plt.axis('off')\n"," plt.imshow(test_ground_truth_0_to_255[mid_plane], aspect='equal', cmap='Greens')\n"," plt.title('Ground Truth')\n","\n"," #Prediction\n"," plt.subplot(1,4,3)\n"," plt.axis('off')\n"," plt.imshow(test_prediction_0_to_255[mid_plane], aspect='equal', cmap='Purples')\n"," plt.title('Prediction')\n","\n"," #Overlay\n"," plt.subplot(1,4,4)\n"," plt.axis('off')\n"," plt.imshow(test_ground_truth_0_to_255[mid_plane], cmap='Greens')\n"," plt.imshow(test_prediction_0_to_255[mid_plane], alpha=0.5, cmap='Purples')\n"," plt.title('Ground Truth and Prediction, Intersection over Union:'+str(round(pdResults.loc[file][\"IoU\"],3)))\n"," plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","\n"," \n","\n","# Make a pdf summary of the QC results\n","qc_pdf_export()\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-tJeeJjLnRkP"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"d8wuQGjoq6eN"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.3) can now be used to process images. If an older model needs to be used, please untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Prediction_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contains the images that you want to predict using the network that you trained.\n","\n","**`Result_folder`:** This folder will contain the predicted output ROI.\n","\n","**`Data_type`:** Please indicate if the images you want to predict are single images or stacks\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"y2TD5p7MZrEb","cellView":"form"},"source":["from PIL import Image\n","\n","#@markdown ### Provide the path to your dataset and to the folder where the prediction will be saved (Result folder), then play the cell to predict output on your unseen images.\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","#test_dataset = Data_folder\n","\n","Results_folder = \"\" #@param {type:\"string\"}\n","#results = results_folder\n","\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = False #@param {type:\"boolean\"}\n","\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","#Here we allow the user to choose the number of tile to be used when predicting the images\n","#@markdown #####To analyse large image, your images need to be divided into tiles. Each tile will then be processed independently and re-assembled to generate the final image. \"Automatic_number_of_tiles\" will search for and use the smallest number of tiles that can be used, at the expanse of your runtime. Alternatively, manually input the number of tiles in each dimension to be used to process your images. \n","\n","Automatic_number_of_tiles = False #@param {type:\"boolean\"}\n","#@markdown #####If you get an Out of memory (OOM) error when using the \"Automatic_number_of_tiles\" option, disable it and manually input the values to be used to process your images. Progressively increases these numbers until the OOM error disappear.\n","n_tiles_Z = 1#@param {type:\"number\"}\n","n_tiles_Y = 1#@param {type:\"number\"}\n","n_tiles_X = 1#@param {type:\"number\"}\n","\n","if (Automatic_number_of_tiles): \n"," n_tilesZYX = None\n","\n","if not (Automatic_number_of_tiles):\n"," n_tilesZYX = (n_tiles_Z, n_tiles_Y, n_tiles_X)\n","\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","#single images\n","#testDATA = test_dataset\n","Dataset = Data_folder+\"/*.tif\"\n","\n","\n","np.random.seed(16)\n","lbl_cmap = random_label_cmap()\n","X = sorted(glob(Dataset))\n","X = list(map(imread,X))\n","n_channel = 1 if X[0].ndim == 2 else X[0].shape[-1]\n","axis_norm = (0,1) # normalize channels independently\n"," \n","# axis_norm = (0,1,2) # normalize channels jointly\n","if n_channel > 1:\n"," print(\"Normalizing image channels %s.\" % ('jointly' if axis_norm is None or 2 in axis_norm else 'independently'))\n","model = StarDist3D(None, name=Prediction_model_name, basedir=Prediction_model_path)\n"," \n","#Sorting and mapping original test dataset\n","X = sorted(glob(Dataset))\n","X = list(map(imread,X))\n","names = [os.path.basename(f) for f in sorted(glob(Dataset))]\n","\n","# modify the names to suitable form: path_images/image_numberX.tif\n","FILEnames=[]\n","for m in names:\n"," m=Results_folder+'/'+m\n"," FILEnames.append(m)\n","\n"," # Predictions folder\n","lenght_of_X = len(X)\n","for i in range(lenght_of_X):\n"," img = normalize(X[i], 1,99.8, axis=axis_norm)\n"," labels, polygons = model.predict_instances(img, n_tiles=n_tilesZYX)\n"," \n","# Save the predicted mask in the result folder\n"," os.chdir(Results_folder)\n"," imsave(FILEnames[i], labels, polygons)\n","\n","\n","print(\"The mid-plane image is displayed below.\")\n","# ------------- For display ------------\n","print('--------------------------------------------------------------')\n","@interact\n","def show_QC_results(file=os.listdir(Data_folder)):\n"," plt.figure(figsize=(13,10))\n","\n"," img = imread(os.path.join(Data_folder, file))\n"," img = normalize(img, 1,99.8, axis=axis_norm)\n"," labels = imread(os.path.join(Results_folder, file))\n"," z = max(0, img.shape[0] // 2 - 5)\n","\n"," plt.subplot(121)\n"," plt.imshow((img if img.ndim==3 else img[...,:3])[z], clim=(0,1), cmap='gray')\n"," plt.title('Raw image (XY slice)')\n"," plt.axis('off')\n"," plt.subplot(122)\n"," plt.imshow((img if img.ndim==3 else img[...,:3])[z], clim=(0,1), cmap='gray')\n"," plt.imshow(labels[z], cmap=lbl_cmap, alpha=0.5)\n"," plt.title('Image and predicted labels (XY slice)')\n"," plt.axis('off');\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB"},"source":["## **6.2. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"4DSziojsE6PC"},"source":["# **7. Version log**\n","---\n","**v1.13**: \n","\n","* StarDist is now downgraded to v 0.6.2 to ensure compatibility with previously trained models.\n","\n","* This version now includes an automatic restart allowing to set the h5py library to v2.10.\n","* The section 1 and 2 are now swapped for better export of *requirements.txt*.\n","This version also now includes built-in version check and the version log that \n","\n","* This version also now includes built-in version check and the version log that you're reading now.\n","\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"UvSlTaH14s3t"},"source":["#**Thank you for using StarDist 3D!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/Template_ZeroCostDL4Mic.ipynb b/Colab_notebooks/Template_ZeroCostDL4Mic.ipynb index 1927db57..d471abdd 100644 --- a/Colab_notebooks/Template_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/Template_ZeroCostDL4Mic.ipynb @@ -1,1739 +1 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "gfIn-nNNhdzh" - }, - "source": [ - " This is a template for a ZeroCostDL4Mic notebook and needs to be filled with appropriate model code and information.\n", - "\n", - " Thank you for contributing to the ZeroCostDL4Mic Project. Please use this notebook as a template for your implementation. When your notebook is completed, please upload it to your github page and send us a link so we can reference your work.\n", - "\n", - " If possible, remember to provide separate training and test datasets (for quality control) containing source and target images with your finished notebooks. This is very useful so that ZeroCostDL4Mic users can test your notebook. " - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Av1qDcfthk1a" - }, - "source": [ - "# **Name of the Network**\n", - "\n", - "---\n", - "\n", - " Description of the network and link to publication with author reference. [author et al, etc.](URL).\n", - "\n", - "---\n", - "\n", - "*Disclaimer*:\n", - "\n", - "This notebook is inspired from the *Zero-Cost Deep-Learning to Enhance Microscopy* project (ZeroCostDL4Mic) (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki) and was created by **Your name**\n", - "\n", - "This notebook is based on the following paper: \n", - "\n", - "**Original Title of the paper**, Journal, volume, pages, year and complete author list, [link to paper](URL)\n", - "\n", - "And source code found in: *provide github link or equivalent if applicable*\n", - "\n", - "Provide information on dataset availability and link for download if applicable.\n", - "\n", - "\n", - "**Please also cite this original paper when using or developing this notebook.**" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "TKktwSaWhq9e" - }, - "source": [ - "# **How to use this notebook?**\n", - "\n", - "---\n", - "\n", - "Video describing how to use ZeroCostDL4Mic notebooks are available on youtube:\n", - " - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n", - " - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n", - "\n", - "\n", - "---\n", - "###**Structure of a notebook**\n", - "\n", - "The notebook contains two types of cell: \n", - "\n", - "**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n", - "\n", - "**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n", - "\n", - "---\n", - "###**Table of contents, Code snippets** and **Files**\n", - "\n", - "On the top left side of the notebook you find three tabs which contain from top to bottom:\n", - "\n", - "*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n", - "\n", - "*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n", - "\n", - "*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n", - "\n", - "**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n", - "\n", - "**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n", - "\n", - "---\n", - "###**Making changes to the notebook**\n", - "\n", - "**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n", - "\n", - "To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n", - "You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "_v_Jl2QZhvLh" - }, - "source": [ - "#**0. Before getting started**\n", - "---\n", - " Give information on the required structure and dataype of the training dataset.\n", - "\n", - " Provide information on quality control dataset, such as:\n", - "\n", - "**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n", - "\n", - " **Additionally, the corresponding input and output files need to have the same name**.\n", - "\n", - " Please note that you currently can **only use .tif files!**\n", - "\n", - "\n", - "Here's a common data structure that can work:\n", - "* Experiment A\n", - " - **Training dataset**\n", - " - Low SNR images (Training_source)\n", - " - img_1.tif, img_2.tif, ...\n", - " - High SNR images (Training_target)\n", - " - img_1.tif, img_2.tif, ...\n", - " - **Quality control dataset**\n", - " - Low SNR images\n", - " - img_1.tif, img_2.tif\n", - " - High SNR images\n", - " - img_1.tif, img_2.tif\n", - " - **Data to be predicted**\n", - " - **Results**\n", - "\n", - "---\n", - "**Important note**\n", - "\n", - "- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n", - "\n", - "- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n", - "\n", - "- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "EPOJkyFYiA15" - }, - "source": [ - "# **1. Initialise the Colab session**\n", - "---\n", - "\n", - "\n", - "\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "8dvLrwF_iEXS" - }, - "source": [ - "\n", - "## **1.1. Check for GPU access**\n", - "---\n", - "\n", - "By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n", - "\n", - "Go to **Runtime -> Change the Runtime type**\n", - "\n", - "**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n", - "\n", - "**Accelerator: GPU** *(Graphics processing unit)*\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "8o_-wbDOiIHF" - }, - "outputs": [], - "source": [ - "#@markdown ##Run this cell to check if you have GPU access\n", - "%tensorflow_version 1.x\n", - "\n", - "import tensorflow as tf\n", - "if tf.test.gpu_device_name()=='':\n", - " print('You do not have GPU access.') \n", - " print('Did you change your runtime ?') \n", - " print('If the runtime settings are correct then Google did not allocate GPU to your session')\n", - " print('Expect slow performance. To access GPU try reconnecting later')\n", - "\n", - "else:\n", - " print('You have GPU access')\n", - "\n", - "from tensorflow.python.client import device_lib \n", - "device_lib.list_local_devices()\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "kEyJvvxSiN6L" - }, - "source": [ - "## **1.2. Mount your Google Drive**\n", - "---\n", - " To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n", - "\n", - " Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n", - "\n", - " Once this is done, your data are available in the **Files** tab on the top left of notebook." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "WWVR1U5tiM9h" - }, - "outputs": [], - "source": [ - "#@markdown ##Run this cell to connect your Google Drive to Colab\n", - "\n", - "#@markdown * Click on the URL. \n", - "\n", - "#@markdown * Sign in your Google Account. \n", - "\n", - "#@markdown * Copy the authorization code. \n", - "\n", - "#@markdown * Enter the authorization code. \n", - "\n", - "#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n", - "\n", - "#mounts user's Google Drive to Google Colab.\n", - "\n", - "from google.colab import drive\n", - "drive.mount('/content/gdrive')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "NvJvtQQgiVDF" - }, - "source": [ - "# **2. Install Name of the network and dependencies**\n", - "---\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "XMi71QrxiZbS" - }, - "outputs": [], - "source": [ - "#@markdown ##Install Network and dependencies\n", - "\n", - "#Libraries contains information of certain topics. \n", - "\n", - "# Place all imports below this code snippet\n", - "import sys\n", - "before = [str(m) for m in sys.modules]\n", - "\n", - "#Put the imported code and libraries here\n", - "\n", - "Notebook_version = ['1.12'] #Contact the ZeroCostDL4Mic team to find out about the version number\n", - "\n", - "!pip install fpdf\n", - "\n", - "# Below are templates for the function definitions for the export\n", - "# of pdf summaries for training and qc. You will need to adjust these functions\n", - "# with the variables and other parameters as necessary to make them\n", - "# work for your project\n", - "from datetime import datetime\n", - "\n", - "def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n", - " # save FPDF() class into a \n", - " # variable pdf \n", - " #from datetime import datetime\n", - "\n", - " class MyFPDF(FPDF, HTMLMixin):\n", - " pass\n", - "\n", - " pdf = MyFPDF()\n", - " pdf.add_page()\n", - " pdf.set_right_margin(-1)\n", - " pdf.set_font(\"Arial\", size = 11, style='B') \n", - "\n", - " Network = \"Your network's name\"\n", - " day = datetime.now()\n", - " datetime_str = str(day)[0:10]\n", - "\n", - " Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n", - " pdf.multi_cell(180, 5, txt = Header, align = 'L') \n", - "\n", - " # add another cell \n", - " if trained:\n", - " training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n", - " pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n", - " pdf.ln(1)\n", - "\n", - " Header_2 = 'Information for your materials and methods:'\n", - " pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n", - "\n", - " all_packages = ''\n", - " for requirement in freeze(local_only=True):\n", - " all_packages = all_packages+requirement+', '\n", - " #print(all_packages)\n", - "\n", - " #Main Packages\n", - " main_packages = ''\n", - " version_numbers = []\n", - " for name in ['tensorflow','numpy','Keras','csbdeep']:\n", - " find_name=all_packages.find(name)\n", - " main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n", - " #Version numbers only here:\n", - " version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n", - "\n", - " cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n", - " cuda_version = cuda_version.stdout.decode('utf-8')\n", - " cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n", - " gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n", - " gpu_name = gpu_name.stdout.decode('utf-8')\n", - " gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n", - " #print(cuda_version[cuda_version.find(', V')+3:-1])\n", - " #print(gpu_name)\n", - "\n", - " shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n", - " dataset_size = len(os.listdir(Training_source))\n", - "\n", - " text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(dataset_size*number_of_patches)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+config.train_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n", - "\n", - " if pretrained_model:\n", - " text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(dataset_size*number_of_patches)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+config.train_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was re-trained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n", - "\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - " pdf.multi_cell(190, 5, txt = text, align='L')\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.ln(1)\n", - " pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n", - " pdf.set_font('')\n", - " if augmentation:\n", - " aug_text = 'The dataset was augmented by a factor of '+str(Multiply_dataset_by)+' by'\n", - " if rotate_270_degrees != 0 or rotate_90_degrees != 0:\n", - " aug_text = aug_text+'\\n- rotation'\n", - " if flip_left_right != 0 or flip_top_bottom != 0:\n", - " aug_text = aug_text+'\\n- flipping'\n", - " if random_zoom_magnification != 0:\n", - " aug_text = aug_text+'\\n- random zoom magnification'\n", - " if random_distortion != 0:\n", - " aug_text = aug_text+'\\n- random distortion'\n", - " if image_shear != 0:\n", - " aug_text = aug_text+'\\n- image shearing'\n", - " if skew_image != 0:\n", - " aug_text = aug_text+'\\n- image skewing'\n", - " else:\n", - " aug_text = 'No augmentation was used for training.'\n", - " pdf.multi_cell(190, 5, txt=aug_text, align='L')\n", - " pdf.set_font('Arial', size = 11, style = 'B')\n", - " pdf.ln(1)\n", - " pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - " if Use_Default_Advanced_Parameters:\n", - " pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n", - " pdf.cell(200, 5, txt='The following parameters were used for training:')\n", - " pdf.ln(1)\n", - " html = \"\"\" \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
ParameterValue
number_of_epochs{0}
patch_size{1}
number_of_patches{2}
batch_size{3}
number_of_steps{4}
percentage_validation{5}
initial_learning_rate{6}
\n", - " \"\"\".format(number_of_epochs,str(patch_size)+'x'+str(patch_size),number_of_patches,batch_size,number_of_steps,percentage_validation,initial_learning_rate)\n", - " pdf.write_html(html)\n", - "\n", - " #pdf.multi_cell(190, 5, txt = text_2, align='L')\n", - " pdf.set_font(\"Arial\", size = 11, style='B')\n", - " pdf.ln(1)\n", - " pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.cell(29, 5, txt= 'Training_source:', align = 'L', ln=0)\n", - " pdf.set_font('')\n", - " pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.cell(27, 5, txt= 'Training_target:', align = 'L', ln=0)\n", - " pdf.set_font('')\n", - " pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n", - " #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n", - " pdf.ln(1)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.cell(22, 5, txt= 'Model Path:', align = 'L', ln=0)\n", - " pdf.set_font('')\n", - " pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n", - " pdf.ln(1)\n", - " pdf.cell(60, 5, txt = 'Example Training pair', ln=1)\n", - " pdf.ln(1)\n", - " exp_size = io.imread(\"/content/NetworkNameExampleData.png\").shape\n", - " pdf.image(\"/content/NetworkNameExampleData.png\", x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n", - " pdf.ln(1)\n", - " ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n", - " pdf.multi_cell(190, 5, txt = ref_1, align='L')\n", - " ref_2 = '- Your networks name: first author et al. \"Title of publication\" Journal, year'\n", - " pdf.multi_cell(190, 5, txt = ref_2, align='L')\n", - " if augmentation:\n", - " ref_3 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\n", - " pdf.multi_cell(190, 5, txt = ref_3, align='L')\n", - " pdf.ln(3)\n", - " reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n", - " pdf.set_font('Arial', size = 11, style='B')\n", - " pdf.multi_cell(190, 5, txt=reminder, align='C')\n", - "\n", - " pdf.output(model_path+'/'+model_name+'/'+model_name+\"_training_report.pdf\")\n", - "\n", - "\n", - "#Make a pdf summary of the QC results\n", - "\n", - "def qc_pdf_export():\n", - " class MyFPDF(FPDF, HTMLMixin):\n", - " pass\n", - "\n", - " pdf = MyFPDF()\n", - " pdf.add_page()\n", - " pdf.set_right_margin(-1)\n", - " pdf.set_font(\"Arial\", size = 11, style='B') \n", - "\n", - " Network = \"Your network's name\"\n", - " #model_name = os.path.basename(full_QC_model_path)\n", - " day = datetime.now()\n", - " datetime_str = str(day)[0:10]\n", - "\n", - " Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n", - " pdf.multi_cell(180, 5, txt = Header, align = 'L') \n", - "\n", - " all_packages = ''\n", - " for requirement in freeze(local_only=True):\n", - " all_packages = all_packages+requirement+', '\n", - "\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 11, style = 'B')\n", - " pdf.ln(2)\n", - " pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n", - " pdf.ln(1)\n", - " exp_size = io.imread(full_QC_model_path+'Quality Control/QC_example_data.png').shape\n", - " if os.path.exists(full_QC_model_path+'Quality Control/lossCurvePlots.png'):\n", - " pdf.image(full_QC_model_path+'Quality Control/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/13))\n", - " else:\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size=10)\n", - " pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.', align='L')\n", - " pdf.ln(2)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.ln(3)\n", - " pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n", - " pdf.ln(1)\n", - " exp_size = io.imread(full_QC_model_path+'Quality Control/QC_example_data.png').shape\n", - " pdf.image(full_QC_model_path+'Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/10))\n", - " pdf.ln(1)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 11, style = 'B')\n", - " pdf.ln(1)\n", - " pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - "\n", - " pdf.ln(1)\n", - " html = \"\"\"\n", - " \n", - " \n", - " \"\"\"\n", - " with open(full_QC_model_path+'Quality Control/QC_metrics_'+QC_model_name+'.csv', 'r') as csvfile:\n", - " metrics = csv.reader(csvfile)\n", - " header = next(metrics)\n", - " image = header[0]\n", - " mSSIM_PvsGT = header[1]\n", - " mSSIM_SvsGT = header[2]\n", - " NRMSE_PvsGT = header[3]\n", - " NRMSE_SvsGT = header[4]\n", - " PSNR_PvsGT = header[5]\n", - " PSNR_SvsGT = header[6]\n", - " header = \"\"\"\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \"\"\".format(image,mSSIM_PvsGT,mSSIM_SvsGT,NRMSE_PvsGT,NRMSE_SvsGT,PSNR_PvsGT,PSNR_SvsGT)\n", - " html = html+header\n", - " for row in metrics:\n", - " image = row[0]\n", - " mSSIM_PvsGT = row[1]\n", - " mSSIM_SvsGT = row[2]\n", - " NRMSE_PvsGT = row[3]\n", - " NRMSE_SvsGT = row[4]\n", - " PSNR_PvsGT = row[5]\n", - " PSNR_SvsGT = row[6]\n", - " cells = \"\"\"\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \"\"\".format(image,str(round(float(mSSIM_PvsGT),3)),str(round(float(mSSIM_SvsGT),3)),str(round(float(NRMSE_PvsGT),3)),str(round(float(NRMSE_SvsGT),3)),str(round(float(PSNR_PvsGT),3)),str(round(float(PSNR_SvsGT),3)))\n", - " html = html+cells\n", - " html = html+\"\"\"
{0}{1}{2}{3}{4}{5}{6}
{0}{1}{2}{3}{4}{5}{6}
\"\"\"\n", - "\n", - " pdf.write_html(html)\n", - "\n", - " pdf.ln(1)\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - " ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n", - " pdf.multi_cell(190, 5, txt = ref_1, align='L')\n", - " ref_2 = '- Your networks name: first author et al. \"Title of publication\" Journal, year'\n", - " pdf.multi_cell(190, 5, txt = ref_2, align='L')\n", - "\n", - " pdf.ln(3)\n", - " reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n", - "\n", - " pdf.set_font('Arial', size = 11, style='B')\n", - " pdf.multi_cell(190, 5, txt=reminder, align='C')\n", - "\n", - " pdf.output(full_QC_model_path+'Quality Control/'+QC_model_name+'_QC_report.pdf')\n", - "\n", - "print(\"Depencies installed and imported.\")\n", - "\n", - "# Exporting requirements.txt for local run \n", - "# -- the developers should leave this below all the other installations\n", - "!pip freeze > requirements.txt\n", - "\n", - "# Code snippet to shorten requirements file to essential packages \n", - "after = [str(m) for m in sys.modules]\n", - "\n", - "# Ensure this code snippet is placed before all other imports!\n", - "# import sys\n", - "# before = [str(m) for m in sys.modules]\n", - "\n", - "from builtins import any as b_any\n", - "\n", - "def filter_files(file_list, filter_list):\n", - " filtered_list = []\n", - " for fname in file_list:\n", - " if b_any(fname.split('==')[0] in s for s in filter_list):\n", - " filtered_list.append(fname)\n", - " return filtered_list\n", - "\n", - "df = pd.read_csv('requirements.txt', delimiter = \"\\n\")\n", - "mod_list = [m.split('.')[0] for m in after if not m in before]\n", - "req_list_temp = df.values.tolist()\n", - "req_list = [x[0] for x in req_list_temp]\n", - "\n", - "# If necessary, extend mod_name_list with packages where import name is different from package name for pip install\n", - "mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]\n", - "mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list] \n", - "filtered_list = filter_files(req_list, mod_replace_list)\n", - "\n", - "# Insert name of network below\n", - "file=open('NAME_OF_NETWORK_requirements_simple.txt','w')\n", - "for item in filtered_list:\n", - " file.writelines(item + '\\n')\n", - "\n", - "file.close()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jKaeBnSuifZn" - }, - "source": [ - "# **3. Select your paths and parameters**\n", - "\n", - "---\n", - "\n", - "The code below allows the user to enter the paths to where the training data is and to define the training parameters.\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "StTGluw2iidc" - }, - "source": [ - "## **3.1. Setting the main training parameters**\n", - "---\n", - "" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "GyRjBdClimfK" - }, - "source": [ - " **Paths for training, predictions and results**\n", - "\n", - " Fill the parameters here as needed and update the code. Note that the sections containing `Training_source`, `Training target`, `model_name` and `model_path` should appear in your notebook.\n", - "\n", - "**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source and Training_target data respectively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n", - "\n", - "**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n", - "\n", - "**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n", - "\n", - "**Training parameters**\n", - "\n", - "**`number_of_epochs`:**Give estimates for training performance given a number of epochs and provide a default value. **Default value:**\n", - "\n", - "**`other_parameters`:**Give other parameters or default values **Default value:**\n", - "\n", - "**If additional parameter above affects the training of the notebook give a brief explanation and how problems can be mitigated** \n", - "\n", - "\n", - "**Advanced parameters - experienced users only**\n", - "\n", - "**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n", - "\n", - "**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 16**\n", - "\n", - "**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during training. **Default value: 10** " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "i1sKnXrDieiR" - }, - "outputs": [], - "source": [ - "class bcolors:\n", - " WARNING = '\\033[31m'\n", - "\n", - "#@markdown ###Path to training images:\n", - "\n", - "Training_source = \"\" #@param {type:\"string\"}\n", - "\n", - "# Ground truth images\n", - "Training_target = \"\" #@param {type:\"string\"}\n", - "\n", - "# model name and path\n", - "#@markdown ###Name of the model and path to model folder:\n", - "model_name = \"\" #@param {type:\"string\"}\n", - "model_path = \"\" #@param {type:\"string\"}\n", - "\n", - "\n", - "# other parameters for training.\n", - "#@markdown ###Training Parameters\n", - "#@markdown Number of epochs:\n", - "\n", - "number_of_epochs = 50#@param {type:\"number\"}\n", - "\n", - "#@markdown Other parameters, add as necessary\n", - "other_parameters = 80#@param {type:\"number\"} # in pixels\n", - "\n", - "\n", - "#@markdown ###Advanced Parameters\n", - "\n", - "Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n", - "#@markdown ###If not, please input:\n", - "\n", - "number_of_steps = 400#@param {type:\"number\"}\n", - "batch_size = 16#@param {type:\"number\"}\n", - "percentage_validation = 10 #@param {type:\"number\"}\n", - "\n", - "\n", - "if (Use_Default_Advanced_Parameters): \n", - " print(\"Default advanced parameters enabled\")\n", - " batch_size = 16\n", - " percentage_validation = 10\n", - "\n", - "#Here we define the percentage to use for validation\n", - "percentage = percentage_validation/100\n", - "\n", - "\n", - "#here we check that no model with the same name already exist, if so delete\n", - "if os.path.exists(model_path+'/'+model_name):\n", - " shutil.rmtree(model_path+'/'+model_name)\n", - "\n", - "\n", - "# The shape of the images.\n", - "x = imread(InputFile)\n", - "y = imread(OutputFile)\n", - "\n", - "print('Loaded Input images (number, width, length) =', x.shape)\n", - "print('Loaded Output images (number, width, length) =', y.shape)\n", - "print(\"Parameters initiated.\")\n", - "\n", - "# This will display a randomly chosen dataset input and output\n", - "random_choice = random.choice(os.listdir(Training_source))\n", - "x = imread(Training_source+\"/\"+random_choice)\n", - "\n", - "\n", - "# Here we check that the input images contains the expected dimensions\n", - "if len(x.shape) == 2:\n", - " print(\"Image dimensions (y,x)\",x.shape)\n", - "\n", - "if not len(x.shape) == 2:\n", - " print(bcolors.WARNING +\"Your images appear to have the wrong dimensions. Image dimension\",x.shape)\n", - "\n", - "\n", - "#Find image XY dimension\n", - "Image_Y = x.shape[0]\n", - "Image_X = x.shape[1]\n", - "\n", - "#Hyperparameters failsafes\n", - "\n", - "# Here we check that patch_size is smaller than the smallest xy dimension of the image \n", - "\n", - "if patch_size > min(Image_Y, Image_X):\n", - " patch_size = min(Image_Y, Image_X)\n", - " print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n", - "\n", - "# Here we check that patch_size is divisible by 8\n", - "if not patch_size % 8 == 0:\n", - " patch_size = ((int(patch_size / 8)-1) * 8)\n", - " print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is now:\",patch_size)\n", - "\n", - "\n", - "\n", - "os.chdir(Training_target)\n", - "y = imread(Training_target+\"/\"+random_choice)\n", - "\n", - "\n", - "f=plt.figure(figsize=(16,8))\n", - "plt.subplot(1,2,1)\n", - "plt.imshow(x, interpolation='nearest')\n", - "plt.title('Training source')\n", - "plt.axis('off');\n", - "\n", - "plt.subplot(1,2,2)\n", - "plt.imshow(y, interpolation='nearest')\n", - "plt.title('Training target')\n", - "plt.axis('off');\n", - "#We save the example data here to use it in the pdf export of the training\n", - "plt.savefig('/content/NetworkNameExampleData.png', bbox_inches='tight', pad_inches=0)\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "VLYZQA6GitQL" - }, - "source": [ - "## **3.2. Data augmentation**\n", - "---\n", - "" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "M4GfK6-1iwbf" - }, - "source": [ - "Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n", - "\n", - "Data augmentation is performed here by rotating the patches in XY-Plane and flip them along X-Axis. This only works if the images are square in XY.\n", - "\n", - "Add any other information which is necessary to run augmentation with your notebook/data." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "EkBGtraZi3Ob" - }, - "outputs": [], - "source": [ - "#@markdown ###Add any further useful augmentations\n", - "Use_Data_augmentation = False #@param{type:\"boolean\"}\n", - "\n", - "#@markdown Select this option if you want to use augmentation to increase the size of your dataset\n", - "\n", - "#@markdown **Rotate each image 3 times by 90 degrees.**\n", - "Rotation = True #@param{type:\"boolean\"}\n", - "\n", - "#@markdown **Flip each image once around the x axis of the stack.**\n", - "Flip = True #@param{type:\"boolean\"}\n", - "\n", - "\n", - "#@markdown **Would you like to save your augmented images?**\n", - "\n", - "Save_augmented_images = False #@param {type:\"boolean\"}\n", - "\n", - "Saving_path = \"\" #@param {type:\"string\"}\n", - "\n", - "\n", - "if not Save_augmented_images:\n", - " Saving_path= \"/content\"\n", - "\n", - "\n", - "def rotation_aug(Source_path, Target_path, flip=False):\n", - " Source_images = os.listdir(Source_path)\n", - " Target_images = os.listdir(Target_path)\n", - " \n", - " for image in Source_images:\n", - " source_img = io.imread(os.path.join(Source_path,image))\n", - " target_img = io.imread(os.path.join(Target_path,image))\n", - " \n", - " # Source Rotation\n", - " source_img_90 = np.rot90(source_img,axes=(1,2))\n", - " source_img_180 = np.rot90(source_img_90,axes=(1,2))\n", - " source_img_270 = np.rot90(source_img_180,axes=(1,2))\n", - "\n", - " # Target Rotation\n", - " target_img_90 = np.rot90(target_img,axes=(1,2))\n", - " target_img_180 = np.rot90(target_img_90,axes=(1,2))\n", - " target_img_270 = np.rot90(target_img_180,axes=(1,2))\n", - "\n", - " # Add a flip to the rotation\n", - " \n", - " if flip == True:\n", - " source_img_lr = np.fliplr(source_img)\n", - " source_img_90_lr = np.fliplr(source_img_90)\n", - " source_img_180_lr = np.fliplr(source_img_180)\n", - " source_img_270_lr = np.fliplr(source_img_270)\n", - "\n", - " target_img_lr = np.fliplr(target_img)\n", - " target_img_90_lr = np.fliplr(target_img_90)\n", - " target_img_180_lr = np.fliplr(target_img_180)\n", - " target_img_270_lr = np.fliplr(target_img_270)\n", - "\n", - " #source_img_90_ud = np.flipud(source_img_90)\n", - " \n", - " # Save the augmented files\n", - " # Source images\n", - " io.imsave(Saving_path+'/augmented_source/'+image,source_img)\n", - " io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_90.tif',source_img_90)\n", - " io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_180.tif',source_img_180)\n", - " io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_270.tif',source_img_270)\n", - " # Target images\n", - " io.imsave(Saving_path+'/augmented_target/'+image,target_img)\n", - " io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_90.tif',target_img_90)\n", - " io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_180.tif',target_img_180)\n", - " io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_270.tif',target_img_270)\n", - "\n", - " if flip == True:\n", - " io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n", - " io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_90_lr.tif',source_img_90_lr)\n", - " io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_180_lr.tif',source_img_180_lr)\n", - " io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_270_lr.tif',source_img_270_lr)\n", - "\n", - " io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n", - " io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_90_lr.tif',target_img_90_lr)\n", - " io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_180_lr.tif',target_img_180_lr)\n", - " io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_270_lr.tif',target_img_270_lr)\n", - "\n", - "def flip(Source_path, Target_path):\n", - " Source_images = os.listdir(Source_path)\n", - " Target_images = os.listdir(Target_path) \n", - "\n", - " for image in Source_images:\n", - " source_img = io.imread(os.path.join(Source_path,image))\n", - " target_img = io.imread(os.path.join(Target_path,image))\n", - " \n", - " source_img_lr = np.fliplr(source_img)\n", - " target_img_lr = np.fliplr(target_img)\n", - "\n", - " io.imsave(Saving_path+'/augmented_source/'+image,source_img)\n", - " io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n", - "\n", - " io.imsave(Saving_path+'/augmented_target/'+image,target_img)\n", - " io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n", - "\n", - "\n", - "if Use_Data_augmentation:\n", - "\n", - " if os.path.exists(Saving_path+'/augmented_source'):\n", - " shutil.rmtree(Saving_path+'/augmented_source')\n", - " os.mkdir(Saving_path+'/augmented_source')\n", - "\n", - " if os.path.exists(Saving_path+'/augmented_target'):\n", - " shutil.rmtree(Saving_path+'/augmented_target') \n", - " os.mkdir(Saving_path+'/augmented_target')\n", - "\n", - " print(\"Data augmentation enabled\")\n", - " print(\"Data augmentation in progress....\")\n", - "\n", - " if Rotation == True:\n", - " rotation_aug(Training_source,Training_target,flip=Flip)\n", - " \n", - " elif Rotation == False and Flip == True:\n", - " flip(Training_source,Training_target)\n", - " print(\"Done\")\n", - "\n", - "\n", - "if not Use_Data_augmentation:\n", - " print(bcolors.WARNING+\"Data augmentation disabled\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-Y-47ZmFiyG_" - }, - "source": [ - "## **3.3. Using weights from a pre-trained model as initial weights**\n", - "---\n", - " Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a model of Your Network**. \n", - "\n", - " This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n", - "\n", - " In order to continue training from the point where the pret-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "jSb9luhrjHe-" - }, - "outputs": [], - "source": [ - "# @markdown ##Loading weights from a pre-trained network\n", - "\n", - "Use_pretrained_model = False #@param {type:\"boolean\"}\n", - "\n", - "pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n", - "\n", - "Weights_choice = \"last\" #@param [\"last\", \"best\"]\n", - "\n", - "\n", - "#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n", - "pretrained_model_path = \"\" #@param {type:\"string\"}\n", - "\n", - "# --------------------- Check if we load a previously trained model ------------------------\n", - "if Use_pretrained_model:\n", - "\n", - "# --------------------- Load the model from the choosen path ------------------------\n", - " if pretrained_model_choice == \"Model_from_file\":\n", - " h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n", - "\n", - "\n", - "# --------------------- Download the a model provided in the XXX ------------------------\n", - "\n", - " if pretrained_model_choice == \"Model_name\":\n", - " pretrained_model_name = \"Model_name\"\n", - " pretrained_model_path = \"/content/\"+pretrained_model_name\n", - " print(\"Downloading the 2D_Demo_Model_from_Stardist_2D_paper\")\n", - " if os.path.exists(pretrained_model_path):\n", - " shutil.rmtree(pretrained_model_path)\n", - " os.makedirs(pretrained_model_path)\n", - " wget.download(\"\", pretrained_model_path)\n", - " wget.download(\"\", pretrained_model_path)\n", - " wget.download(\"\", pretrained_model_path) \n", - " wget.download(\"\", pretrained_model_path)\n", - " h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n", - "\n", - "# --------------------- Add additional pre-trained models here ------------------------\n", - "\n", - "# --------------------- Check the model exist ------------------------\n", - "# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n", - " if not os.path.exists(h5_file_path):\n", - " print(bcolors.WARNING+'WARNING: weights_last.h5 pretrained model does not exist')\n", - " Use_pretrained_model = False\n", - "\n", - " \n", - "# If the model path contains a pretrain model, we load the training rate, \n", - " if os.path.exists(h5_file_path):\n", - "#Here we check if the learning rate can be loaded from the quality control folder\n", - " if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n", - "\n", - " with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n", - " csvRead = pd.read_csv(csvfile, sep=',')\n", - " #print(csvRead)\n", - " \n", - " if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n", - " print(\"pretrained network learning rate found\")\n", - " #find the last learning rate\n", - " lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n", - " #Find the learning rate corresponding to the lowest validation loss\n", - " min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n", - " #print(min_val_loss)\n", - " bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n", - "\n", - " if Weights_choice == \"last\":\n", - " print('Last learning rate: '+str(lastLearningRate))\n", - "\n", - " if Weights_choice == \"best\":\n", - " print('Learning rate of best validation loss: '+str(bestLearningRate))\n", - "\n", - " if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n", - " bestLearningRate = initial_learning_rate\n", - " lastLearningRate = initial_learning_rate\n", - " print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n", - "\n", - "#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n", - " if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n", - " print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n", - " bestLearningRate = initial_learning_rate\n", - " lastLearningRate = initial_learning_rate\n", - "\n", - "\n", - "# Display info about the pretrained model to be loaded (or not)\n", - "if Use_pretrained_model:\n", - " print('Weights found in:')\n", - " print(h5_file_path)\n", - " print('will be loaded prior to training.')\n", - "\n", - "else:\n", - " print(bcolors.WARNING+'No pretrained nerwork will be used.')\n", - "\n", - "\n", - "#@markdown ### You will need to add or replace the code that loads any previously trained weights to the notebook here." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "sjTtP2OmjMqM" - }, - "source": [ - "# **4. Train the network**\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "yQ9NgI6XjQIk" - }, - "source": [ - "## **4.1. Train the network**\n", - "---\n", - "When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n", - "\n", - "* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches.\n", - "\n", - "Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "SVUd0Lr0jUjy" - }, - "outputs": [], - "source": [ - "import time\n", - "import csv\n", - "\n", - "# Export the training parameters as pdf (before training, in case training fails) \n", - "pdf_export(augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)\n", - "\n", - "start = time.time()\n", - "\n", - "#@markdown ##Start training\n", - "\n", - "# Start Training\n", - "\n", - "#Insert the code necessary to initiate training of your model\n", - "\n", - "#Note that the notebook should load weights either from the model that is \n", - "#trained from scratch or if the pretrained weights are used (3.3.)\n", - "\n", - "# Displaying the time elapsed for training\n", - "dt = time.time() - start\n", - "mins, sec = divmod(dt, 60) \n", - "hour, mins = divmod(mins, 60) \n", - "print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n", - "\n", - "# Export the training parameters as pdf (after training)\n", - "pdf_export(trained = True, augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1Tm3aimXjZ1B" - }, - "source": [ - "# **5. Evaluate your model**\n", - "---\n", - "\n", - "This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n", - "\n", - "**We highly recommend to perform quality control on all newly trained models.**\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "QAXu1FR0jYZC" - }, - "outputs": [], - "source": [ - "# model name and path\n", - "#@markdown ###Do you want to assess the model you just trained ?\n", - "Use_the_current_trained_model = True #@param {type:\"boolean\"}\n", - "\n", - "#@markdown ###If not, please provide the name of the model and path to model folder:\n", - "#@markdown #####During training, the model files are automatically saved inside a folder named after model_name in section 3. Provide the path to this folder below. \n", - "\n", - "QC_model_folder = \"\" #@param {type:\"string\"}\n", - "\n", - "#Here we define the loaded model name and path\n", - "QC_model_name = os.path.basename(QC_model_folder)\n", - "QC_model_path = os.path.dirname(QC_model_folder)\n", - "\n", - "if (Use_the_current_trained_model): \n", - " QC_model_name = model_name\n", - " QC_model_path = model_path\n", - "\n", - "full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n", - "if os.path.exists(full_QC_model_path):\n", - " print(\"The \"+QC_model_name+\" network will be evaluated\")\n", - "else:\n", - " W = '\\033[0m' # white (normal)\n", - " R = '\\033[31m' # red\n", - " print(R+'!! WARNING: The chosen model does not exist !!'+W)\n", - " print('Please make sure you provide a valid model path and model name before proceeding further.')\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ULMuc37njkXM" - }, - "source": [ - "## **5.1. Inspection of the loss function**\n", - "---\n", - "\n", - "First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n", - "\n", - "**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n", - "\n", - "**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n", - "\n", - "During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n", - "\n", - "Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "1VCvEofKjjHN" - }, - "outputs": [], - "source": [ - "#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n", - "import csv\n", - "from matplotlib import pyplot as plt\n", - "\n", - "lossDataFromCSV = []\n", - "vallossDataFromCSV = []\n", - "\n", - "with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n", - " csvRead = csv.reader(csvfile, delimiter=',')\n", - " next(csvRead)\n", - " for row in csvRead:\n", - " lossDataFromCSV.append(float(row[0]))\n", - " vallossDataFromCSV.append(float(row[1]))\n", - "\n", - "epochNumber = range(len(lossDataFromCSV))\n", - "plt.figure(figsize=(15,10))\n", - "\n", - "plt.subplot(2,1,1)\n", - "plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n", - "plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n", - "plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n", - "plt.ylabel('Loss')\n", - "plt.xlabel('Epoch number')\n", - "plt.legend()\n", - "\n", - "plt.subplot(2,1,2)\n", - "plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n", - "plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n", - "plt.title('Training loss and validation loss vs. epoch number (log scale)')\n", - "plt.ylabel('Loss')\n", - "plt.xlabel('Epoch number')\n", - "plt.legend()\n", - "plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png')\n", - "plt.show()\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "smiWe2wcjwTc" - }, - "source": [ - "## **5.2. Error mapping and quality metrics estimation**\n", - "---\n", - "\n", - " Update the code below to perform predictions on your quality control dataset. Use the metrics that are the most meaningful to assess the quality of the prediction.\n", - "\n", - "This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" !\n", - "\n", - "**1. The SSIM (structural similarity) map** \n", - "\n", - "The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n", - "\n", - "**mSSIM** is the SSIM value calculated across the entire window of both images.\n", - "\n", - "**The output below shows the SSIM maps with the mSSIM**\n", - "\n", - "**2. The RSE (Root Squared Error) map** \n", - "\n", - "This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n", - "\n", - "\n", - "**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n", - "\n", - "**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n", - "\n", - "**The output below shows the RSE maps with the NRMSE and PSNR values.**\n", - "\n", - "\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "Z179Zxgtj0PP" - }, - "outputs": [], - "source": [ - "#@markdown ##Choose the folders that contain your Quality Control dataset\n", - "\n", - "from skimage.metrics import structural_similarity\n", - "from skimage.metrics import peak_signal_noise_ratio as psnr\n", - "\n", - "Source_QC_folder = \"\" #@param{type:\"string\"}\n", - "Target_QC_folder = \"\" #@param{type:\"string\"}\n", - "\n", - "# Create a quality control/Prediction Folder\n", - "if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"):\n", - " shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n", - "\n", - "os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n", - "\n", - "# Insert code to activate the pretrained model if necessary. \n", - "\n", - "# List Tif images in Source_QC_folder\n", - "Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\n", - "Z = sorted(glob(Source_QC_folder_tif))\n", - "Z = list(map(imread,Z))\n", - "print('Number of test dataset found in the folder: '+str(len(Z)))\n", - "\n", - "\n", - "# Insert code to perform predictions on all datasets in the Source_QC folder\n", - "\n", - "\n", - "def ssim(img1, img2):\n", - " return structural_similarity(img1,img2,data_range=1.,full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n", - "\n", - "\n", - "def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n", - " \"\"\"This function is adapted from Martin Weigert\"\"\"\n", - " \"\"\"Percentile-based image normalization.\"\"\"\n", - "\n", - " mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n", - " ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n", - " return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n", - "\n", - "\n", - "def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n", - " \"\"\"This function is adapted from Martin Weigert\"\"\"\n", - " if dtype is not None:\n", - " x = x.astype(dtype,copy=False)\n", - " mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n", - " ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n", - " eps = dtype(eps)\n", - "\n", - " try:\n", - " import numexpr\n", - " x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n", - " except ImportError:\n", - " x = (x - mi) / ( ma - mi + eps )\n", - "\n", - " if clip:\n", - " x = np.clip(x,0,1)\n", - "\n", - " return x\n", - "\n", - "def norm_minmse(gt, x, normalize_gt=True):\n", - " \"\"\"This function is adapted from Martin Weigert\"\"\"\n", - "\n", - " \"\"\"\n", - " normalizes and affinely scales an image pair such that the MSE is minimized \n", - " \n", - " Parameters\n", - " ----------\n", - " gt: ndarray\n", - " the ground truth image \n", - " x: ndarray\n", - " the image that will be affinely scaled \n", - " normalize_gt: bool\n", - " set to True of gt image should be normalized (default)\n", - " Returns\n", - " -------\n", - " gt_scaled, x_scaled \n", - " \"\"\"\n", - " if normalize_gt:\n", - " gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n", - " x = x.astype(np.float32, copy=False) - np.mean(x)\n", - " #x = x - np.mean(x)\n", - " gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n", - " #gt = gt - np.mean(gt)\n", - " scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n", - " return gt, scale * x\n", - "\n", - "# Open and create the csv file that will contain all the QC metrics\n", - "with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/QC_metrics_\"+QC_model_name+\".csv\", \"w\", newline='') as file:\n", - " writer = csv.writer(file)\n", - "\n", - " # Write the header in the csv file\n", - " writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n", - "\n", - " # Let's loop through the provided dataset in the QC folders\n", - "\n", - "\n", - " for i in os.listdir(Source_QC_folder):\n", - " if not os.path.isdir(os.path.join(Source_QC_folder,i)):\n", - " print('Running QC on: '+i)\n", - " # -------------------------------- Target test data (Ground truth) --------------------------------\n", - " test_GT = io.imread(os.path.join(Target_QC_folder, i))\n", - "\n", - " # -------------------------------- Source test data --------------------------------\n", - " test_source = io.imread(os.path.join(Source_QC_folder,i))\n", - "\n", - " # Normalize the images wrt each other by minimizing the MSE between GT and Source image\n", - " test_GT_norm,test_source_norm = norm_minmse(test_GT, test_source, normalize_gt=True)\n", - "\n", - " # -------------------------------- Prediction --------------------------------\n", - " test_prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\",i))\n", - "\n", - " # Normalize the images wrt each other by minimizing the MSE between GT and prediction\n", - " test_GT_norm,test_prediction_norm = norm_minmse(test_GT, test_prediction, normalize_gt=True) \n", - "\n", - "\n", - " # -------------------------------- Calculate the metric maps and save them --------------------------------\n", - "\n", - " # Calculate the SSIM maps\n", - " index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT_norm, test_prediction_norm)\n", - " index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT_norm, test_source_norm)\n", - "\n", - " #Save ssim_maps\n", - " img_SSIM_GTvsPrediction_32bit = np.float32(img_SSIM_GTvsPrediction)\n", - " io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/SSIM_GTvsPrediction_'+i,img_SSIM_GTvsPrediction_32bit)\n", - " img_SSIM_GTvsSource_32bit = np.float32(img_SSIM_GTvsSource)\n", - " io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/SSIM_GTvsSource_'+i,img_SSIM_GTvsSource_32bit)\n", - " \n", - " # Calculate the Root Squared Error (RSE) maps\n", - " img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n", - " img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n", - "\n", - " # Save SE maps\n", - " img_RSE_GTvsPrediction_32bit = np.float32(img_RSE_GTvsPrediction)\n", - " img_RSE_GTvsSource_32bit = np.float32(img_RSE_GTvsSource)\n", - " io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/RSE_GTvsPrediction_'+i,img_RSE_GTvsPrediction_32bit)\n", - " io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/RSE_GTvsSource_'+i,img_RSE_GTvsSource_32bit)\n", - "\n", - "\n", - " # -------------------------------- Calculate the RSE metrics and save them --------------------------------\n", - "\n", - " # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n", - " NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n", - " NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n", - " \n", - " # We can also measure the peak signal to noise ratio between the images\n", - " PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n", - " PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n", - "\n", - " writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource),str(PSNR_GTvsPrediction),str(PSNR_GTvsSource)])\n", - "\n", - "\n", - "# All data is now processed saved\n", - "Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same\n", - "\n", - "plt.figure(figsize=(15,15))\n", - "# Currently only displays the last computed set, from memory\n", - "# Target (Ground-truth)\n", - "plt.subplot(3,3,1)\n", - "plt.axis('off')\n", - "img_GT = io.imread(os.path.join(Target_QC_folder, Test_FileList[-1]))\n", - "plt.imshow(img_GT)\n", - "plt.title('Target',fontsize=15)\n", - "\n", - "# Source\n", - "plt.subplot(3,3,2)\n", - "plt.axis('off')\n", - "img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n", - "plt.imshow(img_Source)\n", - "plt.title('Source',fontsize=15)\n", - "\n", - "#Prediction\n", - "plt.subplot(3,3,3)\n", - "plt.axis('off')\n", - "img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction/\", Test_FileList[-1]))\n", - "plt.imshow(img_Prediction)\n", - "plt.title('Prediction',fontsize=15)\n", - "\n", - "#Setting up colours\n", - "cmap = plt.cm.CMRmap\n", - "\n", - "#SSIM between GT and Source\n", - "plt.subplot(3,3,5)\n", - "#plt.axis('off')\n", - "plt.tick_params(\n", - " axis='both', # changes apply to the x-axis and y-axis\n", - " which='both', # both major and minor ticks are affected\n", - " bottom=False, # ticks along the bottom edge are off\n", - " top=False, # ticks along the top edge are off\n", - " left=False, # ticks along the left edge are off\n", - " right=False, # ticks along the right edge are off\n", - " labelbottom=False,\n", - " labelleft=False) \n", - "imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n", - "plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n", - "plt.title('Target vs. Source',fontsize=15)\n", - "plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\n", - "plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n", - "\n", - "#SSIM between GT and Prediction\n", - "plt.subplot(3,3,6)\n", - "#plt.axis('off')\n", - "plt.tick_params(\n", - " axis='both', # changes apply to the x-axis and y-axis\n", - " which='both', # both major and minor ticks are affected\n", - " bottom=False, # ticks along the bottom edge are off\n", - " top=False, # ticks along the top edge are off\n", - " left=False, # ticks along the left edge are off\n", - " right=False, # ticks along the right edge are off\n", - " labelbottom=False,\n", - " labelleft=False) \n", - "imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n", - "plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n", - "plt.title('Target vs. Prediction',fontsize=15)\n", - "plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\n", - "\n", - "#Root Squared Error between GT and Source\n", - "plt.subplot(3,3,8)\n", - "#plt.axis('off')\n", - "plt.tick_params(\n", - " axis='both', # changes apply to the x-axis and y-axis\n", - " which='both', # both major and minor ticks are affected\n", - " bottom=False, # ticks along the bottom edge are off\n", - " top=False, # ticks along the top edge are off\n", - " left=False, # ticks along the left edge are off\n", - " right=False, # ticks along the right edge are off\n", - " labelbottom=False,\n", - " labelleft=False) \n", - "imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource, cmap = cmap, vmin=0, vmax = 1)\n", - "plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n", - "plt.title('Target vs. Source',fontsize=15)\n", - "plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsSource,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n", - "#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\n", - "plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n", - "\n", - "#Root Squared Error between GT and Prediction\n", - "plt.subplot(3,3,9)\n", - "#plt.axis('off')\n", - "plt.tick_params(\n", - " axis='both', # changes apply to the x-axis and y-axis\n", - " which='both', # both major and minor ticks are affected\n", - " bottom=False, # ticks along the bottom edge are off\n", - " top=False, # ticks along the top edge are off\n", - " left=False, # ticks along the left edge are off\n", - " right=False, # ticks along the right edge are off\n", - " labelbottom=False,\n", - " labelleft=False) \n", - "imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n", - "plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n", - "plt.title('Target vs. Prediction',fontsize=15)\n", - "plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsPrediction,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n", - "\n", - "\n", - "#Make a pdf summary of the QC results\n", - "\n", - "qc_pdf_export()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "fB8QNLekkCyZ" - }, - "source": [ - "# **6. Using the trained model**\n", - "\n", - "---\n", - "\n", - "In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "B2DrAOANkIWu" - }, - "source": [ - "## **6.1. Generate prediction(s) from unseen dataset**\n", - "---\n", - "Fill the below code to perform predictions using your model.\n", - "\n", - "The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n", - "\n", - "**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n", - "\n", - "**`Result_folder`:** This folder will contain the predicted output images." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "mELG8z-ykCKV" - }, - "outputs": [], - "source": [ - "#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then play the cell to predict outputs from your unseen images.\n", - "\n", - "Data_folder = \"\" #@param {type:\"string\"}\n", - "Result_folder = \"\" #@param {type:\"string\"}\n", - "\n", - "# model name and path\n", - "#@markdown ###Do you want to use the current trained model?\n", - "Use_the_current_trained_model = True #@param {type:\"boolean\"}\n", - "\n", - "#@markdown ###If not, provide the name of the model and path to model folder:\n", - "#@markdown #####During training, the model files are automatically saved inside a folder named after model_name in section 3. Provide the path to this folder below.\n", - "Prediction_model_folder = \"\" #@param {type:\"string\"}\n", - "\n", - "#Here we find the loaded model name and parent path\n", - "Prediction_model_name = os.path.basename(Prediction_model_folder)\n", - "Prediction_model_path = os.path.dirname(Prediction_model_folder)\n", - "\n", - "if (Use_the_current_trained_model): \n", - " print(\"Using current trained network\")\n", - " Prediction_model_name = model_name\n", - " Prediction_model_path = model_path\n", - "\n", - "full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n", - "if os.path.exists(full_Prediction_model_path):\n", - " print(\"The \"+Prediction_model_name+\" network will be used.\")\n", - "else:\n", - " W = '\\033[0m' # white (normal)\n", - " R = '\\033[31m' # red\n", - " print(R+'!! WARNING: The chosen model does not exist !!'+W)\n", - " print('Please make sure you provide a valid model path and model name before proceeding further.')\n", - "\n", - "\n", - "# Activate the (pre-)trained model\n", - "\n", - "\n", - "# Provide the code for performing predictions and saving them\n", - "\n", - "\n", - "print(\"Images saved into folder:\", Result_folder)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "JnSk14AJkRtJ" - }, - "source": [ - "## **6.2. Inspect the predicted output**\n", - "---\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "hlkZUhj4kQ2Z" - }, - "outputs": [], - "source": [ - "# @markdown ##Run this cell to display a randomly chosen input and its corresponding predicted output.\n", - "\n", - "# This will display a randomly chosen dataset input and predicted output\n", - "random_choice = random.choice(os.listdir(Data_folder))\n", - "x = imread(Data_folder+\"/\"+random_choice)\n", - "\n", - "os.chdir(Result_folder)\n", - "y = imread(Result_folder+\"/\"+random_choice)\n", - "\n", - "plt.figure(figsize=(16,8))\n", - "\n", - "plt.subplot(1,2,1)\n", - "plt.axis('off')\n", - "plt.imshow(x, interpolation='nearest')\n", - "plt.title('Input')\n", - "\n", - "plt.subplot(1,2,2)\n", - "plt.axis('off')\n", - "plt.imshow(y, interpolation='nearest')\n", - "plt.title('Predicted output');\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "gP7WDm6bkYkb" - }, - "source": [ - "## **6.3. Download your predictions**\n", - "---\n", - "\n", - "**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "JbOn8U-VkerU" - }, - "source": [ - "\n", - "#**Thank you for using YOUR NETWORK!**" - ] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "collapsed_sections": [], - "machine_shape": "hm", - "name": "Template_ZeroCostDL4Mic.ipynb", - "provenance": [ - { - "file_id": "1A26cn0nxWQCv-LuP3UBfyCWlKBGIo0RU", - "timestamp": 1611141557911 - }, - { - "file_id": "1MmLTCC0nyX3Akb9V4C_OVxM3X_M8u-eX", - "timestamp": 1610543191319 - }, - { - "file_id": "1paNjUObR5Rcr4BMGADJTz0PQBBLZDPrY", - "timestamp": 1602522500580 - }, - { - "file_id": "1WZRIoSBNcRUEq4-Rq5M4mDkIaOlEHnxz", - "timestamp": 1588762142860 - }, - { - "file_id": "10weAY0es-pEfHlACCaBCKK7PmgdoJqdh", - "timestamp": 1587728072051 - }, - { - "file_id": "10Ze0rFZoooyyTL_OIVWGdFJEhWE6_cSB", - "timestamp": 1586789421439 - }, - { - "file_id": "1SsGyUbWcMaLGHFepMuKElRNYLdEBUwf6", - "timestamp": 1583244509550 - } - ], - "toc_visible": true - }, - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.4" - } - }, - "nbformat": 4, - "nbformat_minor": 1 -} +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"Template_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1A26cn0nxWQCv-LuP3UBfyCWlKBGIo0RU","timestamp":1611141557911},{"file_id":"1MmLTCC0nyX3Akb9V4C_OVxM3X_M8u-eX","timestamp":1610543191319},{"file_id":"1paNjUObR5Rcr4BMGADJTz0PQBBLZDPrY","timestamp":1602522500580},{"file_id":"1WZRIoSBNcRUEq4-Rq5M4mDkIaOlEHnxz","timestamp":1588762142860},{"file_id":"10weAY0es-pEfHlACCaBCKK7PmgdoJqdh","timestamp":1587728072051},{"file_id":"10Ze0rFZoooyyTL_OIVWGdFJEhWE6_cSB","timestamp":1586789421439},{"file_id":"1SsGyUbWcMaLGHFepMuKElRNYLdEBUwf6","timestamp":1583244509550}],"collapsed_sections":[],"machine_shape":"hm"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.7"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"gfIn-nNNhdzh"},"source":[" This is a template for a ZeroCostDL4Mic notebook and needs to be filled with appropriate model code and information.\n","\n"," Thank you for contributing to the ZeroCostDL4Mic Project. Please use this notebook as a template for your implementation. When your notebook is completed, please upload it to your github page and send us a link so we can reference your work.\n","\n"," If possible, remember to provide separate training and test datasets (for quality control) containing source and target images with your finished notebooks. This is very useful so that ZeroCostDL4Mic users can test your notebook. "]},{"cell_type":"markdown","metadata":{"id":"Av1qDcfthk1a"},"source":["# **Name of the Network**\n","\n","---\n","\n"," Description of the network and link to publication with author reference. [author et al, etc.](URL).\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is inspired from the *Zero-Cost Deep-Learning to Enhance Microscopy* project (ZeroCostDL4Mic) (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki) and was created by **Your name**\n","\n","This notebook is based on the following paper: \n","\n","**Original Title of the paper**, Journal, volume, pages, year and complete author list, [link to paper](URL)\n","\n","And source code found in: *provide github link or equivalent if applicable*\n","\n","Provide information on dataset availability and link for download if applicable.\n","\n","\n","**Please also cite this original paper when using or developing this notebook.**"]},{"cell_type":"markdown","metadata":{"id":"TKktwSaWhq9e"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use ZeroCostDL4Mic notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"_v_Jl2QZhvLh"},"source":["#**0. Before getting started**\n","---\n"," Give information on the required structure and dataype of the training dataset.\n","\n"," Provide information on quality control dataset, such as:\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n"," **Additionally, the corresponding input and output files need to have the same name**.\n","\n"," Please note that you currently can **only use .tif files!**\n","\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Low SNR images (Training_source)\n"," - img_1.tif, img_2.tif, ...\n"," - High SNR images (Training_target)\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - Low SNR images\n"," - img_1.tif, img_2.tif\n"," - High SNR images\n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"NvJvtQQgiVDF"},"source":["# **1. Install Name of the network and dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"XMi71QrxiZbS","cellView":"form"},"source":["#@markdown ##Install Network and dependencies\n","\n","#Libraries contains information of certain topics. \n","\n","#Put the imported code and libraries here\n","\n","Notebook_version = ['1.12'] #Contact the ZeroCostDL4Mic team to find out about the version number\n","\n","from builtins import any as b_any\n","\n","def get_requirements_path():\n"," # Store requirements file in 'contents' directory \n"," current_dir = os.getcwd()\n"," dir_count = current_dir.count('/') - 1\n"," path = '../' * (dir_count) + 'requirements.txt'\n"," return path\n","\n","def filter_files(file_list, filter_list):\n"," filtered_list = []\n"," for fname in file_list:\n"," if b_any(fname.split('==')[0] in s for s in filter_list):\n"," filtered_list.append(fname)\n"," return filtered_list\n","\n","def build_requirements_file(before, after):\n"," path = get_requirements_path()\n","\n"," # Exporting requirements.txt for local run\n"," !pip freeze > $path\n","\n"," # Get minimum requirements file\n"," df = pd.read_csv(path, delimiter = \"\\n\")\n"," mod_list = [m.split('.')[0] for m in after if not m in before]\n"," req_list_temp = df.values.tolist()\n"," req_list = [x[0] for x in req_list_temp]\n","\n"," # Replace with package name and handle cases where import name is different to module name\n"," mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]\n"," mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list] \n"," filtered_list = filter_files(req_list, mod_replace_list)\n","\n"," file=open(path,'w')\n"," for item in filtered_list:\n"," file.writelines(item + '\\n')\n","\n"," file.close()\n","\n","import sys\n","before = [str(m) for m in sys.modules]\n","\n","!pip install fpdf\n","\n","# Below are templates for the function definitions for the export\n","# of pdf summaries for training and qc. You will need to adjust these functions\n","# with the variables and other parameters as necessary to make them\n","# work for your project\n","from datetime import datetime\n","\n","def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n"," # save FPDF() class into a \n"," # variable pdf \n"," #from datetime import datetime\n","\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = \"Your network's name\"\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," # add another cell \n"," if trained:\n"," training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n"," pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n"," pdf.ln(1)\n","\n"," Header_2 = 'Information for your materials and methods:'\n"," pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n"," #print(all_packages)\n","\n"," #Main Packages\n"," main_packages = ''\n"," version_numbers = []\n"," for name in ['tensorflow','numpy','Keras','csbdeep']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n"," cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n"," cuda_version = cuda_version.stdout.decode('utf-8')\n"," cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n"," gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n"," gpu_name = gpu_name.stdout.decode('utf-8')\n"," gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n"," #print(cuda_version[cuda_version.find(', V')+3:-1])\n"," #print(gpu_name)\n","\n"," shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n"," dataset_size = len(os.listdir(Training_source))\n","\n"," text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(dataset_size*number_of_patches)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+config.train_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," if pretrained_model:\n"," text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(dataset_size*number_of_patches)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+config.train_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was re-trained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.multi_cell(190, 5, txt = text, align='L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n"," pdf.set_font('')\n"," if augmentation:\n"," aug_text = 'The dataset was augmented by a factor of '+str(Multiply_dataset_by)+' by'\n"," if rotate_270_degrees != 0 or rotate_90_degrees != 0:\n"," aug_text = aug_text+'\\n- rotation'\n"," if flip_left_right != 0 or flip_top_bottom != 0:\n"," aug_text = aug_text+'\\n- flipping'\n"," if random_zoom_magnification != 0:\n"," aug_text = aug_text+'\\n- random zoom magnification'\n"," if random_distortion != 0:\n"," aug_text = aug_text+'\\n- random distortion'\n"," if image_shear != 0:\n"," aug_text = aug_text+'\\n- image shearing'\n"," if skew_image != 0:\n"," aug_text = aug_text+'\\n- image skewing'\n"," else:\n"," aug_text = 'No augmentation was used for training.'\n"," pdf.multi_cell(190, 5, txt=aug_text, align='L')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," if Use_Default_Advanced_Parameters:\n"," pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n"," pdf.cell(200, 5, txt='The following parameters were used for training:')\n"," pdf.ln(1)\n"," html = \"\"\" \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
number_of_epochs{0}
patch_size{1}
number_of_patches{2}
batch_size{3}
number_of_steps{4}
percentage_validation{5}
initial_learning_rate{6}
\n"," \"\"\".format(number_of_epochs,str(patch_size)+'x'+str(patch_size),number_of_patches,batch_size,number_of_steps,percentage_validation,initial_learning_rate)\n"," pdf.write_html(html)\n","\n"," #pdf.multi_cell(190, 5, txt = text_2, align='L')\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," pdf.ln(1)\n"," pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(29, 5, txt= 'Training_source:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(27, 5, txt= 'Training_target:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n"," #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(22, 5, txt= 'Model Path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n"," pdf.ln(1)\n"," pdf.cell(60, 5, txt = 'Example Training pair', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread(\"/content/NetworkNameExampleData.png\").shape\n"," pdf.image(\"/content/NetworkNameExampleData.png\", x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- Your networks name: first author et al. \"Title of publication\" Journal, year'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," if augmentation:\n"," ref_3 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\n"," pdf.multi_cell(190, 5, txt = ref_3, align='L')\n"," pdf.ln(3)\n"," reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(model_path+'/'+model_name+'/'+model_name+\"_training_report.pdf\")\n","\n","\n","#Make a pdf summary of the QC results\n","\n","def qc_pdf_export():\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = \"Your network's name\"\n"," #model_name = os.path.basename(full_QC_model_path)\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(2)\n"," pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'Quality Control/QC_example_data.png').shape\n"," if os.path.exists(full_QC_model_path+'Quality Control/lossCurvePlots.png'):\n"," pdf.image(full_QC_model_path+'Quality Control/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/13))\n"," else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.', align='L')\n"," pdf.ln(2)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(3)\n"," pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'Quality Control/QC_example_data.png').shape\n"," pdf.image(full_QC_model_path+'Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/10))\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n","\n"," pdf.ln(1)\n"," html = \"\"\"\n"," \n"," \n"," \"\"\"\n"," with open(full_QC_model_path+'Quality Control/QC_metrics_'+QC_model_name+'.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," mSSIM_PvsGT = header[1]\n"," mSSIM_SvsGT = header[2]\n"," NRMSE_PvsGT = header[3]\n"," NRMSE_SvsGT = header[4]\n"," PSNR_PvsGT = header[5]\n"," PSNR_SvsGT = header[6]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,mSSIM_PvsGT,mSSIM_SvsGT,NRMSE_PvsGT,NRMSE_SvsGT,PSNR_PvsGT,PSNR_SvsGT)\n"," html = html+header\n"," for row in metrics:\n"," image = row[0]\n"," mSSIM_PvsGT = row[1]\n"," mSSIM_SvsGT = row[2]\n"," NRMSE_PvsGT = row[3]\n"," NRMSE_SvsGT = row[4]\n"," PSNR_PvsGT = row[5]\n"," PSNR_SvsGT = row[6]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,str(round(float(mSSIM_PvsGT),3)),str(round(float(mSSIM_SvsGT),3)),str(round(float(NRMSE_PvsGT),3)),str(round(float(NRMSE_SvsGT),3)),str(round(float(PSNR_PvsGT),3)),str(round(float(PSNR_SvsGT),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}{3}{4}{5}{6}
{0}{1}{2}{3}{4}{5}{6}
\"\"\"\n","\n"," pdf.write_html(html)\n","\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- Your networks name: first author et al. \"Title of publication\" Journal, year'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n"," pdf.ln(3)\n"," reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(full_QC_model_path+'Quality Control/'+QC_model_name+'_QC_report.pdf')\n","\n","print(\"Depencies installed and imported.\")\n","\n","# Build requirements file for local run\n","# -- the developers should leave this below all the other installations\n","after = [str(m) for m in sys.modules]\n","build_requirements_file(before, after)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"EPOJkyFYiA15"},"source":["# **2. Initialise the Colab session**\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"8dvLrwF_iEXS"},"source":["\n","## **2.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelerator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"cellView":"form","id":"8o_-wbDOiIHF"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime settings are correct then Google did not allocate GPU to your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n","\n","from tensorflow.python.client import device_lib \n","device_lib.list_local_devices()\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"kEyJvvxSiN6L"},"source":["## **2.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"cellView":"form","id":"WWVR1U5tiM9h"},"source":["#@markdown ##Run this cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","#mounts user's Google Drive to Google Colab.\n","\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"jKaeBnSuifZn"},"source":["# **3. Select your paths and parameters**\n","\n","---\n","\n","The code below allows the user to enter the paths to where the training data is and to define the training parameters.\n"]},{"cell_type":"markdown","metadata":{"id":"StTGluw2iidc"},"source":["## **3.1. Setting the main training parameters**\n","---\n",""]},{"cell_type":"markdown","metadata":{"id":"GyRjBdClimfK"},"source":[" **Paths for training, predictions and results**\n","\n"," Fill the parameters here as needed and update the code. Note that the sections containing `Training_source`, `Training target`, `model_name` and `model_path` should appear in your notebook.\n","\n","**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source and Training_target data respectively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**Training parameters**\n","\n","**`number_of_epochs`:**Give estimates for training performance given a number of epochs and provide a default value. **Default value:**\n","\n","**`other_parameters`:**Give other parameters or default values **Default value:**\n","\n","**If additional parameter above affects the training of the notebook give a brief explanation and how problems can be mitigated** \n","\n","\n","**Advanced parameters - experienced users only**\n","\n","**`number_of_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each patch is seen at least once per epoch. **Default value: Number of patch / batch_size**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 16**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during training. **Default value: 10** "]},{"cell_type":"code","metadata":{"cellView":"form","id":"i1sKnXrDieiR"},"source":["class bcolors:\n"," WARNING = '\\033[31m'\n","\n","#@markdown ###Path to training images:\n","\n","Training_source = \"\" #@param {type:\"string\"}\n","\n","# Ground truth images\n","Training_target = \"\" #@param {type:\"string\"}\n","\n","# model name and path\n","#@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","\n","# other parameters for training.\n","#@markdown ###Training Parameters\n","#@markdown Number of epochs:\n","\n","number_of_epochs = 50#@param {type:\"number\"}\n","\n","#@markdown Other parameters, add as necessary\n","other_parameters = 80#@param {type:\"number\"} # in pixels\n","\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","#@markdown ###If not, please input:\n","\n","number_of_steps = 400#@param {type:\"number\"}\n","batch_size = 16#@param {type:\"number\"}\n","percentage_validation = 10 #@param {type:\"number\"}\n","\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 16\n"," percentage_validation = 10\n","\n","#Here we define the percentage to use for validation\n","percentage = percentage_validation/100\n","\n","\n","#here we check that no model with the same name already exist, if so delete\n","if os.path.exists(model_path+'/'+model_name):\n"," shutil.rmtree(model_path+'/'+model_name)\n","\n","\n","# The shape of the images.\n","x = imread(InputFile)\n","y = imread(OutputFile)\n","\n","print('Loaded Input images (number, width, length) =', x.shape)\n","print('Loaded Output images (number, width, length) =', y.shape)\n","print(\"Parameters initiated.\")\n","\n","# This will display a randomly chosen dataset input and output\n","random_choice = random.choice(os.listdir(Training_source))\n","x = imread(Training_source+\"/\"+random_choice)\n","\n","\n","# Here we check that the input images contains the expected dimensions\n","if len(x.shape) == 2:\n"," print(\"Image dimensions (y,x)\",x.shape)\n","\n","if not len(x.shape) == 2:\n"," print(bcolors.WARNING +\"Your images appear to have the wrong dimensions. Image dimension\",x.shape)\n","\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","#Hyperparameters failsafes\n","\n","# Here we check that patch_size is smaller than the smallest xy dimension of the image \n","\n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is divisible by 8\n","if not patch_size % 8 == 0:\n"," patch_size = ((int(patch_size / 8)-1) * 8)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 8; therefore the patch_size chosen is now:\",patch_size)\n","\n","\n","\n","os.chdir(Training_target)\n","y = imread(Training_target+\"/\"+random_choice)\n","\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest')\n","plt.title('Training source')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(y, interpolation='nearest')\n","plt.title('Training target')\n","plt.axis('off');\n","#We save the example data here to use it in the pdf export of the training\n","plt.savefig('/content/NetworkNameExampleData.png', bbox_inches='tight', pad_inches=0)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"VLYZQA6GitQL"},"source":["## **3.2. Data augmentation**\n","---\n",""]},{"cell_type":"markdown","metadata":{"id":"M4GfK6-1iwbf"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n","Data augmentation is performed here by rotating the patches in XY-Plane and flip them along X-Axis. This only works if the images are square in XY.\n","\n","Add any other information which is necessary to run augmentation with your notebook/data."]},{"cell_type":"code","metadata":{"cellView":"form","id":"EkBGtraZi3Ob"},"source":["#@markdown ###Add any further useful augmentations\n","Use_Data_augmentation = False #@param{type:\"boolean\"}\n","\n","#@markdown Select this option if you want to use augmentation to increase the size of your dataset\n","\n","#@markdown **Rotate each image 3 times by 90 degrees.**\n","Rotation = True #@param{type:\"boolean\"}\n","\n","#@markdown **Flip each image once around the x axis of the stack.**\n","Flip = True #@param{type:\"boolean\"}\n","\n","\n","#@markdown **Would you like to save your augmented images?**\n","\n","Save_augmented_images = False #@param {type:\"boolean\"}\n","\n","Saving_path = \"\" #@param {type:\"string\"}\n","\n","\n","if not Save_augmented_images:\n"," Saving_path= \"/content\"\n","\n","\n","def rotation_aug(Source_path, Target_path, flip=False):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path)\n"," \n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," # Source Rotation\n"," source_img_90 = np.rot90(source_img,axes=(1,2))\n"," source_img_180 = np.rot90(source_img_90,axes=(1,2))\n"," source_img_270 = np.rot90(source_img_180,axes=(1,2))\n","\n"," # Target Rotation\n"," target_img_90 = np.rot90(target_img,axes=(1,2))\n"," target_img_180 = np.rot90(target_img_90,axes=(1,2))\n"," target_img_270 = np.rot90(target_img_180,axes=(1,2))\n","\n"," # Add a flip to the rotation\n"," \n"," if flip == True:\n"," source_img_lr = np.fliplr(source_img)\n"," source_img_90_lr = np.fliplr(source_img_90)\n"," source_img_180_lr = np.fliplr(source_img_180)\n"," source_img_270_lr = np.fliplr(source_img_270)\n","\n"," target_img_lr = np.fliplr(target_img)\n"," target_img_90_lr = np.fliplr(target_img_90)\n"," target_img_180_lr = np.fliplr(target_img_180)\n"," target_img_270_lr = np.fliplr(target_img_270)\n","\n"," #source_img_90_ud = np.flipud(source_img_90)\n"," \n"," # Save the augmented files\n"," # Source images\n"," io.imsave(Saving_path+'/augmented_source/'+image,source_img)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_90.tif',source_img_90)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_180.tif',source_img_180)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_270.tif',source_img_270)\n"," # Target images\n"," io.imsave(Saving_path+'/augmented_target/'+image,target_img)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_90.tif',target_img_90)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_180.tif',target_img_180)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_270.tif',target_img_270)\n","\n"," if flip == True:\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_90_lr.tif',source_img_90_lr)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_180_lr.tif',source_img_180_lr)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_270_lr.tif',source_img_270_lr)\n","\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_90_lr.tif',target_img_90_lr)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_180_lr.tif',target_img_180_lr)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_270_lr.tif',target_img_270_lr)\n","\n","def flip(Source_path, Target_path):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path) \n","\n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," source_img_lr = np.fliplr(source_img)\n"," target_img_lr = np.fliplr(target_img)\n","\n"," io.imsave(Saving_path+'/augmented_source/'+image,source_img)\n"," io.imsave(Saving_path+'/augmented_source/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n","\n"," io.imsave(Saving_path+'/augmented_target/'+image,target_img)\n"," io.imsave(Saving_path+'/augmented_target/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n","\n","\n","if Use_Data_augmentation:\n","\n"," if os.path.exists(Saving_path+'/augmented_source'):\n"," shutil.rmtree(Saving_path+'/augmented_source')\n"," os.mkdir(Saving_path+'/augmented_source')\n","\n"," if os.path.exists(Saving_path+'/augmented_target'):\n"," shutil.rmtree(Saving_path+'/augmented_target') \n"," os.mkdir(Saving_path+'/augmented_target')\n","\n"," print(\"Data augmentation enabled\")\n"," print(\"Data augmentation in progress....\")\n","\n"," if Rotation == True:\n"," rotation_aug(Training_source,Training_target,flip=Flip)\n"," \n"," elif Rotation == False and Flip == True:\n"," flip(Training_source,Training_target)\n"," print(\"Done\")\n","\n","\n","if not Use_Data_augmentation:\n"," print(bcolors.WARNING+\"Data augmentation disabled\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-Y-47ZmFiyG_"},"source":["## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a model of Your Network**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pret-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"cellView":"form","id":"jSb9luhrjHe-"},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","\n","Weights_choice = \"last\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","\n","# --------------------- Download the a model provided in the XXX ------------------------\n","\n"," if pretrained_model_choice == \"Model_name\":\n"," pretrained_model_name = \"Model_name\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the 2D_Demo_Model_from_Stardist_2D_paper\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path) \n"," wget.download(\"\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".h5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: weights_last.h5 pretrained model does not exist')\n"," Use_pretrained_model = False\n","\n"," \n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print(bcolors.WARNING+'No pretrained nerwork will be used.')\n","\n","\n","#@markdown ### You will need to add or replace the code that loads any previously trained weights to the notebook here."],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"sjTtP2OmjMqM"},"source":["# **4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"yQ9NgI6XjQIk"},"source":["## **4.1. Train the network**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches.\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"code","metadata":{"cellView":"form","id":"SVUd0Lr0jUjy"},"source":["import time\n","import csv\n","\n","# Export the training parameters as pdf (before training, in case training fails) \n","pdf_export(augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)\n","\n","start = time.time()\n","\n","#@markdown ##Start training\n","\n","# Start Training\n","\n","#Insert the code necessary to initiate training of your model\n","\n","#Note that the notebook should load weights either from the model that is \n","#trained from scratch or if the pretrained weights are used (3.3.)\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","# Export the training parameters as pdf (after training)\n","pdf_export(trained = True, augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"1Tm3aimXjZ1B"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"cellView":"form","id":"QAXu1FR0jYZC"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the name of the model and path to model folder:\n","#@markdown #####During training, the model files are automatically saved inside a folder named after model_name in section 3. Provide the path to this folder below. \n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"ULMuc37njkXM"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"cellView":"form","id":"1VCvEofKjjHN"},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","import csv\n","from matplotlib import pyplot as plt\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(QC_model_path+'/'+QC_model_name+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/Quality Control/lossCurvePlots.png')\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"smiWe2wcjwTc"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n"," Update the code below to perform predictions on your quality control dataset. Use the metrics that are the most meaningful to assess the quality of the prediction.\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"cellView":"form","id":"Z179Zxgtj0PP"},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","# Create a quality control/Prediction Folder\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\")\n","\n","# Insert code to activate the pretrained model if necessary. \n","\n","# List Tif images in Source_QC_folder\n","Source_QC_folder_tif = Source_QC_folder+\"/*.tif\"\n","Z = sorted(glob(Source_QC_folder_tif))\n","Z = list(map(imread,Z))\n","print('Number of test dataset found in the folder: '+str(len(Z)))\n","\n","\n","# Insert code to perform predictions on all datasets in the Source_QC folder\n","\n","\n","def ssim(img1, img2):\n"," return structural_similarity(img1,img2,data_range=1.,full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n","\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/QC_metrics_\"+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\"]) \n","\n"," # Let's loop through the provided dataset in the QC folders\n","\n","\n"," for i in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder,i)):\n"," print('Running QC on: '+i)\n"," # -------------------------------- Target test data (Ground truth) --------------------------------\n"," test_GT = io.imread(os.path.join(Target_QC_folder, i))\n","\n"," # -------------------------------- Source test data --------------------------------\n"," test_source = io.imread(os.path.join(Source_QC_folder,i))\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and Source image\n"," test_GT_norm,test_source_norm = norm_minmse(test_GT, test_source, normalize_gt=True)\n","\n"," # -------------------------------- Prediction --------------------------------\n"," test_prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\",i))\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and prediction\n"," test_GT_norm,test_prediction_norm = norm_minmse(test_GT, test_prediction, normalize_gt=True) \n","\n","\n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\n","\n"," # Calculate the SSIM maps\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT_norm, test_prediction_norm)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT_norm, test_source_norm)\n","\n"," #Save ssim_maps\n"," img_SSIM_GTvsPrediction_32bit = np.float32(img_SSIM_GTvsPrediction)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/SSIM_GTvsPrediction_'+i,img_SSIM_GTvsPrediction_32bit)\n"," img_SSIM_GTvsSource_32bit = np.float32(img_SSIM_GTvsSource)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/SSIM_GTvsSource_'+i,img_SSIM_GTvsSource_32bit)\n"," \n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n","\n"," # Save SE maps\n"," img_RSE_GTvsPrediction_32bit = np.float32(img_RSE_GTvsPrediction)\n"," img_RSE_GTvsSource_32bit = np.float32(img_RSE_GTvsSource)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/RSE_GTvsPrediction_'+i,img_RSE_GTvsPrediction_32bit)\n"," io.imsave(QC_model_path+'/'+QC_model_name+'/Quality Control/RSE_GTvsSource_'+i,img_RSE_GTvsSource_32bit)\n","\n","\n"," # -------------------------------- Calculate the RSE metrics and save them --------------------------------\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n"," \n"," # We can also measure the peak signal to noise ratio between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n","\n"," writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource),str(PSNR_GTvsPrediction),str(PSNR_GTvsSource)])\n","\n","\n","# All data is now processed saved\n","Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same\n","\n","plt.figure(figsize=(15,15))\n","# Currently only displays the last computed set, from memory\n","# Target (Ground-truth)\n","plt.subplot(3,3,1)\n","plt.axis('off')\n","img_GT = io.imread(os.path.join(Target_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_GT)\n","plt.title('Target',fontsize=15)\n","\n","# Source\n","plt.subplot(3,3,2)\n","plt.axis('off')\n","img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_Source)\n","plt.title('Source',fontsize=15)\n","\n","#Prediction\n","plt.subplot(3,3,3)\n","plt.axis('off')\n","img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction/\", Test_FileList[-1]))\n","plt.imshow(img_Prediction)\n","plt.title('Prediction',fontsize=15)\n","\n","#Setting up colours\n","cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Source\n","plt.subplot(3,3,5)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\n","plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n","plt.subplot(3,3,6)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n","plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\n","\n","#Root Squared Error between GT and Source\n","plt.subplot(3,3,8)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource, cmap = cmap, vmin=0, vmax = 1)\n","plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Source',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsSource,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n","#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\n","plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#Root Squared Error between GT and Prediction\n","plt.subplot(3,3,9)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n","plt.title('Target vs. Prediction',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsPrediction,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n","\n","\n","#Make a pdf summary of the QC results\n","\n","qc_pdf_export()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"fB8QNLekkCyZ"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"B2DrAOANkIWu"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","Fill the below code to perform predictions using your model.\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images."]},{"cell_type":"code","metadata":{"cellView":"form","id":"mELG8z-ykCKV"},"source":["#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then play the cell to predict outputs from your unseen images.\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, provide the name of the model and path to model folder:\n","#@markdown #####During training, the model files are automatically saved inside a folder named after model_name in section 3. Provide the path to this folder below.\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","# Activate the (pre-)trained model\n","\n","\n","# Provide the code for performing predictions and saving them\n","\n","\n","print(\"Images saved into folder:\", Result_folder)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"JnSk14AJkRtJ"},"source":["## **6.2. Inspect the predicted output**\n","---\n","\n"]},{"cell_type":"code","metadata":{"cellView":"form","id":"hlkZUhj4kQ2Z"},"source":["# @markdown ##Run this cell to display a randomly chosen input and its corresponding predicted output.\n","\n","# This will display a randomly chosen dataset input and predicted output\n","random_choice = random.choice(os.listdir(Data_folder))\n","x = imread(Data_folder+\"/\"+random_choice)\n","\n","os.chdir(Result_folder)\n","y = imread(Result_folder+\"/\"+random_choice)\n","\n","plt.figure(figsize=(16,8))\n","\n","plt.subplot(1,2,1)\n","plt.axis('off')\n","plt.imshow(x, interpolation='nearest')\n","plt.title('Input')\n","\n","plt.subplot(1,2,2)\n","plt.axis('off')\n","plt.imshow(y, interpolation='nearest')\n","plt.title('Predicted output');\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"gP7WDm6bkYkb"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"owyvVA3ndrwA"},"source":["# **7. Version log**\n","---\n","**vXXX**: \n","\n","\n","* Indicate here the modifications made to your notebook as it evolves.\n","\n"]},{"cell_type":"markdown","metadata":{"id":"JbOn8U-VkerU"},"source":["\n","#**Thank you for using YOUR NETWORK!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/U-Net_2D_ZeroCostDL4Mic.ipynb b/Colab_notebooks/U-Net_2D_ZeroCostDL4Mic.ipynb index 6c842bd3..e4765d01 100644 --- a/Colab_notebooks/U-Net_2D_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/U-Net_2D_ZeroCostDL4Mic.ipynb @@ -1 +1 @@ -{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"U-Net_2D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1USP7Bmd4UEdhp9cOBc_wlqDXnZjMHk_f","timestamp":1622215174280},{"file_id":"1EZG34jBKULVmO__Fmv7Lr76sVHIMxwJx","timestamp":1622041273450},{"file_id":"1MmLTCC0nyX3Akb9V4C_OVxM3X_M8u-eX","timestamp":1610543191319},{"file_id":"1paNjUObR5Rcr4BMGADJTz0PQBBLZDPrY","timestamp":1602522500580},{"file_id":"1WZRIoSBNcRUEq4-Rq5M4mDkIaOlEHnxz","timestamp":1588762142860},{"file_id":"10weAY0es-pEfHlACCaBCKK7PmgdoJqdh","timestamp":1587728072051},{"file_id":"10Ze0rFZoooyyTL_OIVWGdFJEhWE6_cSB","timestamp":1586789421439},{"file_id":"1SsGyUbWcMaLGHFepMuKElRNYLdEBUwf6","timestamp":1583244509550}],"collapsed_sections":[],"toc_visible":true,"machine_shape":"hm"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.7"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"IkSguVy8Xv83"},"source":["# **U-Net (2D)**\n","---\n","\n","U-Net is an encoder-decoder network architecture originally used for image segmentation, first published by [Ronneberger *et al.*](https://arxiv.org/abs/1505.04597). The first half of the U-Net architecture is a downsampling convolutional neural network which acts as a feature extractor from input images. The other half upsamples these results and restores an image by combining results from downsampling with the upsampled images.\n","\n"," **This particular notebook enables image segmentation of 2D dataset. If you are interested in 3D dataset, you should use the 3D U-Net notebook instead.**\n","\n","---\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the papers: \n","\n","**U-Net: Convolutional Networks for Biomedical Image Segmentation** by Ronneberger *et al.* published on arXiv in 2015 (https://arxiv.org/abs/1505.04597)\n","\n","and \n","\n","**U-Net: deep learning for cell counting, detection, and morphometry** by Thorsten Falk *et al.* in Nature Methods 2019\n","(https://www.nature.com/articles/s41592-018-0261-2)\n","And source code found in: /~https://github.com/zhixuhao/unet by *Zhixuhao*\n","\n","**Please also cite this original paper when using or developing this notebook.** "]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"gKDLkLWUd-YX"},"source":["# **0. Before getting started**\n","---\n","\n","Before you run the notebook, please ensure that you are logged into your Google account and have the training and/or data to process in your Google Drive.\n","\n","For U-Net to train, **it needs to have access to a paired training dataset corresponding to images and their corresponding masks**. Information on how to generate a training dataset is available in our Wiki page: /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n","Additionally, the corresponding Training_source and Training_target files need to have **the same name**.\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Training_source\n"," - img_1.tif, img_2.tif, ...\n"," - Training_target\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - Training_source\n"," - img_1.tif, img_2.tif\n"," - Training_target \n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **1. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelerator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"zCvebubeSaGY","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi\n","\n","# from tensorflow.python.client import device_lib \n","# device_lib.list_local_devices()\n","\n","# print the tensorflow version\n","print('Tensorflow version is ' + str(tf.__version__))\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"sNIVx8_CLolt"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","cellView":"form"},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"AdN8B91xZO0x"},"source":["# **2. Install U-Net dependencies**\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"UGWnGOFsf07b"},"source":["## **2.1. Install key dependencies**\n","---\n"," "]},{"cell_type":"code","metadata":{"cellView":"form","id":"uc0haIa-fZiG"},"source":["#@markdown ##Play to install U-Net dependencies\n","\n","!pip install data\n","!pip install fpdf\n","!pip install h5py==2.10\n","\n","#Force session restart\n","exit(0)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"I4O5zctbf4Gb"},"source":["## **2.2. Restart your runtime**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"id":"F4jMunHMfq_c"},"source":["** Your Runtime has automatically restarted. This is normal.**\n","\n"]},{"cell_type":"markdown","metadata":{"id":"iiX3Ly-7gA5h"},"source":["## **2.3. Load key dependencies**\n","---\n"," "]},{"cell_type":"code","metadata":{"id":"fq21zJVFNASx","cellView":"form"},"source":["Notebook_version = ['1.12.1']\n","\n","from builtins import any as b_any\n","\n","def get_requirements_path():\n"," # Store requirements file in 'contents' directory \n"," current_dir = os.getcwd()\n"," dir_count = current_dir.count('/') - 1\n"," path = '../' * (dir_count) + 'requirements.txt'\n"," return path\n","\n","def filter_files(file_list, filter_list):\n"," filtered_list = []\n"," for fname in file_list:\n"," if b_any(fname.split('==')[0] in s for s in filter_list):\n"," filtered_list.append(fname)\n"," return filtered_list\n","\n","def build_requirements_file(before, after):\n"," path = get_requirements_path()\n","\n"," # Exporting requirements.txt for local run\n"," !pip freeze > $path\n","\n"," # Get minimum requirements file\n"," df = pd.read_csv(path, delimiter = \"\\n\")\n"," mod_list = [m.split('.')[0] for m in after if not m in before]\n"," req_list_temp = df.values.tolist()\n"," req_list = [x[0] for x in req_list_temp]\n","\n"," # Replace with package name and handle cases where import name is different to module name\n"," mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]\n"," mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list] \n"," filtered_list = filter_files(req_list, mod_replace_list)\n","\n"," file=open(path,'w')\n"," for item in filtered_list:\n"," file.writelines(item + '\\n')\n","\n"," file.close()\n","\n","import sys\n","before = [str(m) for m in sys.modules]\n","\n","#@markdown ##Load key U-Net dependencies\n","\n","#As this notebokk depends mostly on keras which runs a tensorflow backend (which in turn is pre-installed in colab)\n","#only the data library needs to be additionally installed.\n","%tensorflow_version 1.x\n","import tensorflow as tf\n","# print(tensorflow.__version__)\n","# print(\"Tensorflow enabled.\")\n","\n","\n","# Keras imports\n","from keras import models\n","from keras.models import Model, load_model\n","from keras.layers import Input, Conv2D, MaxPooling2D, Dropout, concatenate, UpSampling2D\n","from keras.optimizers import Adam\n","# from keras.callbacks import ModelCheckpoint, LearningRateScheduler, CSVLogger # we currently don't use any other callbacks from ModelCheckpoints\n","from keras.callbacks import ModelCheckpoint\n","from keras.callbacks import ReduceLROnPlateau\n","from keras.preprocessing.image import ImageDataGenerator, img_to_array, load_img\n","from keras import backend as keras\n","\n","# General import\n","from __future__ import print_function\n","import numpy as np\n","import pandas as pd\n","import os\n","import glob\n","from skimage import img_as_ubyte, io, transform\n","import matplotlib as mpl\n","from matplotlib import pyplot as plt\n","from matplotlib.pyplot import imread\n","from pathlib import Path\n","import shutil\n","import random\n","import time\n","import csv\n","import sys\n","from math import ceil\n","from fpdf import FPDF, HTMLMixin\n","from pip._internal.operations.freeze import freeze\n","import subprocess\n","# Imports for QC\n","from PIL import Image\n","from scipy import signal\n","from scipy import ndimage\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","\n","# For sliders and dropdown menu and progress bar\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","# from tqdm import tqdm\n","from tqdm.notebook import tqdm\n","\n","from sklearn.feature_extraction import image\n","from skimage import img_as_ubyte, io, transform\n","from skimage.util.shape import view_as_windows\n","\n","from datetime import datetime\n","\n","\n","# Suppressing some warnings\n","import warnings\n","warnings.filterwarnings('ignore')\n","\n","\n","\n","\n","def create_patches(Training_source, Training_target, patch_width, patch_height, min_fraction):\n"," \"\"\"\n"," Function creates patches from the Training_source and Training_target images. \n"," The steps parameter indicates the offset between patches and, if integer, is the same in x and y.\n"," Saves all created patches in two new directories in the /content folder.\n","\n"," min_fraction is the minimum fraction of pixels that need to be foreground to be considered as a valid patch\n","\n"," Returns: - Two paths to where the patches are now saved\n"," \"\"\"\n"," DEBUG = False\n","\n"," Patch_source = os.path.join('/content','img_patches')\n"," Patch_target = os.path.join('/content','mask_patches')\n"," Patch_rejected = os.path.join('/content','rejected')\n"," \n","\n"," #Here we save the patches, in the /content directory as they will not usually be needed after training\n"," if os.path.exists(Patch_source):\n"," shutil.rmtree(Patch_source)\n"," if os.path.exists(Patch_target):\n"," shutil.rmtree(Patch_target)\n"," if os.path.exists(Patch_rejected):\n"," shutil.rmtree(Patch_rejected)\n","\n"," os.mkdir(Patch_source)\n"," os.mkdir(Patch_target)\n"," os.mkdir(Patch_rejected) #This directory will contain the images that have too little signal.\n"," \n"," patch_num = 0\n","\n"," for file in tqdm(os.listdir(Training_source)):\n","\n"," img = io.imread(os.path.join(Training_source, file))\n"," mask = io.imread(os.path.join(Training_target, file),as_gray=True)\n","\n"," if DEBUG:\n"," print(file)\n"," print(img.dtype)\n","\n"," # Using view_as_windows with step size equal to the patch size to ensure there is no overlap\n"," patches_img = view_as_windows(img, (patch_width, patch_height), (patch_width, patch_height))\n"," patches_mask = view_as_windows(mask, (patch_width, patch_height), (patch_width, patch_height))\n","\n"," patches_img = patches_img.reshape(patches_img.shape[0]*patches_img.shape[1], patch_width,patch_height)\n"," patches_mask = patches_mask.reshape(patches_mask.shape[0]*patches_mask.shape[1], patch_width,patch_height)\n","\n"," if DEBUG:\n"," print(all_patches_img.shape)\n"," print(all_patches_img.dtype)\n","\n"," for i in range(patches_img.shape[0]):\n"," img_save_path = os.path.join(Patch_source,'patch_'+str(patch_num)+'.tif')\n"," mask_save_path = os.path.join(Patch_target,'patch_'+str(patch_num)+'.tif')\n"," patch_num += 1\n","\n"," # if the mask conatins at least 2% of its total number pixels as mask, then go ahead and save the images\n"," pixel_threshold_array = sorted(patches_mask[i].flatten())\n"," if pixel_threshold_array[int(round(len(pixel_threshold_array)*(1-min_fraction)))]>0:\n"," io.imsave(img_save_path, img_as_ubyte(normalizeMinMax(patches_img[i])))\n"," io.imsave(mask_save_path, convert2Mask(normalizeMinMax(patches_mask[i]),0))\n"," else:\n"," io.imsave(Patch_rejected+'/patch_'+str(patch_num)+'_image.tif', img_as_ubyte(normalizeMinMax(patches_img[i])))\n"," io.imsave(Patch_rejected+'/patch_'+str(patch_num)+'_mask.tif', convert2Mask(normalizeMinMax(patches_mask[i]),0))\n","\n"," return Patch_source, Patch_target\n","\n","\n","def estimatePatchSize(data_path, max_width = 512, max_height = 512):\n","\n"," files = os.listdir(data_path)\n"," \n"," # Get the size of the first image found in the folder and initialise the variables to that\n"," n = 0 \n"," while os.path.isdir(os.path.join(data_path, files[n])):\n"," n += 1\n"," (height_min, width_min) = Image.open(os.path.join(data_path, files[n])).size\n","\n"," # Screen the size of all dataset to find the minimum image size\n"," for file in files:\n"," if not os.path.isdir(os.path.join(data_path, file)):\n"," (height, width) = Image.open(os.path.join(data_path, file)).size\n"," if width < width_min:\n"," width_min = width\n"," if height < height_min:\n"," height_min = height\n"," \n"," # Find the power of patches that will fit within the smallest dataset\n"," width_min, height_min = (fittingPowerOfTwo(width_min), fittingPowerOfTwo(height_min))\n","\n"," # Clip values at maximum permissible values\n"," if width_min > max_width:\n"," width_min = max_width\n","\n"," if height_min > max_height:\n"," height_min = max_height\n"," \n"," return (width_min, height_min)\n","\n","def fittingPowerOfTwo(number):\n"," n = 0\n"," while 2**n <= number:\n"," n += 1 \n"," return 2**(n-1)\n","\n","\n","def getClassWeights(Training_target_path):\n","\n"," Mask_dir_list = os.listdir(Training_target_path)\n"," number_of_dataset = len(Mask_dir_list)\n","\n"," class_count = np.zeros(2, dtype=int)\n"," for i in tqdm(range(number_of_dataset)):\n"," mask = io.imread(os.path.join(Training_target_path, Mask_dir_list[i]))\n"," mask = normalizeMinMax(mask)\n"," class_count[0] += mask.shape[0]*mask.shape[1] - mask.sum()\n"," class_count[1] += mask.sum()\n","\n"," n_samples = class_count.sum()\n"," n_classes = 2\n","\n"," class_weights = n_samples / (n_classes * class_count)\n"," return class_weights\n","\n","def weighted_binary_crossentropy(class_weights):\n","\n"," def _weighted_binary_crossentropy(y_true, y_pred):\n"," binary_crossentropy = keras.binary_crossentropy(y_true, y_pred)\n"," weight_vector = y_true * class_weights[1] + (1. - y_true) * class_weights[0]\n"," weighted_binary_crossentropy = weight_vector * binary_crossentropy\n","\n"," return keras.mean(weighted_binary_crossentropy)\n","\n"," return _weighted_binary_crossentropy\n","\n","\n","def save_augment(datagen,orig_img,dir_augmented_data=\"/content/augment\"):\n"," \"\"\"\n"," Saves a subset of the augmented data for visualisation, by default in /content.\n","\n"," This is adapted from: https://fairyonice.github.io/Learn-about-ImageDataGenerator.html\n"," \n"," \"\"\"\n"," try:\n"," os.mkdir(dir_augmented_data)\n"," except:\n"," ## if the preview folder exists, then remove\n"," ## the contents (pictures) in the folder\n"," for item in os.listdir(dir_augmented_data):\n"," os.remove(dir_augmented_data + \"/\" + item)\n","\n"," ## convert the original image to array\n"," x = img_to_array(orig_img)\n"," ## reshape (Sampke, Nrow, Ncol, 3) 3 = R, G or B\n"," #print(x.shape)\n"," x = x.reshape((1,) + x.shape)\n"," #print(x.shape)\n"," ## -------------------------- ##\n"," ## randomly generate pictures\n"," ## -------------------------- ##\n"," i = 0\n"," #We will just save 5 images,\n"," #but this can be changed, but note the visualisation in 3. currently uses 5.\n"," Nplot = 5\n"," for batch in datagen.flow(x,batch_size=1,\n"," save_to_dir=dir_augmented_data,\n"," save_format='tif',\n"," seed=42):\n"," i += 1\n"," if i > Nplot - 1:\n"," break\n","\n","# Generators\n","def buildDoubleGenerator(image_datagen, mask_datagen, image_folder_path, mask_folder_path, subset, batch_size, target_size):\n"," '''\n"," Can generate image and mask at the same time use the same seed for image_datagen and mask_datagen to ensure the transformation for image and mask is the same\n"," \n"," datagen: ImageDataGenerator \n"," subset: can take either 'training' or 'validation'\n"," '''\n"," seed = 1\n"," image_generator = image_datagen.flow_from_directory(\n"," os.path.dirname(image_folder_path),\n"," classes = [os.path.basename(image_folder_path)],\n"," class_mode = None,\n"," color_mode = \"grayscale\",\n"," target_size = target_size,\n"," batch_size = batch_size,\n"," subset = subset,\n"," interpolation = \"bicubic\",\n"," seed = seed)\n"," \n"," mask_generator = mask_datagen.flow_from_directory(\n"," os.path.dirname(mask_folder_path),\n"," classes = [os.path.basename(mask_folder_path)],\n"," class_mode = None,\n"," color_mode = \"grayscale\",\n"," target_size = target_size,\n"," batch_size = batch_size,\n"," subset = subset,\n"," interpolation = \"nearest\",\n"," seed = seed)\n"," \n"," this_generator = zip(image_generator, mask_generator)\n"," for (img,mask) in this_generator:\n"," # img,mask = adjustData(img,mask)\n"," yield (img,mask)\n","\n","\n","def prepareGenerators(image_folder_path, mask_folder_path, datagen_parameters, batch_size = 4, target_size = (512, 512)):\n"," image_datagen = ImageDataGenerator(**datagen_parameters, preprocessing_function = normalizePercentile)\n"," mask_datagen = ImageDataGenerator(**datagen_parameters, preprocessing_function = normalizeMinMax)\n","\n"," train_datagen = buildDoubleGenerator(image_datagen, mask_datagen, image_folder_path, mask_folder_path, 'training', batch_size, target_size)\n"," validation_datagen = buildDoubleGenerator(image_datagen, mask_datagen, image_folder_path, mask_folder_path, 'validation', batch_size, target_size)\n","\n"," return (train_datagen, validation_datagen)\n","\n","\n","# Normalization functions from Martin Weigert\n","def normalizePercentile(x, pmin=1, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","\n","\n","# Simple normalization to min/max fir the Mask\n","def normalizeMinMax(x, dtype=np.float32):\n"," x = x.astype(dtype,copy=False)\n"," x = (x - np.amin(x)) / (np.amax(x) - np.amin(x))\n"," return x\n","\n","\n","# This is code outlines the architecture of U-net. The choice of pooling steps decides the depth of the network. \n","def unet(pretrained_weights = None, input_size = (256,256,1), pooling_steps = 4, learning_rate = 1e-4, verbose=True, class_weights=np.ones(2)):\n"," inputs = Input(input_size)\n"," conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(inputs)\n"," conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1)\n"," # Downsampling steps\n"," pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)\n"," conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1)\n"," conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2)\n"," \n"," if pooling_steps > 1:\n"," pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)\n"," conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2)\n"," conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3)\n","\n"," if pooling_steps > 2:\n"," pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)\n"," conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3)\n"," conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4)\n"," drop4 = Dropout(0.5)(conv4)\n"," \n"," if pooling_steps > 3:\n"," pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)\n"," conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4)\n"," conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5)\n"," drop5 = Dropout(0.5)(conv5)\n","\n"," #Upsampling steps\n"," up6 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5))\n"," merge6 = concatenate([drop4,up6], axis = 3)\n"," conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6)\n"," conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6)\n"," \n"," if pooling_steps > 2:\n"," up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop4))\n"," if pooling_steps > 3:\n"," up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6))\n"," merge7 = concatenate([conv3,up7], axis = 3)\n"," conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7)\n"," conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7)\n","\n"," if pooling_steps > 1:\n"," up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv3))\n"," if pooling_steps > 2:\n"," up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7))\n"," merge8 = concatenate([conv2,up8], axis = 3)\n"," conv8 = Conv2D(128, 3, activation= 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8)\n"," conv8 = Conv2D(128, 3, activation= 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8)\n"," \n"," if pooling_steps == 1:\n"," up9 = Conv2D(64, 2, padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv2))\n"," else:\n"," up9 = Conv2D(64, 2, padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8)) #activation = 'relu'\n"," \n"," merge9 = concatenate([conv1,up9], axis = 3)\n"," conv9 = Conv2D(64, 3, padding = 'same', kernel_initializer = 'he_normal')(merge9) #activation = 'relu'\n"," conv9 = Conv2D(64, 3, padding = 'same', kernel_initializer = 'he_normal')(conv9) #activation = 'relu'\n"," conv9 = Conv2D(2, 3, padding = 'same', kernel_initializer = 'he_normal')(conv9) #activation = 'relu'\n"," conv10 = Conv2D(1, 1, activation = 'sigmoid')(conv9)\n","\n"," model = Model(inputs = inputs, outputs = conv10)\n","\n"," # model.compile(optimizer = Adam(lr = learning_rate), loss = 'binary_crossentropy', metrics = ['acc'])\n"," model.compile(optimizer = Adam(lr = learning_rate), loss = weighted_binary_crossentropy(class_weights))\n","\n","\n"," if verbose:\n"," model.summary()\n","\n"," if(pretrained_weights):\n"," \tmodel.load_weights(pretrained_weights);\n","\n"," return model\n","\n","\n","\n","def predict_as_tiles(Image_path, model):\n","\n"," # Read the data in and normalize\n"," Image_raw = io.imread(Image_path, as_gray = True)\n"," Image_raw = normalizePercentile(Image_raw)\n","\n"," # Get the patch size from the input layer of the model\n"," patch_size = model.layers[0].output_shape[1:3]\n","\n"," # Pad the image with zeros if any of its dimensions is smaller than the patch size\n"," if Image_raw.shape[0] < patch_size[0] or Image_raw.shape[1] < patch_size[1]:\n"," Image = np.zeros((max(Image_raw.shape[0], patch_size[0]), max(Image_raw.shape[1], patch_size[1])))\n"," Image[0:Image_raw.shape[0], 0: Image_raw.shape[1]] = Image_raw\n"," else:\n"," Image = Image_raw\n","\n"," # Calculate the number of patches in each dimension\n"," n_patch_in_width = ceil(Image.shape[0]/patch_size[0])\n"," n_patch_in_height = ceil(Image.shape[1]/patch_size[1])\n","\n"," prediction = np.zeros(Image.shape)\n","\n"," for x in range(n_patch_in_width):\n"," for y in range(n_patch_in_height):\n"," xi = patch_size[0]*x\n"," yi = patch_size[1]*y\n","\n"," # If the patch exceeds the edge of the image shift it back \n"," if xi+patch_size[0] >= Image.shape[0]:\n"," xi = Image.shape[0]-patch_size[0]\n","\n"," if yi+patch_size[1] >= Image.shape[1]:\n"," yi = Image.shape[1]-patch_size[1]\n"," \n"," # Extract and reshape the patch\n"," patch = Image[xi:xi+patch_size[0], yi:yi+patch_size[1]]\n"," patch = np.reshape(patch,patch.shape+(1,))\n"," patch = np.reshape(patch,(1,)+patch.shape)\n","\n"," # Get the prediction from the patch and paste it in the prediction in the right place\n"," predicted_patch = model.predict(patch, batch_size = 1)\n"," prediction[xi:xi+patch_size[0], yi:yi+patch_size[1]] = np.squeeze(predicted_patch)\n","\n","\n"," return prediction[0:Image_raw.shape[0], 0: Image_raw.shape[1]]\n"," \n","\n","\n","\n","def saveResult(save_path, nparray, source_dir_list, prefix='', threshold=None):\n"," for (filename, image) in zip(source_dir_list, nparray):\n"," io.imsave(os.path.join(save_path, prefix+os.path.splitext(filename)[0]+'.tif'), img_as_ubyte(image)) # saving as unsigned 8-bit image\n"," \n"," # For masks, threshold the images and return 8 bit image\n"," if threshold is not None:\n"," mask = convert2Mask(image, threshold)\n"," io.imsave(os.path.join(save_path, prefix+'mask_'+os.path.splitext(filename)[0]+'.tif'), mask)\n","\n","\n","def convert2Mask(image, threshold):\n"," mask = img_as_ubyte(image, force_copy=True)\n"," mask[mask > threshold] = 255\n"," mask[mask <= threshold] = 0\n"," return mask\n","\n","\n","def getIoUvsThreshold(prediction_filepath, groud_truth_filepath):\n"," prediction = io.imread(prediction_filepath)\n"," ground_truth_image = img_as_ubyte(io.imread(groud_truth_filepath, as_gray=True), force_copy=True)\n","\n"," threshold_list = []\n"," IoU_scores_list = []\n","\n"," for threshold in range(0,256): \n"," # Convert to 8-bit for calculating the IoU\n"," mask = img_as_ubyte(prediction, force_copy=True)\n"," mask[mask > threshold] = 255\n"," mask[mask <= threshold] = 0\n","\n"," # Intersection over Union metric\n"," intersection = np.logical_and(ground_truth_image, np.squeeze(mask))\n"," union = np.logical_or(ground_truth_image, np.squeeze(mask))\n"," iou_score = np.sum(intersection) / np.sum(union)\n","\n"," threshold_list.append(threshold)\n"," IoU_scores_list.append(iou_score)\n","\n"," return (threshold_list, IoU_scores_list)\n","\n","\n","\n","# -------------- Other definitions -----------\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","prediction_prefix = 'Predicted_'\n","\n","\n","print('-------------------')\n","print('U-Net and dependencies installed.')\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","\n","# Check if this is the latest version of the notebook\n","Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n","\n","print('Notebook version: '+Notebook_version[0])\n","\n","strlist = Notebook_version[0].split('.')\n","Notebook_version_main = strlist[0]+'.'+strlist[1]\n","\n","if Notebook_version_main == Latest_notebook_version.columns:\n"," print(\"This notebook is up-to-date.\")\n","else:\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'U-Net 2D'\n","\n","\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," \n"," # add another cell \n"," if trained:\n"," training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n"," pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n"," pdf.ln(1)\n","\n"," Header_2 = 'Information for your materials and method:'\n"," pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n"," #print(all_packages)\n","\n"," #Main Packages\n"," main_packages = ''\n"," version_numbers = []\n"," for name in ['tensorflow','numpy','Keras']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n"," cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n"," cuda_version = cuda_version.stdout.decode('utf-8')\n"," cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n"," gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n"," gpu_name = gpu_name.stdout.decode('utf-8')\n"," gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n"," #print(cuda_version[cuda_version.find(', V')+3:-1])\n"," #print(gpu_name)\n"," loss = str(model.loss)[str(model.loss).find('function')+len('function'):str(model.loss).find('.<')]\n"," shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n"," dataset_size = len(os.listdir(Training_source))\n","\n"," text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(number_of_training_dataset)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_width)+','+str(patch_height)+')) with a batch size of '+str(batch_size)+' and a'+loss+' loss function,'+' using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," if pretrained_model:\n"," text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(number_of_training_dataset)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_width)+','+str(patch_height)+')) with a batch size of '+str(batch_size)+' and a'+loss+' loss function,'+' using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was re-trained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.multi_cell(180, 5, txt = text, align='L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(28, 5, txt='Augmentation: ', ln=1)\n"," pdf.set_font('')\n"," if augmentation:\n"," aug_text = 'The dataset was augmented by'\n"," if rotation_range != 0:\n"," aug_text = aug_text+'\\n- rotation'\n"," if horizontal_flip == True or vertical_flip == True:\n"," aug_text = aug_text+'\\n- flipping'\n"," if zoom_range != 0:\n"," aug_text = aug_text+'\\n- random zoom magnification'\n"," if horizontal_shift != 0 or vertical_shift != 0:\n"," aug_text = aug_text+'\\n- shifting'\n"," if shear_range != 0:\n"," aug_text = aug_text+'\\n- image shearing'\n"," else:\n"," aug_text = 'No augmentation was used for training.'\n"," pdf.multi_cell(190, 5, txt=aug_text, align='L')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," if Use_Default_Advanced_Parameters:\n"," pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n"," pdf.cell(200, 5, txt='The following parameters were used for training:')\n"," pdf.ln(1)\n"," html = \"\"\" \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
number_of_epochs{0}
patch_size{1}
batch_size{2}
number_of_steps{3}
percentage_validation{4}
initial_learning_rate{5}
pooling_steps{6}
min_fraction{7}
\n"," \"\"\".format(number_of_epochs, str(patch_width)+'x'+str(patch_height), batch_size, number_of_steps, percentage_validation, initial_learning_rate, pooling_steps, min_fraction)\n"," pdf.write_html(html)\n","\n"," #pdf.multi_cell(190, 5, txt = text_2, align='L')\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," pdf.ln(1)\n"," pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(29, 5, txt= 'Training_source:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(28, 5, txt= 'Training_target:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n"," #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(21, 5, txt= 'Model Path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n"," pdf.ln(1)\n"," pdf.cell(60, 5, txt = 'Example Training pair', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread('/content/TrainingDataExample_Unet2D.png').shape\n"," pdf.image('/content/TrainingDataExample_Unet2D.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- Unet: Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. \"U-net: Convolutional networks for biomedical image segmentation.\" International Conference on Medical image computing and computer-assisted intervention. Springer, Cham, 2015.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," # if Use_Data_augmentation:\n"," # ref_3 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\n"," # pdf.multi_cell(190, 5, txt = ref_3, align='L')\n"," pdf.ln(3)\n"," reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(model_path+'/'+model_name+'/'+model_name+'_training_report.pdf')\n","\n"," print('------------------------------')\n"," print('PDF report exported in '+model_path+'/'+model_name+'/')\n","\n","def qc_pdf_export():\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'Unet 2D'\n","\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(2)\n"," pdf.cell(190, 5, txt = 'Loss curves', ln=1, align='L')\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n"," if os.path.exists(full_QC_model_path+'/Quality Control/lossCurvePlots.png'):\n"," pdf.image(full_QC_model_path+'/Quality Control/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/12), h = round(exp_size[0]/3))\n"," else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.',align='L')\n"," pdf.ln(2)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(2)\n"," pdf.cell(190, 5, txt = 'Threshold Optimisation', ln=1, align='L')\n"," #pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/'+QC_model_name+'_IoUvsThresholdPlot.png').shape\n"," pdf.image(full_QC_model_path+'/Quality Control/'+QC_model_name+'_IoUvsThresholdPlot.png', x = 11, y = None, w = round(exp_size[1]/6), h = round(exp_size[0]/7))\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(3)\n"," pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n"," pdf.image(full_QC_model_path+'/Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n","\n"," pdf.ln(1)\n"," html = \"\"\"\n"," \n"," \n"," \"\"\"\n"," with open(full_QC_model_path+'/Quality Control/QC_metrics_'+QC_model_name+'.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," IoU = header[1]\n"," IoU_OptThresh = header[2]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \"\"\".format(image,IoU,IoU_OptThresh)\n"," html = html+header\n"," i=0\n"," for row in metrics:\n"," i+=1\n"," image = row[0]\n"," IoU = row[1]\n"," IoU_OptThresh = row[2]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \"\"\".format(image,str(round(float(IoU),3)),str(round(float(IoU_OptThresh),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}
{0}{1}{2}
\"\"\"\n"," \n"," pdf.write_html(html)\n","\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- Unet: Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. \"U-net: Convolutional networks for biomedical image segmentation.\" International Conference on Medical image computing and computer-assisted intervention. Springer, Cham, 2015.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n"," pdf.ln(3)\n"," reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n","\n"," print('------------------------------')\n"," print('QC PDF report exported as '+full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n","\n","# Build requirements file for local run\n","after = [str(m) for m in sys.modules]\n","build_requirements_file(before, after)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"HLYcZR9gMv42"},"source":["# **3. Select your parameters and paths**\n","---"]},{"cell_type":"markdown","metadata":{"id":"FQ_QxtSWQ7CL"},"source":["## **3.1. Setting main training parameters**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"id":"AuESFimvMv43"},"source":[" **Paths for training data and models**\n","\n","**`Training_source`, `Training_target`:** These are the folders containing your source (e.g. EM images) and target files (segmentation masks). Enter the path to the source and target images for training. **These should be located in the same parent folder.**\n","\n","**`model_name`:** Use only my_model -style, not my-model. If you want to use a previously trained model, enter the name of the pretrained model (which should be contained in the trained_model -folder after training).\n","\n","**`model_path`**: Enter the path of the folder where you want to save your model.\n","\n","**`visual_validation_after_training`**: If you select this option, a random image pair will be set aside from your training set and will be used to display a predicted image of the trained network next to the input and the ground-truth. This can aid in visually assessing the performance of your network after training. **Note: Your training set size will decrease by 1 if you select this option.**\n","\n"," **Select training parameters**\n","\n","**`number_of_epochs`**: Choose more epochs for larger training sets. Observing how much the loss reduces between epochs during training may help determine the optimal value. **Default: 200**\n","\n","**Advanced parameters - experienced users only**\n","\n","**`batch_size`**: This parameter describes the amount of images that are loaded into the network per step. Smaller batchsizes may improve training performance slightly but may increase training time. If the notebook crashes while loading the dataset this can be due to a too large batch size. Decrease the number in this case. **Default: 4**\n","\n","**`number_of_steps`**: This number should be equivalent to the number of samples in the training set divided by the batch size, to ensure the training iterates through the entire training set. The default value is calculated to ensure this. This behaviour can also be obtained by setting it to 0. Other values can be used for testing.\n","\n"," **`pooling_steps`**: Choosing a different number of pooling layers can affect the performance of the network. Each additional pooling step will also two additional convolutions. The network can learn more complex information but is also more likely to overfit. Achieving best performance may require testing different values here. **Default: 2**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during training. **Default value: 10** \n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0003**\n","\n","**`patch_width` and `patch_height`:** The notebook crops the data in patches of fixed size prior to training. The dimensions of the patches can be defined here. When `Use_Default_Advanced_Parameters` is selected, the largest 2^n x 2^n patch that fits in the smallest dataset is chosen. Larger patches than 512x512 should **NOT** be selected for network stability.\n","\n","**`min_fraction`:** Minimum fraction of pixels being foreground for a slected patch to be considered valid. It should be between 0 and 1.**Default value: 0.02** (2%)\n","\n"]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","cellView":"form"},"source":["# ------------- Initial user input ------------\n","#@markdown ###Path to training images:\n","Training_source = '' #@param {type:\"string\"}\n","Training_target = '' #@param {type:\"string\"}\n","\n","model_name = '' #@param {type:\"string\"}\n","model_path = '' #@param {type:\"string\"}\n","\n","#@markdown ###Training parameters:\n","#@markdown Number of epochs\n","number_of_epochs = 200#@param {type:\"number\"}\n","\n","#@markdown ###Advanced parameters:\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please input:\n","batch_size = 4#@param {type:\"integer\"}\n","number_of_steps = 0#@param {type:\"number\"}\n","pooling_steps = 2 #@param [1,2,3,4]{type:\"raw\"}\n","percentage_validation = 10#@param{type:\"number\"}\n","initial_learning_rate = 0.0003 #@param {type:\"number\"}\n","\n","patch_width = 512#@param{type:\"number\"}\n","patch_height = 512#@param{type:\"number\"}\n","min_fraction = 0.02#@param{type:\"number\"}\n","\n","\n","# ------------- Initialising folder, variables and failsafes ------------\n","# Create the folders where to save the model and the QC\n","full_model_path = os.path.join(model_path, model_name)\n","if os.path.exists(full_model_path):\n"," print(R+'!! WARNING: Folder already exists and will be overwritten !!'+W)\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 4\n"," pooling_steps = 2\n"," percentage_validation = 10\n"," initial_learning_rate = 0.0003\n"," patch_width, patch_height = estimatePatchSize(Training_source)\n"," min_fraction = 0.02\n","\n","\n","#The create_patches function will create the two folders below\n","# Patch_source = '/content/img_patches'\n","# Patch_target = '/content/mask_patches'\n","print('Training on patches of size (x,y): ('+str(patch_width)+','+str(patch_height)+')')\n","\n","#Create patches\n","print('Creating patches...')\n","Patch_source, Patch_target = create_patches(Training_source, Training_target, patch_width, patch_height, min_fraction)\n","\n","number_of_training_dataset = len(os.listdir(Patch_source))\n","print('Total number of valid patches: '+str(number_of_training_dataset))\n","\n","if Use_Default_Advanced_Parameters or number_of_steps == 0:\n"," number_of_steps = ceil((100-percentage_validation)/100*number_of_training_dataset/batch_size)\n","print('Number of steps: '+str(number_of_steps))\n","\n","# Calculate the number of steps to use for validation\n","validation_steps = max(1, ceil(percentage_validation/100*number_of_training_dataset/batch_size))\n","\n","\n","# Here we disable pre-trained model by default (in case the next cell is not ran)\n","Use_pretrained_model = False\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","Use_Data_augmentation = False\n","# Build the default dict for the ImageDataGenerator\n","data_gen_args = dict(width_shift_range = 0.,\n"," height_shift_range = 0.,\n"," rotation_range = 0., #90\n"," zoom_range = 0.,\n"," shear_range = 0.,\n"," horizontal_flip = False,\n"," vertical_flip = False,\n"," validation_split = percentage_validation/100,\n"," fill_mode = 'reflect')\n","\n","# ------------- Display ------------\n","\n","#if not os.path.exists('/content/img_patches/'):\n","random_choice = random.choice(os.listdir(Patch_source))\n","x = io.imread(os.path.join(Patch_source, random_choice))\n","\n","#os.chdir(Training_target)\n","y = io.imread(os.path.join(Patch_target, random_choice), as_gray=True)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest',cmap='gray')\n","plt.title('Training image patch')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(y, interpolation='nearest',cmap='gray')\n","plt.title('Training mask patch')\n","plt.axis('off');\n","\n","plt.savefig('/content/TrainingDataExample_Unet2D.png',bbox_inches='tight',pad_inches=0)\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"w_jCy7xOx2g3"},"source":["##**3.2. Data augmentation**\n","\n","---\n","\n"," Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if the dataset is large the values can be set to 0.\n","\n"," The augmentation options below are to be used as follows:\n","\n","* **shift**: a translation of the image by a fraction of the image size (width or height), **default: 10%**\n","* **zoom_range**: Increasing or decreasing the field of view. E.g. 10% will result in a zoom range of (0.9 to 1.1), with pixels added or interpolated, depending on the transformation, **default: 10%**\n","* **shear_range**: Shear angle in counter-clockwise direction, **default: 10%**\n","* **flip**: creating a mirror image along specified axis (horizontal or vertical), **default: True**\n","* **rotation_range**: range of allowed rotation angles in degrees (from 0 to *value*), **default: 180**"]},{"cell_type":"code","metadata":{"id":"DMqWq5-AxnFU","cellView":"form"},"source":["#@markdown ##**Augmentation options**\n","\n","Use_Data_augmentation = True #@param {type:\"boolean\"}\n","Use_Default_Augmentation_Parameters = True #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation:\n"," if Use_Default_Augmentation_Parameters:\n"," horizontal_shift = 10 \n"," vertical_shift = 20 \n"," zoom_range = 10\n"," shear_range = 10\n"," horizontal_flip = True\n"," vertical_flip = True\n"," rotation_range = 180\n","#@markdown ###If you are not using the default settings, please provide the values below:\n","\n","#@markdown ###**Image shift, zoom, shear and flip (%)**\n"," else:\n"," horizontal_shift = 10 #@param {type:\"slider\", min:0, max:100, step:1}\n"," vertical_shift = 10 #@param {type:\"slider\", min:0, max:100, step:1}\n"," zoom_range = 10 #@param {type:\"slider\", min:0, max:100, step:1}\n"," shear_range = 10 #@param {type:\"slider\", min:0, max:100, step:1}\n"," horizontal_flip = True #@param {type:\"boolean\"}\n"," vertical_flip = True #@param {type:\"boolean\"}\n","\n","#@markdown ###**Rotate image within angle range (degrees):**\n"," rotation_range = 180 #@param {type:\"slider\", min:0, max:180, step:1}\n","\n","#given behind the # are the default values for each parameter.\n","\n","else:\n"," horizontal_shift = 0 \n"," vertical_shift = 0 \n"," zoom_range = 0\n"," shear_range = 0\n"," horizontal_flip = False\n"," vertical_flip = False\n"," rotation_range = 0\n","\n","\n","# Build the dict for the ImageDataGenerator\n","data_gen_args = dict(width_shift_range = horizontal_shift/100.,\n"," height_shift_range = vertical_shift/100.,\n"," rotation_range = rotation_range, #90\n"," zoom_range = zoom_range/100.,\n"," shear_range = shear_range/100.,\n"," horizontal_flip = horizontal_flip,\n"," vertical_flip = vertical_flip,\n"," validation_split = percentage_validation/100,\n"," fill_mode = 'reflect')\n","\n","\n","\n","# ------------- Display ------------\n","dir_augmented_data_imgs=\"/content/augment_img\"\n","dir_augmented_data_masks=\"/content/augment_mask\"\n","random_choice = random.choice(os.listdir(Patch_source))\n","orig_img = load_img(os.path.join(Patch_source,random_choice))\n","orig_mask = load_img(os.path.join(Patch_target,random_choice))\n","\n","augment_view = ImageDataGenerator(**data_gen_args)\n","\n","if Use_Data_augmentation:\n"," print(\"Parameters enabled\")\n"," print(\"Here is what a subset of your augmentations looks like:\")\n"," save_augment(augment_view, orig_img, dir_augmented_data=dir_augmented_data_imgs)\n"," save_augment(augment_view, orig_mask, dir_augmented_data=dir_augmented_data_masks)\n","\n"," fig = plt.figure(figsize=(15, 7))\n"," fig.subplots_adjust(hspace=0.0,wspace=0.1,left=0,right=1.1,bottom=0, top=0.8)\n","\n"," \n"," ax = fig.add_subplot(2, 6, 1,xticks=[],yticks=[]) \n"," new_img=img_as_ubyte(normalizeMinMax(img_to_array(orig_img)))\n"," ax.imshow(new_img)\n"," ax.set_title('Original Image')\n"," i = 2\n"," for imgnm in os.listdir(dir_augmented_data_imgs):\n"," ax = fig.add_subplot(2, 6, i,xticks=[],yticks=[]) \n"," img = load_img(dir_augmented_data_imgs + \"/\" + imgnm)\n"," ax.imshow(img)\n"," i += 1\n","\n"," ax = fig.add_subplot(2, 6, 7,xticks=[],yticks=[]) \n"," new_mask=img_as_ubyte(normalizeMinMax(img_to_array(orig_mask)))\n"," ax.imshow(new_mask)\n"," ax.set_title('Original Mask')\n"," j=2\n"," for imgnm in os.listdir(dir_augmented_data_masks):\n"," ax = fig.add_subplot(2, 6, j+6,xticks=[],yticks=[]) \n"," mask = load_img(dir_augmented_data_masks + \"/\" + imgnm)\n"," ax.imshow(mask)\n"," j += 1\n"," plt.show()\n","\n","else:\n"," print(\"No augmentation will be used\")\n","\n"," "],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"3L9zSGtORKYI"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a U-Net model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"9vC2n-HeLdiJ","cellView":"form"},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = True #@param {type:\"boolean\"}\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","Weights_choice = \"last\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".hdf5\")\n","\n","\n","# --------------------- Download the a model provided in the XXX ------------------------\n","\n"," if pretrained_model_choice == \"Model_name\":\n"," pretrained_model_name = \"Model_name\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the UNET_Model_from_\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path) \n"," wget.download(\"\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".hdf5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(R+'WARNING: pretrained model does not exist')\n"," Use_pretrained_model = False\n"," \n","\n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print(R+'No pretrained network will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"MCGklf1vZf2M"},"source":["# **4. Train the network**\n","---\n","####**Troubleshooting:** If you receive a time-out or exhausted error, try reducing the batchsize of your training set. This reduces the amount of data loaded into the model at one point in time. "]},{"cell_type":"markdown","metadata":{"id":"1KYOuygETJkT"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"id":"lIUAOJ_LMv5E","cellView":"form"},"source":["#@markdown ##Play this cell to prepare the model for training\n","\n","\n","# ------------------ Set the generators, model and logger ------------------\n","# This will take the image size and set that as a patch size (arguable...)\n","# Read image size (without actuall reading the data)\n","\n","(train_datagen, validation_datagen) = prepareGenerators(Patch_source, Patch_target, data_gen_args, batch_size, target_size = (patch_width, patch_height))\n","\n","\n","# This modelcheckpoint will only save the best model from the validation loss point of view\n","model_checkpoint = ModelCheckpoint(os.path.join(full_model_path, 'weights_best.hdf5'), monitor='val_loss',verbose=1, save_best_only=True)\n","\n","print('Getting class weights...')\n","class_weights = getClassWeights(Training_target)\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","else:\n"," h5_file_path = None\n","\n","# --------------------- ---------------------- ------------------------\n","\n","# --------------------- Reduce learning rate on plateau ------------------------\n","\n","reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, verbose=1, mode='auto',\n"," patience=10, min_lr=0)\n","# --------------------- ---------------------- ------------------------\n","\n","\n","# Define the model\n","model = unet(pretrained_weights = h5_file_path, \n"," input_size = (patch_width,patch_height,1), \n"," pooling_steps = pooling_steps, \n"," learning_rate = initial_learning_rate, \n"," class_weights = class_weights)\n","\n","config_model= model.optimizer.get_config()\n","print(config_model)\n","\n","\n","# ------------------ Failsafes ------------------\n","if os.path.exists(full_model_path):\n"," print(R+'!! WARNING: Model folder already existed and has been removed !!'+W)\n"," shutil.rmtree(full_model_path)\n","\n","os.makedirs(full_model_path)\n","os.makedirs(os.path.join(full_model_path,'Quality Control'))\n","\n","\n","# ------------------ Display ------------------\n","print('---------------------------- Main training parameters ----------------------------')\n","print('Number of epochs: '+str(number_of_epochs))\n","print('Batch size: '+str(batch_size))\n","print('Number of training dataset: '+str(number_of_training_dataset))\n","print('Number of training steps: '+str(number_of_steps))\n","print('Number of validation steps: '+str(validation_steps))\n","print('---------------------------- ------------------------ ----------------------------')\n","\n","pdf_export(augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"0Dfn8ZsEMv5d"},"source":["## **4.2. Start Training**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point.\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder from Google Drive as all data can be erased at the next training if using the same folder."]},{"cell_type":"code","metadata":{"scrolled":true,"id":"iwNmp1PUzRDQ","cellView":"form"},"source":["#@markdown ##Start training\n","\n","start = time.time()\n","# history = model.fit_generator(train_datagen, steps_per_epoch = number_of_steps, epochs=epochs, callbacks=[model_checkpoint,csv_log], validation_data = validation_datagen, validation_steps = validation_steps, shuffle=True, verbose=1)\n","history = model.fit_generator(train_datagen, steps_per_epoch = number_of_steps, epochs = number_of_epochs, callbacks=[model_checkpoint, reduce_lr], validation_data = validation_datagen, validation_steps = validation_steps, shuffle=True, verbose=1)\n","\n","# Save the last model\n","model.save(os.path.join(full_model_path, 'weights_last.hdf5'))\n","\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = os.path.join(full_model_path,'Quality Control/training_evaluation.csv')\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n"," \n","\n","\n","# Displaying the time elapsed for training\n","print(\"------------------------------------------\")\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\", hour, \"hour(s)\", mins,\"min(s)\",round(sec),\"sec(s)\")\n","print(\"------------------------------------------\")\n","\n","#Create a pdf document with training summary\n","\n","pdf_export(trained = True, augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"_0Hynw3-xHp1"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"id":"eAJzMwPA6tlH","cellView":"form"},"source":["#@markdown ###Do you want to assess the model you just trained ?\n","\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","\n","full_QC_model_path = os.path.join(QC_model_path, QC_model_name)\n","if os.path.exists(os.path.join(full_QC_model_path, 'weights_best.hdf5')):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"dhJROwlAMv5o"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"id":"vMzSP50kMv5p","cellView":"form"},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","epochNumber = []\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(os.path.join(full_QC_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(os.path.join(full_QC_model_path, 'Quality Control', 'lossCurvePlots.png'),bbox_inches='tight',pad_inches=0)\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"X5_92nL2xdP6"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","This section will calculate the Intersection over Union score for all the images provided in the Source_QC_folder and Target_QC_folder. The result for one of the image will also be displayed.\n","\n","The **Intersection over Union** metric is a method that can be used to quantify the percent overlap between the target mask and your prediction output. **Therefore, the closer to 1, the better the performance.** This metric can be used to assess the quality of your model to accurately predict nuclei. \n","\n","The Input, Ground Truth, Prediction and IoU maps are shown below for the last example in the QC set.\n","\n"," The results for all QC examples can be found in the \"*Quality Control*\" folder which is located inside your \"model_folder\".\n","\n","### **Thresholds for image masks**\n","\n"," Since the output from Unet is not a binary mask, the output images are converted to binary masks using thresholding. This section will test different thresholds (from 0 to 255) to find the one yielding the best IoU score compared with the ground truth. The best threshold for each image and the average of these thresholds will be displayed below. **These values can be a guideline when creating masks for unseen data in section 6.**"]},{"cell_type":"code","metadata":{"id":"w90MdriMxhjD","cellView":"form"},"source":["# ------------- User input ------------\n","#@markdown ##Choose the folders that contain your Quality Control dataset\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","\n","# ------------- Initialise folders ------------\n","# Create a quality control/Prediction Folder\n","prediction_QC_folder = os.path.join(full_QC_model_path, 'Quality Control', 'Prediction')\n","if os.path.exists(prediction_QC_folder):\n"," shutil.rmtree(prediction_QC_folder)\n","\n","os.makedirs(prediction_QC_folder)\n","\n","\n","# ------------- Prepare the model and run predictions ------------\n","\n","# Load the model\n","unet = load_model(os.path.join(full_QC_model_path, 'weights_best.hdf5'), custom_objects={'_weighted_binary_crossentropy': weighted_binary_crossentropy(np.ones(2))})\n","Input_size = unet.layers[0].output_shape[1:3]\n","print('Model input size: '+str(Input_size[0])+'x'+str(Input_size[1]))\n","\n","# Create a list of sources\n","source_dir_list = os.listdir(Source_QC_folder)\n","number_of_dataset = len(source_dir_list)\n","print('Number of dataset found in the folder: '+str(number_of_dataset))\n","\n","predictions = []\n","for i in tqdm(range(number_of_dataset)):\n"," predictions.append(predict_as_tiles(os.path.join(Source_QC_folder, source_dir_list[i]), unet))\n","\n","\n","# Save the results in the folder along with the masks according to the set threshold\n","saveResult(prediction_QC_folder, predictions, source_dir_list, prefix=prediction_prefix, threshold=None)\n","\n","#-----------------------------Calculate Metrics----------------------------------------#\n","\n","f = plt.figure(figsize=((5,5)))\n","\n","with open(os.path.join(full_QC_model_path,'Quality Control', 'QC_metrics_'+QC_model_name+'.csv'), \"w\", newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"File name\",\"IoU\", \"IoU-optimised threshold\"]) \n","\n"," # Initialise the lists \n"," filename_list = []\n"," best_threshold_list = []\n"," best_IoU_score_list = []\n","\n"," for filename in os.listdir(Source_QC_folder):\n","\n"," if not os.path.isdir(os.path.join(Source_QC_folder, filename)):\n"," print('Running QC on: '+filename)\n"," test_input = io.imread(os.path.join(Source_QC_folder, filename), as_gray=True)\n"," test_ground_truth_image = io.imread(os.path.join(Target_QC_folder, filename), as_gray=True)\n","\n"," (threshold_list, iou_scores_per_threshold) = getIoUvsThreshold(os.path.join(prediction_QC_folder, prediction_prefix+filename), os.path.join(Target_QC_folder, filename))\n"," plt.plot(threshold_list,iou_scores_per_threshold, label=filename)\n","\n"," # Here we find which threshold yielded the highest IoU score for image n.\n"," best_IoU_score = max(iou_scores_per_threshold)\n"," best_threshold = iou_scores_per_threshold.index(best_IoU_score)\n","\n"," # Write the results in the CSV file\n"," writer.writerow([filename, str(best_IoU_score), str(best_threshold)])\n","\n"," # Here we append the best threshold and score to the lists\n"," filename_list.append(filename)\n"," best_IoU_score_list.append(best_IoU_score)\n"," best_threshold_list.append(best_threshold)\n","\n","# Display the IoV vs Threshold plot\n","plt.title('IoU vs. Threshold')\n","plt.ylabel('IoU')\n","plt.xlabel('Threshold value')\n","plt.legend()\n","plt.savefig(full_QC_model_path+'/Quality Control/'+QC_model_name+'_IoUvsThresholdPlot.png',bbox_inches='tight',pad_inches=0)\n","plt.show()\n","\n","\n","# Table with metrics as dataframe output\n","pdResults = pd.DataFrame(index = filename_list)\n","pdResults[\"IoU\"] = best_IoU_score_list\n","pdResults[\"IoU-optimised threshold\"] = best_threshold_list\n","\n","average_best_threshold = sum(best_threshold_list)/len(best_threshold_list)\n","\n","\n","# ------------- For display ------------\n","print('--------------------------------------------------------------')\n","@interact\n","def show_QC_results(file=os.listdir(Source_QC_folder)):\n"," \n"," plt.figure(figsize=(25,5))\n"," #Input\n"," plt.subplot(1,4,1)\n"," plt.axis('off')\n"," plt.imshow(plt.imread(os.path.join(Source_QC_folder, file)), aspect='equal', cmap='gray', interpolation='nearest')\n"," plt.title('Input')\n","\n"," #Ground-truth\n"," plt.subplot(1,4,2)\n"," plt.axis('off')\n"," test_ground_truth_image = io.imread(os.path.join(Target_QC_folder, file),as_gray=True)\n"," plt.imshow(test_ground_truth_image, aspect='equal', cmap='Greens')\n"," plt.title('Ground Truth')\n","\n"," #Prediction\n"," plt.subplot(1,4,3)\n"," plt.axis('off')\n"," test_prediction = plt.imread(os.path.join(prediction_QC_folder, prediction_prefix+file))\n"," test_prediction_mask = np.empty_like(test_prediction)\n"," test_prediction_mask[test_prediction > average_best_threshold] = 255\n"," test_prediction_mask[test_prediction <= average_best_threshold] = 0\n"," plt.imshow(test_prediction_mask, aspect='equal', cmap='Purples')\n"," plt.title('Prediction')\n","\n"," #Overlay\n"," plt.subplot(1,4,4)\n"," plt.axis('off')\n"," plt.imshow(test_ground_truth_image, cmap='Greens')\n"," plt.imshow(test_prediction_mask, alpha=0.5, cmap='Purples')\n"," metrics_title = 'Overlay (IoU: ' + str(round(pdResults.loc[file][\"IoU\"],3)) + ' T: ' + str(round(pdResults.loc[file][\"IoU-optimised threshold\"])) + ')'\n"," plt.title(metrics_title)\n"," plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","\n","\n","print('--------------------------------------------------------------')\n","print('Best average threshold is: '+str(round(average_best_threshold)))\n","print('--------------------------------------------------------------')\n","\n","pdResults.head()\n","\n","qc_pdf_export()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-tJeeJjLnRkP"},"source":["# **6. Using the trained model**\n","\n","---\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"d8wuQGjoq6eN"},"source":["## **6.1 Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.1) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder.\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images.\n","\n"," Once the predictions are complete the cell will display a random example prediction beside the input image and the calculated mask for visual inspection.\n","\n"," **Troubleshooting:** If there is a low contrast image warning when saving the images, this may be due to overfitting of the model to the data. It may result in images containing only a single colour. Train the network again with different network hyperparameters."]},{"cell_type":"code","metadata":{"id":"y2TD5p7MZrEb","cellView":"form"},"source":["\n","\n","# ------------- Initial user input ------------\n","#@markdown ###Provide the path to your dataset and to the folder where the prediction will be saved (Result folder), then play the cell to predict output on your unseen images.\n","Data_folder = '' #@param {type:\"string\"}\n","Results_folder = '' #@param {type:\"string\"}\n","\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","\n","# ------------- Failsafes ------------\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = os.path.join(Prediction_model_path, Prediction_model_name)\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","# ------------- Prepare the model and run predictions ------------\n","\n","# Load the model and prepare generator\n","\n","\n","\n","unet = load_model(os.path.join(Prediction_model_path, Prediction_model_name, 'weights_best.hdf5'), custom_objects={'_weighted_binary_crossentropy': weighted_binary_crossentropy(np.ones(2))})\n","Input_size = unet.layers[0].output_shape[1:3]\n","print('Model input size: '+str(Input_size[0])+'x'+str(Input_size[1]))\n","\n","# Create a list of sources\n","source_dir_list = os.listdir(Data_folder)\n","number_of_dataset = len(source_dir_list)\n","print('Number of dataset found in the folder: '+str(number_of_dataset))\n","\n","predictions = []\n","for i in tqdm(range(number_of_dataset)):\n"," predictions.append(predict_as_tiles(os.path.join(Data_folder, source_dir_list[i]), unet))\n"," # predictions.append(prediction(os.path.join(Data_folder, source_dir_list[i]), os.path.join(Prediction_model_path, Prediction_model_name)))\n","\n","\n","# Save the results in the folder along with the masks according to the set threshold\n","saveResult(Results_folder, predictions, source_dir_list, prefix=prediction_prefix, threshold=None)\n","\n","\n","# ------------- For display ------------\n","print('--------------------------------------------------------------')\n","\n","\n","def show_prediction_mask(file=os.listdir(Data_folder), threshold=(0,255,1)):\n","\n"," plt.figure(figsize=(18,6))\n"," # Wide-field\n"," plt.subplot(1,3,1)\n"," plt.axis('off')\n"," img_Source = plt.imread(os.path.join(Data_folder, file))\n"," plt.imshow(img_Source, cmap='gray')\n"," plt.title('Source image',fontsize=15)\n"," # Prediction\n"," plt.subplot(1,3,2)\n"," plt.axis('off')\n"," img_Prediction = plt.imread(os.path.join(Results_folder, prediction_prefix+file))\n"," plt.imshow(img_Prediction, cmap='gray')\n"," plt.title('Prediction',fontsize=15)\n","\n"," # Thresholded mask\n"," plt.subplot(1,3,3)\n"," plt.axis('off')\n"," img_Mask = convert2Mask(img_Prediction, threshold)\n"," plt.imshow(img_Mask, cmap='gray')\n"," plt.title('Mask (Threshold: '+str(round(threshold))+')',fontsize=15)\n","\n","\n","interact(show_prediction_mask, continuous_update=False);\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"stS96mFZLMOU"},"source":["## **6.2. Export results as masks**\n","---\n"]},{"cell_type":"code","metadata":{"cellView":"form","id":"qb5ZmFstLNbR"},"source":["\n","# @markdown #Play this cell to save results as masks with the chosen threshold\n","threshold = 120#@param {type:\"number\"}\n","\n","saveResult(Results_folder, predictions, source_dir_list, prefix=prediction_prefix, threshold=threshold)\n","print('-------------------')\n","print('Masks were saved in: '+Results_folder)\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"UvSlTaH14s3t"},"source":["#**Thank you for using 2D U-Net!**\n"]}]} \ No newline at end of file +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"U-Net_2D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1USP7Bmd4UEdhp9cOBc_wlqDXnZjMHk_f","timestamp":1622215174280},{"file_id":"1EZG34jBKULVmO__Fmv7Lr76sVHIMxwJx","timestamp":1622041273450},{"file_id":"1MmLTCC0nyX3Akb9V4C_OVxM3X_M8u-eX","timestamp":1610543191319},{"file_id":"1paNjUObR5Rcr4BMGADJTz0PQBBLZDPrY","timestamp":1602522500580},{"file_id":"1WZRIoSBNcRUEq4-Rq5M4mDkIaOlEHnxz","timestamp":1588762142860},{"file_id":"10weAY0es-pEfHlACCaBCKK7PmgdoJqdh","timestamp":1587728072051},{"file_id":"10Ze0rFZoooyyTL_OIVWGdFJEhWE6_cSB","timestamp":1586789421439},{"file_id":"1SsGyUbWcMaLGHFepMuKElRNYLdEBUwf6","timestamp":1583244509550}],"collapsed_sections":[],"toc_visible":true,"machine_shape":"hm"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.7"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"IkSguVy8Xv83"},"source":["# **U-Net (2D)**\n","---\n","\n","U-Net is an encoder-decoder network architecture originally used for image segmentation, first published by [Ronneberger *et al.*](https://arxiv.org/abs/1505.04597). The first half of the U-Net architecture is a downsampling convolutional neural network which acts as a feature extractor from input images. The other half upsamples these results and restores an image by combining results from downsampling with the upsampled images.\n","\n"," **This particular notebook enables image segmentation of 2D dataset. If you are interested in 3D dataset, you should use the 3D U-Net notebook instead.**\n","\n","---\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is largely based on the papers: \n","\n","**U-Net: Convolutional Networks for Biomedical Image Segmentation** by Ronneberger *et al.* published on arXiv in 2015 (https://arxiv.org/abs/1505.04597)\n","\n","and \n","\n","**U-Net: deep learning for cell counting, detection, and morphometry** by Thorsten Falk *et al.* in Nature Methods 2019\n","(https://www.nature.com/articles/s41592-018-0261-2)\n","And source code found in: /~https://github.com/zhixuhao/unet by *Zhixuhao*\n","\n","**Please also cite this original paper when using or developing this notebook.** "]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"gKDLkLWUd-YX"},"source":["# **0. Before getting started**\n","---\n","\n","Before you run the notebook, please ensure that you are logged into your Google account and have the training and/or data to process in your Google Drive.\n","\n","For U-Net to train, **it needs to have access to a paired training dataset corresponding to images and their corresponding masks**. Information on how to generate a training dataset is available in our Wiki page: /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n","Additionally, the corresponding Training_source and Training_target files need to have **the same name**.\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Training_source\n"," - img_1.tif, img_2.tif, ...\n"," - Training_target\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - Training_source\n"," - img_1.tif, img_2.tif\n"," - Training_target \n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"AdN8B91xZO0x"},"source":["# **1. Install U-Net dependencies**\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"UGWnGOFsf07b"},"source":["## **1.1. Install key dependencies**\n","---\n"," "]},{"cell_type":"code","metadata":{"id":"uc0haIa-fZiG","cellView":"form"},"source":["#@markdown ##Play to install U-Net dependencies\n","\n","!pip install data\n","!pip install fpdf\n","!pip install h5py==2.10\n","\n","#Force session restart\n","exit(0)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"I4O5zctbf4Gb"},"source":["\n","## **1.2. Restart your runtime**\n","---\n","\n","\n","\n","** Ignore the following message error message. Your Runtime has automatically restarted. This is normal.**\n","\n","\"\"
"]},{"cell_type":"markdown","metadata":{"id":"iiX3Ly-7gA5h"},"source":["## **1.3. Load key dependencies**\n","---\n"," "]},{"cell_type":"code","metadata":{"id":"fq21zJVFNASx","cellView":"form"},"source":["Notebook_version = '1.13'\n","Network = 'U-Net (2D)'\n","\n","\n","from builtins import any as b_any\n","\n","def get_requirements_path():\n"," # Store requirements file in 'contents' directory \n"," current_dir = os.getcwd()\n"," dir_count = current_dir.count('/') - 1\n"," path = '../' * (dir_count) + 'requirements.txt'\n"," return path\n","\n","def filter_files(file_list, filter_list):\n"," filtered_list = []\n"," for fname in file_list:\n"," if b_any(fname.split('==')[0] in s for s in filter_list):\n"," filtered_list.append(fname)\n"," return filtered_list\n","\n","def build_requirements_file(before, after):\n"," path = get_requirements_path()\n","\n"," # Exporting requirements.txt for local run\n"," !pip freeze > $path\n","\n"," # Get minimum requirements file\n"," df = pd.read_csv(path, delimiter = \"\\n\")\n"," mod_list = [m.split('.')[0] for m in after if not m in before]\n"," req_list_temp = df.values.tolist()\n"," req_list = [x[0] for x in req_list_temp]\n","\n"," # Replace with package name and handle cases where import name is different to module name\n"," mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]\n"," mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list] \n"," filtered_list = filter_files(req_list, mod_replace_list)\n","\n"," file=open(path,'w')\n"," for item in filtered_list:\n"," file.writelines(item + '\\n')\n","\n"," file.close()\n","\n","import sys\n","before = [str(m) for m in sys.modules]\n","\n","#@markdown ##Load key U-Net dependencies\n","\n","#As this notebokk depends mostly on keras which runs a tensorflow backend (which in turn is pre-installed in colab)\n","#only the data library needs to be additionally installed.\n","%tensorflow_version 1.x\n","import tensorflow as tf\n","# print(tensorflow.__version__)\n","# print(\"Tensorflow enabled.\")\n","\n","\n","# Keras imports\n","from keras import models\n","from keras.models import Model, load_model\n","from keras.layers import Input, Conv2D, MaxPooling2D, Dropout, concatenate, UpSampling2D\n","from keras.optimizers import Adam\n","# from keras.callbacks import ModelCheckpoint, LearningRateScheduler, CSVLogger # we currently don't use any other callbacks from ModelCheckpoints\n","from keras.callbacks import ModelCheckpoint\n","from keras.callbacks import ReduceLROnPlateau\n","from keras.preprocessing.image import ImageDataGenerator, img_to_array, load_img\n","from keras import backend as keras\n","\n","# General import\n","from __future__ import print_function\n","import numpy as np\n","import pandas as pd\n","import os\n","import glob\n","from skimage import img_as_ubyte, io, transform\n","import matplotlib as mpl\n","from matplotlib import pyplot as plt\n","from matplotlib.pyplot import imread\n","from pathlib import Path\n","import shutil\n","import random\n","import time\n","import csv\n","import sys\n","from math import ceil\n","from fpdf import FPDF, HTMLMixin\n","from pip._internal.operations.freeze import freeze\n","import subprocess\n","# Imports for QC\n","from PIL import Image\n","from scipy import signal\n","from scipy import ndimage\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","\n","# For sliders and dropdown menu and progress bar\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","# from tqdm import tqdm\n","from tqdm.notebook import tqdm\n","\n","from sklearn.feature_extraction import image\n","from skimage import img_as_ubyte, io, transform\n","from skimage.util.shape import view_as_windows\n","\n","from datetime import datetime\n","\n","\n","# Suppressing some warnings\n","import warnings\n","warnings.filterwarnings('ignore')\n","\n","\n","\n","\n","def create_patches(Training_source, Training_target, patch_width, patch_height, min_fraction):\n"," \"\"\"\n"," Function creates patches from the Training_source and Training_target images. \n"," The steps parameter indicates the offset between patches and, if integer, is the same in x and y.\n"," Saves all created patches in two new directories in the /content folder.\n","\n"," min_fraction is the minimum fraction of pixels that need to be foreground to be considered as a valid patch\n","\n"," Returns: - Two paths to where the patches are now saved\n"," \"\"\"\n"," DEBUG = False\n","\n"," Patch_source = os.path.join('/content','img_patches')\n"," Patch_target = os.path.join('/content','mask_patches')\n"," Patch_rejected = os.path.join('/content','rejected')\n"," \n","\n"," #Here we save the patches, in the /content directory as they will not usually be needed after training\n"," if os.path.exists(Patch_source):\n"," shutil.rmtree(Patch_source)\n"," if os.path.exists(Patch_target):\n"," shutil.rmtree(Patch_target)\n"," if os.path.exists(Patch_rejected):\n"," shutil.rmtree(Patch_rejected)\n","\n"," os.mkdir(Patch_source)\n"," os.mkdir(Patch_target)\n"," os.mkdir(Patch_rejected) #This directory will contain the images that have too little signal.\n"," \n"," patch_num = 0\n","\n"," for file in tqdm(os.listdir(Training_source)):\n","\n"," img = io.imread(os.path.join(Training_source, file))\n"," mask = io.imread(os.path.join(Training_target, file),as_gray=True)\n","\n"," if DEBUG:\n"," print(file)\n"," print(img.dtype)\n","\n"," # Using view_as_windows with step size equal to the patch size to ensure there is no overlap\n"," patches_img = view_as_windows(img, (patch_width, patch_height), (patch_width, patch_height))\n"," patches_mask = view_as_windows(mask, (patch_width, patch_height), (patch_width, patch_height))\n","\n"," patches_img = patches_img.reshape(patches_img.shape[0]*patches_img.shape[1], patch_width,patch_height)\n"," patches_mask = patches_mask.reshape(patches_mask.shape[0]*patches_mask.shape[1], patch_width,patch_height)\n","\n"," if DEBUG:\n"," print(all_patches_img.shape)\n"," print(all_patches_img.dtype)\n","\n"," for i in range(patches_img.shape[0]):\n"," img_save_path = os.path.join(Patch_source,'patch_'+str(patch_num)+'.tif')\n"," mask_save_path = os.path.join(Patch_target,'patch_'+str(patch_num)+'.tif')\n"," patch_num += 1\n","\n"," # if the mask conatins at least 2% of its total number pixels as mask, then go ahead and save the images\n"," pixel_threshold_array = sorted(patches_mask[i].flatten())\n"," if pixel_threshold_array[int(round(len(pixel_threshold_array)*(1-min_fraction)))]>0:\n"," io.imsave(img_save_path, img_as_ubyte(normalizeMinMax(patches_img[i])))\n"," io.imsave(mask_save_path, convert2Mask(normalizeMinMax(patches_mask[i]),0))\n"," else:\n"," io.imsave(Patch_rejected+'/patch_'+str(patch_num)+'_image.tif', img_as_ubyte(normalizeMinMax(patches_img[i])))\n"," io.imsave(Patch_rejected+'/patch_'+str(patch_num)+'_mask.tif', convert2Mask(normalizeMinMax(patches_mask[i]),0))\n","\n"," return Patch_source, Patch_target\n","\n","\n","def estimatePatchSize(data_path, max_width = 512, max_height = 512):\n","\n"," files = os.listdir(data_path)\n"," \n"," # Get the size of the first image found in the folder and initialise the variables to that\n"," n = 0 \n"," while os.path.isdir(os.path.join(data_path, files[n])):\n"," n += 1\n"," (height_min, width_min) = Image.open(os.path.join(data_path, files[n])).size\n","\n"," # Screen the size of all dataset to find the minimum image size\n"," for file in files:\n"," if not os.path.isdir(os.path.join(data_path, file)):\n"," (height, width) = Image.open(os.path.join(data_path, file)).size\n"," if width < width_min:\n"," width_min = width\n"," if height < height_min:\n"," height_min = height\n"," \n"," # Find the power of patches that will fit within the smallest dataset\n"," width_min, height_min = (fittingPowerOfTwo(width_min), fittingPowerOfTwo(height_min))\n","\n"," # Clip values at maximum permissible values\n"," if width_min > max_width:\n"," width_min = max_width\n","\n"," if height_min > max_height:\n"," height_min = max_height\n"," \n"," return (width_min, height_min)\n","\n","def fittingPowerOfTwo(number):\n"," n = 0\n"," while 2**n <= number:\n"," n += 1 \n"," return 2**(n-1)\n","\n","\n","def getClassWeights(Training_target_path):\n","\n"," Mask_dir_list = os.listdir(Training_target_path)\n"," number_of_dataset = len(Mask_dir_list)\n","\n"," class_count = np.zeros(2, dtype=int)\n"," for i in tqdm(range(number_of_dataset)):\n"," mask = io.imread(os.path.join(Training_target_path, Mask_dir_list[i]))\n"," mask = normalizeMinMax(mask)\n"," class_count[0] += mask.shape[0]*mask.shape[1] - mask.sum()\n"," class_count[1] += mask.sum()\n","\n"," n_samples = class_count.sum()\n"," n_classes = 2\n","\n"," class_weights = n_samples / (n_classes * class_count)\n"," return class_weights\n","\n","def weighted_binary_crossentropy(class_weights):\n","\n"," def _weighted_binary_crossentropy(y_true, y_pred):\n"," binary_crossentropy = keras.binary_crossentropy(y_true, y_pred)\n"," weight_vector = y_true * class_weights[1] + (1. - y_true) * class_weights[0]\n"," weighted_binary_crossentropy = weight_vector * binary_crossentropy\n","\n"," return keras.mean(weighted_binary_crossentropy)\n","\n"," return _weighted_binary_crossentropy\n","\n","\n","def save_augment(datagen,orig_img,dir_augmented_data=\"/content/augment\"):\n"," \"\"\"\n"," Saves a subset of the augmented data for visualisation, by default in /content.\n","\n"," This is adapted from: https://fairyonice.github.io/Learn-about-ImageDataGenerator.html\n"," \n"," \"\"\"\n"," try:\n"," os.mkdir(dir_augmented_data)\n"," except:\n"," ## if the preview folder exists, then remove\n"," ## the contents (pictures) in the folder\n"," for item in os.listdir(dir_augmented_data):\n"," os.remove(dir_augmented_data + \"/\" + item)\n","\n"," ## convert the original image to array\n"," x = img_to_array(orig_img)\n"," ## reshape (Sampke, Nrow, Ncol, 3) 3 = R, G or B\n"," #print(x.shape)\n"," x = x.reshape((1,) + x.shape)\n"," #print(x.shape)\n"," ## -------------------------- ##\n"," ## randomly generate pictures\n"," ## -------------------------- ##\n"," i = 0\n"," #We will just save 5 images,\n"," #but this can be changed, but note the visualisation in 3. currently uses 5.\n"," Nplot = 5\n"," for batch in datagen.flow(x,batch_size=1,\n"," save_to_dir=dir_augmented_data,\n"," save_format='tif',\n"," seed=42):\n"," i += 1\n"," if i > Nplot - 1:\n"," break\n","\n","# Generators\n","def buildDoubleGenerator(image_datagen, mask_datagen, image_folder_path, mask_folder_path, subset, batch_size, target_size):\n"," '''\n"," Can generate image and mask at the same time use the same seed for image_datagen and mask_datagen to ensure the transformation for image and mask is the same\n"," \n"," datagen: ImageDataGenerator \n"," subset: can take either 'training' or 'validation'\n"," '''\n"," seed = 1\n"," image_generator = image_datagen.flow_from_directory(\n"," os.path.dirname(image_folder_path),\n"," classes = [os.path.basename(image_folder_path)],\n"," class_mode = None,\n"," color_mode = \"grayscale\",\n"," target_size = target_size,\n"," batch_size = batch_size,\n"," subset = subset,\n"," interpolation = \"bicubic\",\n"," seed = seed)\n"," \n"," mask_generator = mask_datagen.flow_from_directory(\n"," os.path.dirname(mask_folder_path),\n"," classes = [os.path.basename(mask_folder_path)],\n"," class_mode = None,\n"," color_mode = \"grayscale\",\n"," target_size = target_size,\n"," batch_size = batch_size,\n"," subset = subset,\n"," interpolation = \"nearest\",\n"," seed = seed)\n"," \n"," this_generator = zip(image_generator, mask_generator)\n"," for (img,mask) in this_generator:\n"," # img,mask = adjustData(img,mask)\n"," yield (img,mask)\n","\n","\n","def prepareGenerators(image_folder_path, mask_folder_path, datagen_parameters, batch_size = 4, target_size = (512, 512)):\n"," image_datagen = ImageDataGenerator(**datagen_parameters, preprocessing_function = normalizePercentile)\n"," mask_datagen = ImageDataGenerator(**datagen_parameters, preprocessing_function = normalizeMinMax)\n","\n"," train_datagen = buildDoubleGenerator(image_datagen, mask_datagen, image_folder_path, mask_folder_path, 'training', batch_size, target_size)\n"," validation_datagen = buildDoubleGenerator(image_datagen, mask_datagen, image_folder_path, mask_folder_path, 'validation', batch_size, target_size)\n","\n"," return (train_datagen, validation_datagen)\n","\n","\n","# Normalization functions from Martin Weigert\n","def normalizePercentile(x, pmin=1, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","\n","\n","# Simple normalization to min/max fir the Mask\n","def normalizeMinMax(x, dtype=np.float32):\n"," x = x.astype(dtype,copy=False)\n"," x = (x - np.amin(x)) / (np.amax(x) - np.amin(x))\n"," return x\n","\n","\n","# This is code outlines the architecture of U-net. The choice of pooling steps decides the depth of the network. \n","def unet(pretrained_weights = None, input_size = (256,256,1), pooling_steps = 4, learning_rate = 1e-4, verbose=True, class_weights=np.ones(2)):\n"," inputs = Input(input_size)\n"," conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(inputs)\n"," conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1)\n"," # Downsampling steps\n"," pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)\n"," conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1)\n"," conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2)\n"," \n"," if pooling_steps > 1:\n"," pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)\n"," conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2)\n"," conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3)\n","\n"," if pooling_steps > 2:\n"," pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)\n"," conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3)\n"," conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4)\n"," drop4 = Dropout(0.5)(conv4)\n"," \n"," if pooling_steps > 3:\n"," pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)\n"," conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4)\n"," conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5)\n"," drop5 = Dropout(0.5)(conv5)\n","\n"," #Upsampling steps\n"," up6 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5))\n"," merge6 = concatenate([drop4,up6], axis = 3)\n"," conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6)\n"," conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6)\n"," \n"," if pooling_steps > 2:\n"," up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop4))\n"," if pooling_steps > 3:\n"," up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6))\n"," merge7 = concatenate([conv3,up7], axis = 3)\n"," conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7)\n"," conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7)\n","\n"," if pooling_steps > 1:\n"," up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv3))\n"," if pooling_steps > 2:\n"," up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7))\n"," merge8 = concatenate([conv2,up8], axis = 3)\n"," conv8 = Conv2D(128, 3, activation= 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8)\n"," conv8 = Conv2D(128, 3, activation= 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8)\n"," \n"," if pooling_steps == 1:\n"," up9 = Conv2D(64, 2, padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv2))\n"," else:\n"," up9 = Conv2D(64, 2, padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8)) #activation = 'relu'\n"," \n"," merge9 = concatenate([conv1,up9], axis = 3)\n"," conv9 = Conv2D(64, 3, padding = 'same', kernel_initializer = 'he_normal')(merge9) #activation = 'relu'\n"," conv9 = Conv2D(64, 3, padding = 'same', kernel_initializer = 'he_normal')(conv9) #activation = 'relu'\n"," conv9 = Conv2D(2, 3, padding = 'same', kernel_initializer = 'he_normal')(conv9) #activation = 'relu'\n"," conv10 = Conv2D(1, 1, activation = 'sigmoid')(conv9)\n","\n"," model = Model(inputs = inputs, outputs = conv10)\n","\n"," # model.compile(optimizer = Adam(lr = learning_rate), loss = 'binary_crossentropy', metrics = ['acc'])\n"," model.compile(optimizer = Adam(lr = learning_rate), loss = weighted_binary_crossentropy(class_weights))\n","\n","\n"," if verbose:\n"," model.summary()\n","\n"," if(pretrained_weights):\n"," \tmodel.load_weights(pretrained_weights);\n","\n"," return model\n","\n","\n","\n","def predict_as_tiles(Image_path, model):\n","\n"," # Read the data in and normalize\n"," Image_raw = io.imread(Image_path, as_gray = True)\n"," Image_raw = normalizePercentile(Image_raw)\n","\n"," # Get the patch size from the input layer of the model\n"," patch_size = model.layers[0].output_shape[1:3]\n","\n"," # Pad the image with zeros if any of its dimensions is smaller than the patch size\n"," if Image_raw.shape[0] < patch_size[0] or Image_raw.shape[1] < patch_size[1]:\n"," Image = np.zeros((max(Image_raw.shape[0], patch_size[0]), max(Image_raw.shape[1], patch_size[1])))\n"," Image[0:Image_raw.shape[0], 0: Image_raw.shape[1]] = Image_raw\n"," else:\n"," Image = Image_raw\n","\n"," # Calculate the number of patches in each dimension\n"," n_patch_in_width = ceil(Image.shape[0]/patch_size[0])\n"," n_patch_in_height = ceil(Image.shape[1]/patch_size[1])\n","\n"," prediction = np.zeros(Image.shape)\n","\n"," for x in range(n_patch_in_width):\n"," for y in range(n_patch_in_height):\n"," xi = patch_size[0]*x\n"," yi = patch_size[1]*y\n","\n"," # If the patch exceeds the edge of the image shift it back \n"," if xi+patch_size[0] >= Image.shape[0]:\n"," xi = Image.shape[0]-patch_size[0]\n","\n"," if yi+patch_size[1] >= Image.shape[1]:\n"," yi = Image.shape[1]-patch_size[1]\n"," \n"," # Extract and reshape the patch\n"," patch = Image[xi:xi+patch_size[0], yi:yi+patch_size[1]]\n"," patch = np.reshape(patch,patch.shape+(1,))\n"," patch = np.reshape(patch,(1,)+patch.shape)\n","\n"," # Get the prediction from the patch and paste it in the prediction in the right place\n"," predicted_patch = model.predict(patch, batch_size = 1)\n"," prediction[xi:xi+patch_size[0], yi:yi+patch_size[1]] = np.squeeze(predicted_patch)\n","\n","\n"," return prediction[0:Image_raw.shape[0], 0: Image_raw.shape[1]]\n"," \n","\n","\n","\n","def saveResult(save_path, nparray, source_dir_list, prefix='', threshold=None):\n"," for (filename, image) in zip(source_dir_list, nparray):\n"," io.imsave(os.path.join(save_path, prefix+os.path.splitext(filename)[0]+'.tif'), img_as_ubyte(image)) # saving as unsigned 8-bit image\n"," \n"," # For masks, threshold the images and return 8 bit image\n"," if threshold is not None:\n"," mask = convert2Mask(image, threshold)\n"," io.imsave(os.path.join(save_path, prefix+'mask_'+os.path.splitext(filename)[0]+'.tif'), mask)\n","\n","\n","def convert2Mask(image, threshold):\n"," mask = img_as_ubyte(image, force_copy=True)\n"," mask[mask > threshold] = 255\n"," mask[mask <= threshold] = 0\n"," return mask\n","\n","\n","def getIoUvsThreshold(prediction_filepath, groud_truth_filepath):\n"," prediction = io.imread(prediction_filepath)\n"," ground_truth_image = img_as_ubyte(io.imread(groud_truth_filepath, as_gray=True), force_copy=True)\n","\n"," threshold_list = []\n"," IoU_scores_list = []\n","\n"," for threshold in range(0,256): \n"," # Convert to 8-bit for calculating the IoU\n"," mask = img_as_ubyte(prediction, force_copy=True)\n"," mask[mask > threshold] = 255\n"," mask[mask <= threshold] = 0\n","\n"," # Intersection over Union metric\n"," intersection = np.logical_and(ground_truth_image, np.squeeze(mask))\n"," union = np.logical_or(ground_truth_image, np.squeeze(mask))\n"," iou_score = np.sum(intersection) / np.sum(union)\n","\n"," threshold_list.append(threshold)\n"," IoU_scores_list.append(iou_score)\n","\n"," return (threshold_list, IoU_scores_list)\n","\n","\n","\n","# -------------- Other definitions -----------\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","prediction_prefix = 'Predicted_'\n","\n","\n","print('-------------------')\n","print('U-Net and dependencies installed.')\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","\n","# Check if this is the latest version of the notebook\n","\n","All_notebook_versions = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_Notebook_versions.csv\", dtype=str)\n","print('Notebook version: '+Notebook_version)\n","Latest_Notebook_version = All_notebook_versions[All_notebook_versions[\"Notebook\"] == Network]['Version'].iloc[0]\n","print('Latest notebook version: '+Latest_Notebook_version)\n","if Notebook_version == Latest_Notebook_version:\n"," print(\"This notebook is up-to-date.\")\n","else:\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","\n","def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n","\n","\n","\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," \n"," # add another cell \n"," if trained:\n"," training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n"," pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n"," pdf.ln(1)\n","\n"," Header_2 = 'Information for your materials and method:'\n"," pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n"," #print(all_packages)\n","\n"," #Main Packages\n"," main_packages = ''\n"," version_numbers = []\n"," for name in ['tensorflow','numpy','Keras']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n"," cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n"," cuda_version = cuda_version.stdout.decode('utf-8')\n"," cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n"," gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n"," gpu_name = gpu_name.stdout.decode('utf-8')\n"," gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n"," #print(cuda_version[cuda_version.find(', V')+3:-1])\n"," #print(gpu_name)\n"," loss = str(model.loss)[str(model.loss).find('function')+len('function'):str(model.loss).find('.<')]\n"," shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n"," dataset_size = len(os.listdir(Training_source))\n","\n"," text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(number_of_training_dataset)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_width)+','+str(patch_height)+')) with a batch size of '+str(batch_size)+' and a'+loss+' loss function,'+' using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," if pretrained_model:\n"," text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(number_of_training_dataset)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_width)+','+str(patch_height)+')) with a batch size of '+str(batch_size)+' and a'+loss+' loss function,'+' using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was re-trained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.multi_cell(180, 5, txt = text, align='L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(28, 5, txt='Augmentation: ', ln=1)\n"," pdf.set_font('')\n"," if augmentation:\n"," aug_text = 'The dataset was augmented by'\n"," if rotation_range != 0:\n"," aug_text = aug_text+'\\n- rotation'\n"," if horizontal_flip == True or vertical_flip == True:\n"," aug_text = aug_text+'\\n- flipping'\n"," if zoom_range != 0:\n"," aug_text = aug_text+'\\n- random zoom magnification'\n"," if horizontal_shift != 0 or vertical_shift != 0:\n"," aug_text = aug_text+'\\n- shifting'\n"," if shear_range != 0:\n"," aug_text = aug_text+'\\n- image shearing'\n"," else:\n"," aug_text = 'No augmentation was used for training.'\n"," pdf.multi_cell(190, 5, txt=aug_text, align='L')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," if Use_Default_Advanced_Parameters:\n"," pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n"," pdf.cell(200, 5, txt='The following parameters were used for training:')\n"," pdf.ln(1)\n"," html = \"\"\" \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
number_of_epochs{0}
patch_size{1}
batch_size{2}
number_of_steps{3}
percentage_validation{4}
initial_learning_rate{5}
pooling_steps{6}
min_fraction{7}
\n"," \"\"\".format(number_of_epochs, str(patch_width)+'x'+str(patch_height), batch_size, number_of_steps, percentage_validation, initial_learning_rate, pooling_steps, min_fraction)\n"," pdf.write_html(html)\n","\n"," #pdf.multi_cell(190, 5, txt = text_2, align='L')\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," pdf.ln(1)\n"," pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(29, 5, txt= 'Training_source:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(28, 5, txt= 'Training_target:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n"," #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(21, 5, txt= 'Model Path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n"," pdf.ln(1)\n"," pdf.cell(60, 5, txt = 'Example Training pair', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread('/content/TrainingDataExample_Unet2D.png').shape\n"," pdf.image('/content/TrainingDataExample_Unet2D.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- Unet: Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. \"U-net: Convolutional networks for biomedical image segmentation.\" International Conference on Medical image computing and computer-assisted intervention. Springer, Cham, 2015.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," # if Use_Data_augmentation:\n"," # ref_3 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\n"," # pdf.multi_cell(190, 5, txt = ref_3, align='L')\n"," pdf.ln(3)\n"," reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(model_path+'/'+model_name+'/'+model_name+'_training_report.pdf')\n","\n"," print('------------------------------')\n"," print('PDF report exported in '+model_path+'/'+model_name+'/')\n","\n","def qc_pdf_export():\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'Unet 2D'\n","\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(2)\n"," pdf.cell(190, 5, txt = 'Loss curves', ln=1, align='L')\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n"," if os.path.exists(full_QC_model_path+'/Quality Control/lossCurvePlots.png'):\n"," pdf.image(full_QC_model_path+'/Quality Control/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/12), h = round(exp_size[0]/3))\n"," else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.',align='L')\n"," pdf.ln(2)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(2)\n"," pdf.cell(190, 5, txt = 'Threshold Optimisation', ln=1, align='L')\n"," #pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/'+QC_model_name+'_IoUvsThresholdPlot.png').shape\n"," pdf.image(full_QC_model_path+'/Quality Control/'+QC_model_name+'_IoUvsThresholdPlot.png', x = 11, y = None, w = round(exp_size[1]/6), h = round(exp_size[0]/7))\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(3)\n"," pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n"," pdf.image(full_QC_model_path+'/Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n","\n"," pdf.ln(1)\n"," html = \"\"\"\n"," \n"," \n"," \"\"\"\n"," with open(full_QC_model_path+'/Quality Control/QC_metrics_'+QC_model_name+'.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," IoU = header[1]\n"," IoU_OptThresh = header[2]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \"\"\".format(image,IoU,IoU_OptThresh)\n"," html = html+header\n"," i=0\n"," for row in metrics:\n"," i+=1\n"," image = row[0]\n"," IoU = row[1]\n"," IoU_OptThresh = row[2]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \"\"\".format(image,str(round(float(IoU),3)),str(round(float(IoU_OptThresh),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}
{0}{1}{2}
\"\"\"\n"," \n"," pdf.write_html(html)\n","\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- Unet: Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. \"U-net: Convolutional networks for biomedical image segmentation.\" International Conference on Medical image computing and computer-assisted intervention. Springer, Cham, 2015.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n"," pdf.ln(3)\n"," reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n","\n"," print('------------------------------')\n"," print('QC PDF report exported as '+full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n","\n","# Build requirements file for local run\n","after = [str(m) for m in sys.modules]\n","build_requirements_file(before, after)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **2. Complete the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["\n","## **2.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelerator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"zCvebubeSaGY","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi\n","\n","# from tensorflow.python.client import device_lib \n","# device_lib.list_local_devices()\n","\n","# print the tensorflow version\n","print('Tensorflow version is ' + str(tf.__version__))\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"sNIVx8_CLolt"},"source":["## **2.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","cellView":"form"},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"dm3eCMYB5d-H"},"source":["** If you cannot see your files, reactivate your session by connecting to your hosted runtime.** \n","\n","\n","\"Example
Connect to a hosted runtime.
"]},{"cell_type":"markdown","metadata":{"id":"HLYcZR9gMv42"},"source":["# **3. Select your parameters and paths**\n","---"]},{"cell_type":"markdown","metadata":{"id":"FQ_QxtSWQ7CL"},"source":["## **3.1. Setting main training parameters**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"id":"AuESFimvMv43"},"source":[" **Paths for training data and models**\n","\n","**`Training_source`, `Training_target`:** These are the folders containing your source (e.g. EM images) and target files (segmentation masks). Enter the path to the source and target images for training. **These should be located in the same parent folder.**\n","\n","**`model_name`:** Use only my_model -style, not my-model. If you want to use a previously trained model, enter the name of the pretrained model (which should be contained in the trained_model -folder after training).\n","\n","**`model_path`**: Enter the path of the folder where you want to save your model.\n","\n","**`visual_validation_after_training`**: If you select this option, a random image pair will be set aside from your training set and will be used to display a predicted image of the trained network next to the input and the ground-truth. This can aid in visually assessing the performance of your network after training. **Note: Your training set size will decrease by 1 if you select this option.**\n","\n"," **Select training parameters**\n","\n","**`number_of_epochs`**: Choose more epochs for larger training sets. Observing how much the loss reduces between epochs during training may help determine the optimal value. **Default: 200**\n","\n","**Advanced parameters - experienced users only**\n","\n","**`batch_size`**: This parameter describes the amount of images that are loaded into the network per step. Smaller batchsizes may improve training performance slightly but may increase training time. If the notebook crashes while loading the dataset this can be due to a too large batch size. Decrease the number in this case. **Default: 4**\n","\n","**`number_of_steps`**: This number should be equivalent to the number of samples in the training set divided by the batch size, to ensure the training iterates through the entire training set. The default value is calculated to ensure this. This behaviour can also be obtained by setting it to 0. Other values can be used for testing.\n","\n"," **`pooling_steps`**: Choosing a different number of pooling layers can affect the performance of the network. Each additional pooling step will also two additional convolutions. The network can learn more complex information but is also more likely to overfit. Achieving best performance may require testing different values here. **Default: 2**\n","\n","**`percentage_validation`:** Input the percentage of your training dataset you want to use to validate the network during training. **Default value: 10** \n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0003**\n","\n","**`patch_width` and `patch_height`:** The notebook crops the data in patches of fixed size prior to training. The dimensions of the patches can be defined here. When `Use_Default_Advanced_Parameters` is selected, the largest 2^n x 2^n patch that fits in the smallest dataset is chosen. Larger patches than 512x512 should **NOT** be selected for network stability.\n","\n","**`min_fraction`:** Minimum fraction of pixels being foreground for a slected patch to be considered valid. It should be between 0 and 1.**Default value: 0.02** (2%)\n","\n"]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","cellView":"form"},"source":["# ------------- Initial user input ------------\n","#@markdown ###Path to training images:\n","Training_source = '' #@param {type:\"string\"}\n","Training_target = '' #@param {type:\"string\"}\n","\n","model_name = '' #@param {type:\"string\"}\n","model_path = '' #@param {type:\"string\"}\n","\n","#@markdown ###Training parameters:\n","#@markdown Number of epochs\n","number_of_epochs = 200#@param {type:\"number\"}\n","\n","#@markdown ###Advanced parameters:\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please input:\n","batch_size = 4#@param {type:\"integer\"}\n","number_of_steps = 0#@param {type:\"number\"}\n","pooling_steps = 2 #@param [1,2,3,4]{type:\"raw\"}\n","percentage_validation = 10#@param{type:\"number\"}\n","initial_learning_rate = 0.0003 #@param {type:\"number\"}\n","\n","patch_width = 512#@param{type:\"number\"}\n","patch_height = 512#@param{type:\"number\"}\n","min_fraction = 0.02#@param{type:\"number\"}\n","\n","\n","# ------------- Initialising folder, variables and failsafes ------------\n","# Create the folders where to save the model and the QC\n","full_model_path = os.path.join(model_path, model_name)\n","if os.path.exists(full_model_path):\n"," print(R+'!! WARNING: Folder already exists and will be overwritten !!'+W)\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 4\n"," pooling_steps = 2\n"," percentage_validation = 10\n"," initial_learning_rate = 0.0003\n"," patch_width, patch_height = estimatePatchSize(Training_source)\n"," min_fraction = 0.02\n","\n","\n","#The create_patches function will create the two folders below\n","# Patch_source = '/content/img_patches'\n","# Patch_target = '/content/mask_patches'\n","print('Training on patches of size (x,y): ('+str(patch_width)+','+str(patch_height)+')')\n","\n","#Create patches\n","print('Creating patches...')\n","Patch_source, Patch_target = create_patches(Training_source, Training_target, patch_width, patch_height, min_fraction)\n","\n","number_of_training_dataset = len(os.listdir(Patch_source))\n","print('Total number of valid patches: '+str(number_of_training_dataset))\n","\n","if Use_Default_Advanced_Parameters or number_of_steps == 0:\n"," number_of_steps = ceil((100-percentage_validation)/100*number_of_training_dataset/batch_size)\n","print('Number of steps: '+str(number_of_steps))\n","\n","# Calculate the number of steps to use for validation\n","validation_steps = max(1, ceil(percentage_validation/100*number_of_training_dataset/batch_size))\n","\n","\n","# Here we disable pre-trained model by default (in case the next cell is not ran)\n","Use_pretrained_model = False\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","Use_Data_augmentation = False\n","# Build the default dict for the ImageDataGenerator\n","data_gen_args = dict(width_shift_range = 0.,\n"," height_shift_range = 0.,\n"," rotation_range = 0., #90\n"," zoom_range = 0.,\n"," shear_range = 0.,\n"," horizontal_flip = False,\n"," vertical_flip = False,\n"," validation_split = percentage_validation/100,\n"," fill_mode = 'reflect')\n","\n","# ------------- Display ------------\n","\n","#if not os.path.exists('/content/img_patches/'):\n","random_choice = random.choice(os.listdir(Patch_source))\n","x = io.imread(os.path.join(Patch_source, random_choice))\n","\n","#os.chdir(Training_target)\n","y = io.imread(os.path.join(Patch_target, random_choice), as_gray=True)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest',cmap='gray')\n","plt.title('Training image patch')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(y, interpolation='nearest',cmap='gray')\n","plt.title('Training mask patch')\n","plt.axis('off');\n","\n","plt.savefig('/content/TrainingDataExample_Unet2D.png',bbox_inches='tight',pad_inches=0)\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"w_jCy7xOx2g3"},"source":["##**3.2. Data augmentation**\n","\n","---\n","\n"," Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if the dataset is large the values can be set to 0.\n","\n"," The augmentation options below are to be used as follows:\n","\n","* **shift**: a translation of the image by a fraction of the image size (width or height), **default: 10%**\n","* **zoom_range**: Increasing or decreasing the field of view. E.g. 10% will result in a zoom range of (0.9 to 1.1), with pixels added or interpolated, depending on the transformation, **default: 10%**\n","* **shear_range**: Shear angle in counter-clockwise direction, **default: 10%**\n","* **flip**: creating a mirror image along specified axis (horizontal or vertical), **default: True**\n","* **rotation_range**: range of allowed rotation angles in degrees (from 0 to *value*), **default: 180**"]},{"cell_type":"code","metadata":{"id":"DMqWq5-AxnFU","cellView":"form"},"source":["#@markdown ##**Augmentation options**\n","\n","Use_Data_augmentation = True #@param {type:\"boolean\"}\n","Use_Default_Augmentation_Parameters = True #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation:\n"," if Use_Default_Augmentation_Parameters:\n"," horizontal_shift = 10 \n"," vertical_shift = 20 \n"," zoom_range = 10\n"," shear_range = 10\n"," horizontal_flip = True\n"," vertical_flip = True\n"," rotation_range = 180\n","#@markdown ###If you are not using the default settings, please provide the values below:\n","\n","#@markdown ###**Image shift, zoom, shear and flip (%)**\n"," else:\n"," horizontal_shift = 10 #@param {type:\"slider\", min:0, max:100, step:1}\n"," vertical_shift = 10 #@param {type:\"slider\", min:0, max:100, step:1}\n"," zoom_range = 10 #@param {type:\"slider\", min:0, max:100, step:1}\n"," shear_range = 10 #@param {type:\"slider\", min:0, max:100, step:1}\n"," horizontal_flip = True #@param {type:\"boolean\"}\n"," vertical_flip = True #@param {type:\"boolean\"}\n","\n","#@markdown ###**Rotate image within angle range (degrees):**\n"," rotation_range = 180 #@param {type:\"slider\", min:0, max:180, step:1}\n","\n","#given behind the # are the default values for each parameter.\n","\n","else:\n"," horizontal_shift = 0 \n"," vertical_shift = 0 \n"," zoom_range = 0\n"," shear_range = 0\n"," horizontal_flip = False\n"," vertical_flip = False\n"," rotation_range = 0\n","\n","\n","# Build the dict for the ImageDataGenerator\n","data_gen_args = dict(width_shift_range = horizontal_shift/100.,\n"," height_shift_range = vertical_shift/100.,\n"," rotation_range = rotation_range, #90\n"," zoom_range = zoom_range/100.,\n"," shear_range = shear_range/100.,\n"," horizontal_flip = horizontal_flip,\n"," vertical_flip = vertical_flip,\n"," validation_split = percentage_validation/100,\n"," fill_mode = 'reflect')\n","\n","\n","\n","# ------------- Display ------------\n","dir_augmented_data_imgs=\"/content/augment_img\"\n","dir_augmented_data_masks=\"/content/augment_mask\"\n","random_choice = random.choice(os.listdir(Patch_source))\n","orig_img = load_img(os.path.join(Patch_source,random_choice))\n","orig_mask = load_img(os.path.join(Patch_target,random_choice))\n","\n","augment_view = ImageDataGenerator(**data_gen_args)\n","\n","if Use_Data_augmentation:\n"," print(\"Parameters enabled\")\n"," print(\"Here is what a subset of your augmentations looks like:\")\n"," save_augment(augment_view, orig_img, dir_augmented_data=dir_augmented_data_imgs)\n"," save_augment(augment_view, orig_mask, dir_augmented_data=dir_augmented_data_masks)\n","\n"," fig = plt.figure(figsize=(15, 7))\n"," fig.subplots_adjust(hspace=0.0,wspace=0.1,left=0,right=1.1,bottom=0, top=0.8)\n","\n"," \n"," ax = fig.add_subplot(2, 6, 1,xticks=[],yticks=[]) \n"," new_img=img_as_ubyte(normalizeMinMax(img_to_array(orig_img)))\n"," ax.imshow(new_img)\n"," ax.set_title('Original Image')\n"," i = 2\n"," for imgnm in os.listdir(dir_augmented_data_imgs):\n"," ax = fig.add_subplot(2, 6, i,xticks=[],yticks=[]) \n"," img = load_img(dir_augmented_data_imgs + \"/\" + imgnm)\n"," ax.imshow(img)\n"," i += 1\n","\n"," ax = fig.add_subplot(2, 6, 7,xticks=[],yticks=[]) \n"," new_mask=img_as_ubyte(normalizeMinMax(img_to_array(orig_mask)))\n"," ax.imshow(new_mask)\n"," ax.set_title('Original Mask')\n"," j=2\n"," for imgnm in os.listdir(dir_augmented_data_masks):\n"," ax = fig.add_subplot(2, 6, j+6,xticks=[],yticks=[]) \n"," mask = load_img(dir_augmented_data_masks + \"/\" + imgnm)\n"," ax.imshow(mask)\n"," j += 1\n"," plt.show()\n","\n","else:\n"," print(\"No augmentation will be used\")\n","\n"," "],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"3L9zSGtORKYI"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a U-Net model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"9vC2n-HeLdiJ","cellView":"form"},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","pretrained_model_choice = \"Model_from_file\" #@param [\"Model_from_file\"]\n","Weights_choice = \"last\" #@param [\"last\", \"best\"]\n","\n","\n","#@markdown ###If you chose \"Model_from_file\", please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n","# --------------------- Load the model from the choosen path ------------------------\n"," if pretrained_model_choice == \"Model_from_file\":\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".hdf5\")\n","\n","\n","# --------------------- Download the a model provided in the XXX ------------------------\n","\n"," if pretrained_model_choice == \"Model_name\":\n"," pretrained_model_name = \"Model_name\"\n"," pretrained_model_path = \"/content/\"+pretrained_model_name\n"," print(\"Downloading the UNET_Model_from_\")\n"," if os.path.exists(pretrained_model_path):\n"," shutil.rmtree(pretrained_model_path)\n"," os.makedirs(pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path)\n"," wget.download(\"\", pretrained_model_path) \n"," wget.download(\"\", pretrained_model_path)\n"," h5_file_path = os.path.join(pretrained_model_path, \"weights_\"+Weights_choice+\".hdf5\")\n","\n","# --------------------- Add additional pre-trained models here ------------------------\n","\n","\n","\n","# --------------------- Check the model exist ------------------------\n","# If the model path chosen does not contain a pretrain model then use_pretrained_model is disabled, \n"," if not os.path.exists(h5_file_path):\n"," print(R+'WARNING: pretrained model does not exist')\n"," Use_pretrained_model = False\n"," \n","\n","# If the model path contains a pretrain model, we load the training rate, \n"," if os.path.exists(h5_file_path):\n","#Here we check if the learning rate can be loaded from the quality control folder\n"," if os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n","\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," #print(csvRead)\n"," \n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","\n","#Compatibility with models trained outside ZeroCostDL4Mic but default learning rate will be used\n"," if not os.path.exists(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv')):\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(initial_learning_rate)+' will be used instead'+ W)\n"," bestLearningRate = initial_learning_rate\n"," lastLearningRate = initial_learning_rate\n","\n","\n","# Display info about the pretrained model to be loaded (or not)\n","if Use_pretrained_model:\n"," print('Weights found in:')\n"," print(h5_file_path)\n"," print('will be loaded prior to training.')\n","\n","else:\n"," print(R+'No pretrained network will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"MCGklf1vZf2M"},"source":["# **4. Train the network**\n","---\n","####**Troubleshooting:** If you receive a time-out or exhausted error, try reducing the batchsize of your training set. This reduces the amount of data loaded into the model at one point in time. "]},{"cell_type":"markdown","metadata":{"id":"1KYOuygETJkT"},"source":["## **4.1. Prepare the training data and model for training**\n","---\n","Here, we use the information from 3. to build the model and convert the training data into a suitable format for training."]},{"cell_type":"code","metadata":{"id":"lIUAOJ_LMv5E","cellView":"form"},"source":["#@markdown ##Play this cell to prepare the model for training\n","\n","\n","# ------------------ Set the generators, model and logger ------------------\n","# This will take the image size and set that as a patch size (arguable...)\n","# Read image size (without actuall reading the data)\n","\n","(train_datagen, validation_datagen) = prepareGenerators(Patch_source, Patch_target, data_gen_args, batch_size, target_size = (patch_width, patch_height))\n","\n","\n","# This modelcheckpoint will only save the best model from the validation loss point of view\n","model_checkpoint = ModelCheckpoint(os.path.join(full_model_path, 'weights_best.hdf5'), monitor='val_loss',verbose=1, save_best_only=True)\n","\n","print('Getting class weights...')\n","class_weights = getClassWeights(Training_target)\n","\n","# --------------------- Using pretrained model ------------------------\n","#Here we ensure that the learning rate set correctly when using pre-trained models\n","if Use_pretrained_model:\n"," if Weights_choice == \"last\":\n"," initial_learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\": \n"," initial_learning_rate = bestLearningRate\n","else:\n"," h5_file_path = None\n","\n","# --------------------- ---------------------- ------------------------\n","\n","# --------------------- Reduce learning rate on plateau ------------------------\n","\n","reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, verbose=1, mode='auto',\n"," patience=10, min_lr=0)\n","# --------------------- ---------------------- ------------------------\n","\n","\n","# Define the model\n","model = unet(pretrained_weights = h5_file_path, \n"," input_size = (patch_width,patch_height,1), \n"," pooling_steps = pooling_steps, \n"," learning_rate = initial_learning_rate, \n"," class_weights = class_weights)\n","\n","config_model= model.optimizer.get_config()\n","print(config_model)\n","\n","\n","# ------------------ Failsafes ------------------\n","if os.path.exists(full_model_path):\n"," print(R+'!! WARNING: Model folder already existed and has been removed !!'+W)\n"," shutil.rmtree(full_model_path)\n","\n","os.makedirs(full_model_path)\n","os.makedirs(os.path.join(full_model_path,'Quality Control'))\n","\n","\n","# ------------------ Display ------------------\n","print('---------------------------- Main training parameters ----------------------------')\n","print('Number of epochs: '+str(number_of_epochs))\n","print('Batch size: '+str(batch_size))\n","print('Number of training dataset: '+str(number_of_training_dataset))\n","print('Number of training steps: '+str(number_of_steps))\n","print('Number of validation steps: '+str(validation_steps))\n","print('---------------------------- ------------------------ ----------------------------')\n","\n","pdf_export(augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"0Dfn8ZsEMv5d"},"source":["## **4.2. Start Training**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point.\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder from Google Drive as all data can be erased at the next training if using the same folder."]},{"cell_type":"code","metadata":{"scrolled":true,"id":"iwNmp1PUzRDQ","cellView":"form"},"source":["#@markdown ##Start training\n","\n","start = time.time()\n","# history = model.fit_generator(train_datagen, steps_per_epoch = number_of_steps, epochs=epochs, callbacks=[model_checkpoint,csv_log], validation_data = validation_datagen, validation_steps = validation_steps, shuffle=True, verbose=1)\n","history = model.fit_generator(train_datagen, steps_per_epoch = number_of_steps, epochs = number_of_epochs, callbacks=[model_checkpoint, reduce_lr], validation_data = validation_datagen, validation_steps = validation_steps, shuffle=True, verbose=1)\n","\n","# Save the last model\n","model.save(os.path.join(full_model_path, 'weights_last.hdf5'))\n","\n","\n","# convert the history.history dict to a pandas DataFrame: \n","lossData = pd.DataFrame(history.history) \n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n","lossDataCSVpath = os.path.join(full_model_path,'Quality Control/training_evaluation.csv')\n","with open(lossDataCSVpath, 'w') as f:\n"," writer = csv.writer(f)\n"," writer.writerow(['loss','val_loss', 'learning rate'])\n"," for i in range(len(history.history['loss'])):\n"," writer.writerow([history.history['loss'][i], history.history['val_loss'][i], history.history['lr'][i]])\n"," \n","\n","\n","# Displaying the time elapsed for training\n","print(\"------------------------------------------\")\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\", hour, \"hour(s)\", mins,\"min(s)\",round(sec),\"sec(s)\")\n","print(\"------------------------------------------\")\n","\n","#Create a pdf document with training summary\n","\n","pdf_export(trained = True, augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"_0Hynw3-xHp1"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"id":"eAJzMwPA6tlH","cellView":"form"},"source":["#@markdown ###Do you want to assess the model you just trained ?\n","\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","\n","full_QC_model_path = os.path.join(QC_model_path, QC_model_name)\n","if os.path.exists(os.path.join(full_QC_model_path, 'weights_best.hdf5')):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"dhJROwlAMv5o"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"id":"vMzSP50kMv5p","cellView":"form"},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","\n","epochNumber = []\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","with open(os.path.join(full_QC_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(os.path.join(full_QC_model_path, 'Quality Control', 'lossCurvePlots.png'),bbox_inches='tight',pad_inches=0)\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"X5_92nL2xdP6"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","This section will calculate the Intersection over Union score for all the images provided in the Source_QC_folder and Target_QC_folder. The result for one of the image will also be displayed.\n","\n","The **Intersection over Union** metric is a method that can be used to quantify the percent overlap between the target mask and your prediction output. **Therefore, the closer to 1, the better the performance.** This metric can be used to assess the quality of your model to accurately predict nuclei. \n","\n","The Input, Ground Truth, Prediction and IoU maps are shown below for the last example in the QC set.\n","\n"," The results for all QC examples can be found in the \"*Quality Control*\" folder which is located inside your \"model_folder\".\n","\n","### **Thresholds for image masks**\n","\n"," Since the output from Unet is not a binary mask, the output images are converted to binary masks using thresholding. This section will test different thresholds (from 0 to 255) to find the one yielding the best IoU score compared with the ground truth. The best threshold for each image and the average of these thresholds will be displayed below. **These values can be a guideline when creating masks for unseen data in section 6.**"]},{"cell_type":"code","metadata":{"id":"w90MdriMxhjD","cellView":"form"},"source":["# ------------- User input ------------\n","#@markdown ##Choose the folders that contain your Quality Control dataset\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","\n","# ------------- Initialise folders ------------\n","# Create a quality control/Prediction Folder\n","prediction_QC_folder = os.path.join(full_QC_model_path, 'Quality Control', 'Prediction')\n","if os.path.exists(prediction_QC_folder):\n"," shutil.rmtree(prediction_QC_folder)\n","\n","os.makedirs(prediction_QC_folder)\n","\n","\n","# ------------- Prepare the model and run predictions ------------\n","\n","# Load the model\n","unet = load_model(os.path.join(full_QC_model_path, 'weights_best.hdf5'), custom_objects={'_weighted_binary_crossentropy': weighted_binary_crossentropy(np.ones(2))})\n","Input_size = unet.layers[0].output_shape[1:3]\n","print('Model input size: '+str(Input_size[0])+'x'+str(Input_size[1]))\n","\n","# Create a list of sources\n","source_dir_list = os.listdir(Source_QC_folder)\n","number_of_dataset = len(source_dir_list)\n","print('Number of dataset found in the folder: '+str(number_of_dataset))\n","\n","predictions = []\n","for i in tqdm(range(number_of_dataset)):\n"," predictions.append(predict_as_tiles(os.path.join(Source_QC_folder, source_dir_list[i]), unet))\n","\n","\n","# Save the results in the folder along with the masks according to the set threshold\n","saveResult(prediction_QC_folder, predictions, source_dir_list, prefix=prediction_prefix, threshold=None)\n","\n","#-----------------------------Calculate Metrics----------------------------------------#\n","\n","f = plt.figure(figsize=((5,5)))\n","\n","with open(os.path.join(full_QC_model_path,'Quality Control', 'QC_metrics_'+QC_model_name+'.csv'), \"w\", newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"File name\",\"IoU\", \"IoU-optimised threshold\"]) \n","\n"," # Initialise the lists \n"," filename_list = []\n"," best_threshold_list = []\n"," best_IoU_score_list = []\n","\n"," for filename in os.listdir(Source_QC_folder):\n","\n"," if not os.path.isdir(os.path.join(Source_QC_folder, filename)):\n"," print('Running QC on: '+filename)\n"," test_input = io.imread(os.path.join(Source_QC_folder, filename), as_gray=True)\n"," test_ground_truth_image = io.imread(os.path.join(Target_QC_folder, filename), as_gray=True)\n","\n"," (threshold_list, iou_scores_per_threshold) = getIoUvsThreshold(os.path.join(prediction_QC_folder, prediction_prefix+filename), os.path.join(Target_QC_folder, filename))\n"," plt.plot(threshold_list,iou_scores_per_threshold, label=filename)\n","\n"," # Here we find which threshold yielded the highest IoU score for image n.\n"," best_IoU_score = max(iou_scores_per_threshold)\n"," best_threshold = iou_scores_per_threshold.index(best_IoU_score)\n","\n"," # Write the results in the CSV file\n"," writer.writerow([filename, str(best_IoU_score), str(best_threshold)])\n","\n"," # Here we append the best threshold and score to the lists\n"," filename_list.append(filename)\n"," best_IoU_score_list.append(best_IoU_score)\n"," best_threshold_list.append(best_threshold)\n","\n","# Display the IoV vs Threshold plot\n","plt.title('IoU vs. Threshold')\n","plt.ylabel('IoU')\n","plt.xlabel('Threshold value')\n","plt.legend()\n","plt.savefig(full_QC_model_path+'/Quality Control/'+QC_model_name+'_IoUvsThresholdPlot.png',bbox_inches='tight',pad_inches=0)\n","plt.show()\n","\n","\n","# Table with metrics as dataframe output\n","pdResults = pd.DataFrame(index = filename_list)\n","pdResults[\"IoU\"] = best_IoU_score_list\n","pdResults[\"IoU-optimised threshold\"] = best_threshold_list\n","\n","average_best_threshold = sum(best_threshold_list)/len(best_threshold_list)\n","\n","\n","# ------------- For display ------------\n","print('--------------------------------------------------------------')\n","@interact\n","def show_QC_results(file=os.listdir(Source_QC_folder)):\n"," \n"," plt.figure(figsize=(25,5))\n"," #Input\n"," plt.subplot(1,4,1)\n"," plt.axis('off')\n"," plt.imshow(plt.imread(os.path.join(Source_QC_folder, file)), aspect='equal', cmap='gray', interpolation='nearest')\n"," plt.title('Input')\n","\n"," #Ground-truth\n"," plt.subplot(1,4,2)\n"," plt.axis('off')\n"," test_ground_truth_image = io.imread(os.path.join(Target_QC_folder, file),as_gray=True)\n"," plt.imshow(test_ground_truth_image, aspect='equal', cmap='Greens')\n"," plt.title('Ground Truth')\n","\n"," #Prediction\n"," plt.subplot(1,4,3)\n"," plt.axis('off')\n"," test_prediction = plt.imread(os.path.join(prediction_QC_folder, prediction_prefix+file))\n"," test_prediction_mask = np.empty_like(test_prediction)\n"," test_prediction_mask[test_prediction > average_best_threshold] = 255\n"," test_prediction_mask[test_prediction <= average_best_threshold] = 0\n"," plt.imshow(test_prediction_mask, aspect='equal', cmap='Purples')\n"," plt.title('Prediction')\n","\n"," #Overlay\n"," plt.subplot(1,4,4)\n"," plt.axis('off')\n"," plt.imshow(test_ground_truth_image, cmap='Greens')\n"," plt.imshow(test_prediction_mask, alpha=0.5, cmap='Purples')\n"," metrics_title = 'Overlay (IoU: ' + str(round(pdResults.loc[file][\"IoU\"],3)) + ' T: ' + str(round(pdResults.loc[file][\"IoU-optimised threshold\"])) + ')'\n"," plt.title(metrics_title)\n"," plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","\n","\n","print('--------------------------------------------------------------')\n","print('Best average threshold is: '+str(round(average_best_threshold)))\n","print('--------------------------------------------------------------')\n","\n","pdResults.head()\n","\n","qc_pdf_export()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-tJeeJjLnRkP"},"source":["# **6. Using the trained model**\n","\n","---\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"d8wuQGjoq6eN"},"source":["## **6.1 Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.1) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder.\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images.\n","\n"," Once the predictions are complete the cell will display a random example prediction beside the input image and the calculated mask for visual inspection.\n","\n"," **Troubleshooting:** If there is a low contrast image warning when saving the images, this may be due to overfitting of the model to the data. It may result in images containing only a single colour. Train the network again with different network hyperparameters."]},{"cell_type":"code","metadata":{"id":"y2TD5p7MZrEb","cellView":"form"},"source":["\n","\n","# ------------- Initial user input ------------\n","#@markdown ###Provide the path to your dataset and to the folder where the prediction will be saved (Result folder), then play the cell to predict output on your unseen images.\n","Data_folder = '' #@param {type:\"string\"}\n","Results_folder = '' #@param {type:\"string\"}\n","\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","\n","# ------------- Failsafes ------------\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","full_Prediction_model_path = os.path.join(Prediction_model_path, Prediction_model_name)\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","\n","# ------------- Prepare the model and run predictions ------------\n","\n","# Load the model and prepare generator\n","\n","\n","\n","unet = load_model(os.path.join(Prediction_model_path, Prediction_model_name, 'weights_best.hdf5'), custom_objects={'_weighted_binary_crossentropy': weighted_binary_crossentropy(np.ones(2))})\n","Input_size = unet.layers[0].output_shape[1:3]\n","print('Model input size: '+str(Input_size[0])+'x'+str(Input_size[1]))\n","\n","# Create a list of sources\n","source_dir_list = os.listdir(Data_folder)\n","number_of_dataset = len(source_dir_list)\n","print('Number of dataset found in the folder: '+str(number_of_dataset))\n","\n","predictions = []\n","for i in tqdm(range(number_of_dataset)):\n"," predictions.append(predict_as_tiles(os.path.join(Data_folder, source_dir_list[i]), unet))\n"," # predictions.append(prediction(os.path.join(Data_folder, source_dir_list[i]), os.path.join(Prediction_model_path, Prediction_model_name)))\n","\n","\n","# Save the results in the folder along with the masks according to the set threshold\n","saveResult(Results_folder, predictions, source_dir_list, prefix=prediction_prefix, threshold=None)\n","\n","\n","# ------------- For display ------------\n","print('--------------------------------------------------------------')\n","\n","\n","def show_prediction_mask(file=os.listdir(Data_folder), threshold=(0,255,1)):\n","\n"," plt.figure(figsize=(18,6))\n"," # Wide-field\n"," plt.subplot(1,3,1)\n"," plt.axis('off')\n"," img_Source = plt.imread(os.path.join(Data_folder, file))\n"," plt.imshow(img_Source, cmap='gray')\n"," plt.title('Source image',fontsize=15)\n"," # Prediction\n"," plt.subplot(1,3,2)\n"," plt.axis('off')\n"," img_Prediction = plt.imread(os.path.join(Results_folder, prediction_prefix+file))\n"," plt.imshow(img_Prediction, cmap='gray')\n"," plt.title('Prediction',fontsize=15)\n","\n"," # Thresholded mask\n"," plt.subplot(1,3,3)\n"," plt.axis('off')\n"," img_Mask = convert2Mask(img_Prediction, threshold)\n"," plt.imshow(img_Mask, cmap='gray')\n"," plt.title('Mask (Threshold: '+str(round(threshold))+')',fontsize=15)\n","\n","\n","interact(show_prediction_mask, continuous_update=False);\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"stS96mFZLMOU"},"source":["## **6.2. Export results as masks**\n","---\n"]},{"cell_type":"code","metadata":{"id":"qb5ZmFstLNbR","cellView":"form"},"source":["\n","# @markdown #Play this cell to save results as masks with the chosen threshold\n","threshold = 120#@param {type:\"number\"}\n","\n","saveResult(Results_folder, predictions, source_dir_list, prefix=prediction_prefix, threshold=threshold)\n","print('-------------------')\n","print('Masks were saved in: '+Results_folder)\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"BphZ0wBrC2Zw"},"source":["# **7. Version log**\n","\n","---\n","**v1.13**: \n","* This version now includes an automatic restart allowing to set the h5py library to v2.10. \n","* The section 1 and 2 are now swapped for better export of *requirements.txt*. \n","* This version also now includes built-in version check and the version log that you're reading now.\n","\n","---"]},{"cell_type":"markdown","metadata":{"id":"UvSlTaH14s3t"},"source":["#**Thank you for using 2D U-Net!**\n"]}]} \ No newline at end of file diff --git a/Colab_notebooks/U-Net_3D_ZeroCostDL4Mic.ipynb b/Colab_notebooks/U-Net_3D_ZeroCostDL4Mic.ipynb index f0263246..0c5b0e02 100644 --- a/Colab_notebooks/U-Net_3D_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/U-Net_3D_ZeroCostDL4Mic.ipynb @@ -1,2403 +1 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "IkSguVy8Xv83" - }, - "source": [ - "# **U-Net (3D)**\n", - " ---\n", - "\n", - " The 3D U-Net was first introduced by [Çiçek et al](https://arxiv.org/abs/1606.06650) for learning dense volumetric segmentations from sparsely annotated ground-truth data building upon the original U-Net architecture by [Ronneberger et al](https://arxiv.org/abs/1505.04597). \n", - "\n", - "**This particular implementation allows supervised learning between any two types of 3D image data. If you are interested in image segmentation of 2D datasets, you should use the 2D U-Net notebook instead.**\n", - "\n", - "---\n", - "\n", - "*Disclaimer*:\n", - "\n", - "This notebook is part of the *Zero-Cost Deep-Learning to Enhance Microscopy* project ([ZeroCostDL4Mic](/~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki)) jointly developed by the [Jacquemet](https://cellmig.org/) and [Henriques](https://henriqueslab.github.io/) laboratories and created by Daniel Krentzel.\n", - "\n", - "This notebook is laregly based on the following paper: \n", - "\n", - "[**3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation**](https://arxiv.org/pdf/1606.06650.pdf) by Özgün Çiçek *et al.* published on arXiv in 2016\n", - "\n", - "The following two Python libraries play an important role in the notebook: \n", - "\n", - "1. [**Elasticdeform**](/~https://github.com/gvtulder/elasticdeform)\n", - " by Gijs van Tulder was used to augment the 3D training data using elastic grid-based deformations as described in the original 3D U-Net paper. \n", - "\n", - "2. [**Tifffile**](/~https://github.com/cgohlke/tifffile) by Christoph Gohlke is a great library for reading and writing TIFF files. \n", - "\n", - "3. [**Imgaug**](/~https://github.com/aleju/imgaug) by Alexander Jung *et al.* is an amazing library for image augmentation in machine learning - it is the most complete and extensive image augmentation package I have found to date. \n", - "\n", - "The [example dataset](https://www.epfl.ch/labs/cvlab/data/data-em/) represents a 5x5x5µm section taken from the CA1 hippocampus region of the brain with annotated mitochondria and was acquired by Graham Knott and Marco Cantoni at EPFL.\n", - "\n", - "\n", - "**Please also cite the original paper and relevant Python libraries when using or developing this notebook.**" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jWAz2i7RdxUV" - }, - "source": [ - "# **How to use this notebook?**\n", - "\n", - "---\n", - "\n", - "Video describing how to use ZeroCostDL4Mic notebooks are available on youtube:\n", - " - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n", - " - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n", - "\n", - "\n", - "---\n", - "###**Structure of a notebook**\n", - "\n", - "The notebook contains two types of cells: \n", - "\n", - "**Text cells** provide information and can be modified by double-clicking the cell. You are currently reading a text cell. You can create a new one by clicking `+ Text`.\n", - "\n", - "**Code cells** contain code which can be modfied by selecting the cell. To execute the cell, move your cursor to the `[]`-symbol on the left side of the cell (a play button should appear). Click it to execute the cell. Once the cell is fully executed, the animation stops. You can create a new coding cell by clicking `+ Code`.\n", - "\n", - "---\n", - "###**Table of contents, Code snippets** and **Files**\n", - "\n", - "Three tabs are located on the upper left side of the notebook:\n", - "\n", - "1. *Table of contents* contains the structure of the notebook. Click the headers to move quickly between sections.\n", - "\n", - "2. *Code snippets* provides a wide array of example code specific to Google Colab. You can ignore this when using this notebook.\n", - "\n", - "3. *Files* displays the current working directory. We will mount your Google Drive in Section 1.2. so that you can access your files and save them permanently.\n", - "\n", - "**Important:** All uploaded files are purged once the runtime ends.\n", - "\n", - "**Note:** The directory *sample data* in *Files* contains default files. Do not upload anything there!\n", - "\n", - "---\n", - "###**Making changes to the notebook**\n", - "\n", - "**You can make a copy** of the notebook and save it to your Google Drive by clicking *File* -> *Save a copy in Drive*.\n", - "\n", - "To **edit a cell**, double click on the text. This will either display the source code (in code cells) or the [markdown](https://colab.research.google.com/notebooks/markdown_guide.ipynb#scrollTo=70pYkR9LiOV0) (in text cells).\n", - "You can use `#` in code cells to comment out parts of the code. This allows you to keep the original piece of code while not executing it." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "gKDLkLWUd-YX" - }, - "source": [ - "#**0. Before getting started**\n", - "---\n", - "\n", - "As the network operates in three dimensions, certain consideration should be given to correctly pre-processing the data. Ensure that the structure of interest does not substantially change between slices - image volumes with isotropic pixelsizes are ideal for this architecture.\n", - "\n", - "Each image volume must be provided as an **8-bit** or **binary multipage TIFF file** to maintain the correct ordering of individual image slices. If more than one image volume has been annotated, source and target files must be named identically and placed in separate directories. In case only one image volume has been annotated, source and target file do not have to be placed in separate directories and can be named differently, as long as their paths are explicitly provided in Section 3. \n", - "\n", - "**Prepare two datasets** (*training* and *testing*) for quality control puproses. Make sure that the *testing* dataset does not overlap with the *training* dataset and is ideally sourced from a different acquisiton and sample to ensure robustness of the trained model. \n", - "\n", - "\n", - "---\n", - "\n", - "\n", - "### **Directory structure**\n", - "\n", - "Make sure to adhere to one of the following directory structures. If only one annotated training volume exists, choose the first structure. In case more than one training volume is available, choose the second structure.\n", - "\n", - "**Structure 1:** Only one training volume\n", - "```\n", - "path/to/directory/with/one/training/volume\n", - "│--training_source.tif\n", - "│--training_target.tif\n", - "| \n", - "│--testing_source.tif\n", - "|--testing_target.tif \n", - "|\n", - "|--data_to_predict_on.tif\n", - "|--prediction_results.tif\n", - "\n", - "```\n", - "**Structure 2:** Various training volumes\n", - "```\n", - "path/to/directory/with/various/training/volumes\n", - "│--testing_source.tif\n", - "|--testing_target.tif \n", - "|\n", - "└───training\n", - "| └───source\n", - "| | |--training_volume_one.tif\n", - "| | |--training_volume_two.tif\n", - "| | |--...\n", - "| | |--training_volume_n.tif\n", - "| |\n", - "| └───target\n", - "| |--training_volume_one.tif\n", - "| |--training_volume_two.tif\n", - "| |--...\n", - "| |--training_volume_n.tif\n", - "|\n", - "|--data_to_predict_on.tif\n", - "|--prediction_results.tif\n", - "```\n", - "**Note:** Naming directories is completely up to you, as long as the paths are correctly specified throughout the notebook.\n", - "\n", - "\n", - "---\n", - "\n", - "\n", - "### **Important note**\n", - "\n", - "* If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do so), you will need to run **Sections 1 - 4**, then use **Section 5** to assess the quality of your model and **Section 6** to run predictions using the model that you trained.\n", - "\n", - "* If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **Sections 1 and 2** to set up the notebook, then use **Section 5** to assess the quality of your model.\n", - "\n", - "* If you only wish to **Run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **Sections 1 and 2** to set up the notebook, then use **Section 6** to run the predictions on the desired model.\n", - "---" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "M-GZMaL7pd8a" - }, - "outputs": [], - "source": [ - "#@markdown ##**Download example dataset**\n", - "\n", - "#@markdown This usually takes a few minutes. The images are saved in *example_dataset*.\n", - "\n", - "import requests \n", - "import os\n", - "from tqdm.notebook import tqdm \n", - "\n", - "def make_directory(dir):\n", - " if not os.path.exists(dir):\n", - " os.makedirs(dir)\n", - "\n", - "def download_from_url(url, save_as):\n", - " file_url = url\n", - " r = requests.get(file_url, stream=True) \n", - " \n", - " with open(save_as, 'wb') as file: \n", - " for block in tqdm(r.iter_content(chunk_size = 1024), desc = 'Downloading ' + os.path.basename(save_as), total=126875, ncols=1000):\n", - " if block:\n", - " file.write(block) \n", - "\n", - "\n", - "make_directory('example_dataset')\n", - "\n", - "download_from_url('https://documents.epfl.ch/groups/c/cv/cvlab-unit/www/data/%20ElectronMicroscopy_Hippocampus/training.tif', 'example_dataset/training.tif')\n", - "download_from_url('https://documents.epfl.ch/groups/c/cv/cvlab-unit/www/data/%20ElectronMicroscopy_Hippocampus/training_groundtruth.tif', 'example_dataset/training_groundtruth.tif')\n", - "download_from_url('https://documents.epfl.ch/groups/c/cv/cvlab-unit/www/data/%20ElectronMicroscopy_Hippocampus/testing.tif', 'example_dataset/testing.tif')\n", - "download_from_url('https://documents.epfl.ch/groups/c/cv/cvlab-unit/www/data/%20ElectronMicroscopy_Hippocampus/testing_groundtruth.tif', 'example_dataset/testing_groundtruth.tif')\n", - "\n", - "print('Example dataset successfully downloaded!')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "n4yWFoJNnoin" - }, - "source": [ - "# **1. Initialise the Colab session**\n", - "\n", - "\n", - "\n", - "\n", - "---\n", - "\n", - "\n", - "\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "DMNHVZfHmbKb" - }, - "source": [ - "\n", - "## **1.1. Check for GPU access**\n", - "---\n", - "\n", - "By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n", - "\n", - "Go to **Runtime -> Change the Runtime type**\n", - "\n", - "**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n", - "\n", - "**Accelerator: GPU** *(Graphics processing unit)*\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "zCvebubeSaGY" - }, - "outputs": [], - "source": [ - "#@markdown ##Run this cell to check if you have GPU access\n", - "%tensorflow_version 1.x\n", - "\n", - "import tensorflow as tf\n", - "if tf.test.gpu_device_name()=='':\n", - " print('You do not have GPU access.') \n", - " print('Did you change your runtime?') \n", - " print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n", - " print('Expect slow performance. To access GPU try reconnecting later')\n", - "\n", - "else:\n", - " print('You have GPU access')\n", - " !nvidia-smi\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "sNIVx8_CLolt" - }, - "source": [ - "## **1.2. Mount Google Drive**\n", - "---\n", - " To use this notebook with your **own data**, place it in a folder on **Google Drive** following one of the directory structures outlined in **Section 0**.\n", - "\n", - "1. **Run** the **cell** below to mount your Google Drive and follow the link. \n", - "\n", - "2. **Sign in** to your Google account and press 'Allow'. \n", - "\n", - "3. Next, copy the **authorization code**, paste it into the cell and press enter. This will allow Colab to read and write data from and to your Google Drive. \n", - "\n", - "4. Once this is done, your data can be viewed in the **Files tab** on the top left of the notebook after hitting 'Refresh'." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "01Djr8v-5pPk" - }, - "outputs": [], - "source": [ - "#@markdown ##Play the cell to connect your Google Drive to Colab\n", - "\n", - "# mount user's Google Drive to Google Colab.\n", - "from google.colab import drive\n", - "drive.mount('/content/gdrive')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "zxELU7CIp4oF" - }, - "outputs": [], - "source": [ - "#@markdown ##Unzip pre-trained model directory\n", - "\n", - "#@markdown 1. Upload a zipped model directory using the *Files* tab\n", - "#@markdown 2. Run this cell to unzip your model file\n", - "#@markdown 3. The model directory will appear in the *Files* tab \n", - "\n", - "from google.colab import files\n", - "\n", - "zipped_model_file = \"\" #@param {type:\"string\"}\n", - "\n", - "!unzip \"$zipped_model_file\"" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "AdN8B91xZO0x" - }, - "source": [ - "# **2. Install 3D U-Net dependencies**\n", - "---\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "fq21zJVFNASx" - }, - "outputs": [], - "source": [ - "#@markdown ##Install dependencies and instantiate network\n", - "Notebook_version = ['1.12']\n", - "#Put the imported code and libraries here\n", - "\n", - "import sys\n", - "before = [str(m) for m in sys.modules]\n", - "\n", - "!pip install fpdf\n", - "from __future__ import absolute_import, division, print_function, unicode_literals\n", - "\n", - "try:\n", - " import elasticdeform\n", - "except:\n", - " !pip install elasticdeform\n", - " import elasticdeform\n", - "\n", - "try:\n", - " import tifffile\n", - "except:\n", - " !pip install tifffile\n", - " import tifffile\n", - "\n", - "try:\n", - " import imgaug.augmenters as iaa\n", - "except:\n", - " !pip install imgaug\n", - " import imgaug.augmenters as iaa\n", - "\n", - "import os\n", - "import csv\n", - "import random\n", - "import h5py\n", - "import imageio\n", - "import math\n", - "import shutil\n", - "\n", - "import pandas as pd\n", - "from glob import glob\n", - "from tqdm import tqdm\n", - "\n", - "from skimage import transform\n", - "from skimage import exposure\n", - "from skimage import color\n", - "from skimage import io\n", - "\n", - "from scipy.ndimage import zoom\n", - "\n", - "import matplotlib.pyplot as plt\n", - "\n", - "import numpy as np\n", - "import tensorflow as tf\n", - "\n", - "from keras import backend as K\n", - "\n", - "from keras.layers import Conv3D\n", - "from keras.layers import BatchNormalization\n", - "from keras.layers import ReLU\n", - "from keras.layers import MaxPooling3D\n", - "from keras.layers import Conv3DTranspose\n", - "from keras.layers import Input\n", - "from keras.layers import Concatenate\n", - "\n", - "from keras.models import Model\n", - "\n", - "from keras.utils import Sequence\n", - "\n", - "from keras.callbacks import ModelCheckpoint\n", - "from keras.callbacks import CSVLogger\n", - "from keras.callbacks import Callback\n", - "\n", - "from keras.metrics import RootMeanSquaredError\n", - "\n", - "from ipywidgets import interact\n", - "from ipywidgets import interactive\n", - "from ipywidgets import fixed\n", - "from ipywidgets import interact_manual \n", - "import ipywidgets as widgets\n", - "\n", - "from fpdf import FPDF, HTMLMixin\n", - "from datetime import datetime\n", - "import subprocess\n", - "from pip._internal.operations.freeze import freeze\n", - "import time\n", - "\n", - "from skimage import io\n", - "import matplotlib\n", - "\n", - "print(\"Dependencies installed and imported.\")\n", - "\n", - "# Define MultiPageTiffGenerator class\n", - "class MultiPageTiffGenerator(Sequence):\n", - "\n", - " def __init__(self,\n", - " source_path,\n", - " target_path,\n", - " batch_size=1,\n", - " shape=(128,128,32,1),\n", - " augment=False,\n", - " augmentations=[],\n", - " deform_augment=False,\n", - " deform_augmentation_params=(5,3,4),\n", - " val_split=0.2,\n", - " is_val=False,\n", - " random_crop=True,\n", - " downscale=1,\n", - " binary_target=False):\n", - "\n", - " # If directory with various multi-page tiffiles is provided read as list\n", - " if os.path.isfile(source_path):\n", - " self.dir_flag = False\n", - " self.source = tifffile.imread(source_path)\n", - " if binary_target:\n", - " self.target = tifffile.imread(target_path).astype(np.bool)\n", - " else:\n", - " self.target = tifffile.imread(target_path)\n", - "\n", - " elif os.path.isdir(source_path):\n", - " self.dir_flag = True\n", - " self.source_dir_list = glob(os.path.join(source_path, '*'))\n", - " self.target_dir_list = glob(os.path.join(target_path, '*'))\n", - "\n", - " self.source_dir_list.sort()\n", - " self.target_dir_list.sort()\n", - "\n", - " self.shape = shape\n", - " self.batch_size = batch_size\n", - " self.augment = augment\n", - " self.val_split = val_split\n", - " self.is_val = is_val\n", - " self.random_crop = random_crop\n", - " self.downscale = downscale\n", - " self.binary_target = binary_target\n", - " self.deform_augment = deform_augment\n", - " self.on_epoch_end()\n", - " \n", - " if self.augment:\n", - " # pass list of augmentation functions \n", - " self.seq = iaa.Sequential(augmentations, random_order=True) # apply augmenters in random order\n", - " if self.deform_augment:\n", - " self.deform_sigma, self.deform_points, self.deform_order = deform_augmentation_params\n", - "\n", - " def __len__(self):\n", - " # If various multi-page tiff files provided sum all images within each\n", - " if self.augment:\n", - " augment_factor = 4\n", - " else:\n", - " augment_factor = 1\n", - " \n", - " if self.dir_flag:\n", - " num_of_imgs = 0\n", - " for tiff_path in self.source_dir_list:\n", - " num_of_imgs += tifffile.imread(tiff_path).shape[0]\n", - " xy_shape = tifffile.imread(self.source_dir_list[0]).shape[1:]\n", - "\n", - " if self.is_val:\n", - " if self.random_crop:\n", - " crop_volume = self.shape[0] * self.shape[1] * self.shape[2]\n", - " volume = xy_shape[0] * xy_shape[1] * self.val_split * num_of_imgs\n", - " return math.floor(augment_factor * volume / (crop_volume * self.batch_size * self.downscale))\n", - " else:\n", - " return math.floor(self.val_split * num_of_imgs / self.batch_size)\n", - " else:\n", - " if self.random_crop:\n", - " crop_volume = self.shape[0] * self.shape[1] * self.shape[2]\n", - " volume = xy_shape[0] * xy_shape[1] * (1 - self.val_split) * num_of_imgs\n", - " return math.floor(augment_factor * volume / (crop_volume * self.batch_size * self.downscale))\n", - "\n", - " else:\n", - " return math.floor(augment_factor*(1 - self.val_split) * num_of_imgs/self.batch_size)\n", - " else:\n", - " if self.is_val:\n", - " if self.random_crop:\n", - " crop_volume = self.shape[0] * self.shape[1] * self.shape[2]\n", - " volume = self.source.shape[0] * self.source.shape[1] * self.val_split * self.source.shape[2]\n", - " return math.floor(augment_factor * volume / (crop_volume * self.batch_size * self.downscale))\n", - " else:\n", - " return math.floor((self.val_split * self.source.shape[0] / self.batch_size))\n", - " else:\n", - " if self.random_crop:\n", - " crop_volume = self.shape[0] * self.shape[1] * self.shape[2]\n", - " volume = self.source.shape[0] * self.source.shape[1] * (1 - self.val_split) * self.source.shape[2]\n", - " return math.floor(augment_factor * volume / (crop_volume * self.batch_size * self.downscale))\n", - " else:\n", - " return math.floor(augment_factor * (1 - self.val_split) * self.source.shape[0] / self.batch_size)\n", - "\n", - " def __getitem__(self, idx):\n", - " source_batch = np.empty((self.batch_size,\n", - " self.shape[0],\n", - " self.shape[1],\n", - " self.shape[2],\n", - " self.shape[3]))\n", - " target_batch = np.empty((self.batch_size,\n", - " self.shape[0],\n", - " self.shape[1],\n", - " self.shape[2],\n", - " self.shape[3]))\n", - "\n", - " for batch in range(self.batch_size):\n", - " # Modulo operator ensures IndexError is avoided\n", - " stack_start = self.batch_list[(idx+batch*self.shape[2])%len(self.batch_list)]\n", - "\n", - " if self.dir_flag:\n", - " self.source = tifffile.imread(self.source_dir_list[stack_start[0]])\n", - " if self.binary_target:\n", - " self.target = tifffile.imread(self.target_dir_list[stack_start[0]]).astype(np.bool)\n", - " else:\n", - " self.target = tifffile.imread(self.target_dir_list[stack_start[0]])\n", - "\n", - " src_list = []\n", - " tgt_list = []\n", - " for i in range(stack_start[1], stack_start[1]+self.shape[2]):\n", - " src = self.source[i]\n", - " src = transform.downscale_local_mean(src, (self.downscale, self.downscale))\n", - " if not self.random_crop:\n", - " src = transform.resize(src, (self.shape[0], self.shape[1]), mode='constant', preserve_range=True)\n", - " src = self._min_max_scaling(src)\n", - " src_list.append(src)\n", - "\n", - " tgt = self.target[i]\n", - " tgt = transform.downscale_local_mean(tgt, (self.downscale, self.downscale))\n", - " if not self.random_crop:\n", - " tgt = transform.resize(tgt, (self.shape[0], self.shape[1]), mode='constant', preserve_range=True)\n", - " if not self.binary_target:\n", - " tgt = self._min_max_scaling(tgt)\n", - " tgt_list.append(tgt)\n", - "\n", - " if self.random_crop:\n", - " if src.shape[0] == self.shape[0]:\n", - " x_rand = 0\n", - " if src.shape[1] == self.shape[1]:\n", - " y_rand = 0\n", - " if src.shape[0] > self.shape[0]:\n", - " x_rand = np.random.randint(src.shape[0] - self.shape[0])\n", - " if src.shape[1] > self.shape[1]:\n", - " y_rand = np.random.randint(src.shape[1] - self.shape[1])\n", - " if src.shape[0] < self.shape[0] or src.shape[1] < self.shape[1]:\n", - " raise ValueError('Patch shape larger than (downscaled) source shape')\n", - " \n", - " for i in range(self.shape[2]):\n", - " if self.random_crop:\n", - " src = src_list[i]\n", - " tgt = tgt_list[i]\n", - " src_crop = src[x_rand:self.shape[0]+x_rand, y_rand:self.shape[1]+y_rand]\n", - " tgt_crop = tgt[x_rand:self.shape[0]+x_rand, y_rand:self.shape[1]+y_rand]\n", - " else:\n", - " src_crop = src_list[i]\n", - " tgt_crop = tgt_list[i]\n", - "\n", - " source_batch[batch,:,:,i,0] = src_crop\n", - " target_batch[batch,:,:,i,0] = tgt_crop\n", - "\n", - " if self.augment:\n", - " # On-the-fly data augmentation\n", - " source_batch, target_batch = self.augment_volume(source_batch, target_batch)\n", - "\n", - " # Data augmentation by reversing stack\n", - " if np.random.random() > 0.5:\n", - " source_batch, target_batch = source_batch[::-1], target_batch[::-1]\n", - " \n", - " # Data augmentation by elastic deformation\n", - " if np.random.random() > 0.5 and self.deform_augment:\n", - " source_batch, target_batch = self.deform_volume(source_batch, target_batch)\n", - " \n", - " if not self.binary_target:\n", - " target_batch = self._min_max_scaling(target_batch)\n", - " \n", - " return self._min_max_scaling(source_batch), target_batch\n", - " \n", - " else:\n", - " return source_batch, target_batch\n", - "\n", - " def on_epoch_end(self):\n", - " # Validation split performed here\n", - " self.batch_list = []\n", - " # Create batch_list of all combinations of tifffile and stack position\n", - " if self.dir_flag:\n", - " for i in range(len(self.source_dir_list)):\n", - " num_of_pages = tifffile.imread(self.source_dir_list[i]).shape[0]\n", - " if self.is_val:\n", - " start_page = num_of_pages-math.floor(self.val_split*num_of_pages)\n", - " for j in range(start_page, num_of_pages-self.shape[2]):\n", - " self.batch_list.append([i, j])\n", - " else:\n", - " last_page = math.floor((1-self.val_split)*num_of_pages)\n", - " for j in range(last_page-self.shape[2]):\n", - " self.batch_list.append([i, j])\n", - " else:\n", - " num_of_pages = self.source.shape[0]\n", - " if self.is_val:\n", - " start_page = num_of_pages-math.floor(self.val_split*num_of_pages)\n", - " for j in range(start_page, num_of_pages-self.shape[2]):\n", - " self.batch_list.append([0, j])\n", - "\n", - " else:\n", - " last_page = math.floor((1-self.val_split)*num_of_pages)\n", - " for j in range(last_page-self.shape[2]):\n", - " self.batch_list.append([0, j])\n", - " \n", - " if self.is_val and (len(self.batch_list) <= 0):\n", - " raise ValueError('validation_split too small! Increase val_split or decrease z-depth')\n", - " random.shuffle(self.batch_list)\n", - " \n", - " def _min_max_scaling(self, data):\n", - " n = data - np.min(data)\n", - " d = np.max(data) - np.min(data) \n", - " \n", - " return n/d\n", - " \n", - " def class_weights(self):\n", - " ones = 0\n", - " pixels = 0\n", - "\n", - " if self.dir_flag:\n", - " for i in range(len(self.target_dir_list)):\n", - " tgt = tifffile.imread(self.target_dir_list[i]).astype(np.bool)\n", - " ones += np.sum(tgt)\n", - " pixels += tgt.shape[0]*tgt.shape[1]*tgt.shape[2]\n", - " else:\n", - " ones = np.sum(self.target)\n", - " pixels = self.target.shape[0]*self.target.shape[1]*self.target.shape[2]\n", - " p_ones = ones/pixels\n", - " p_zeros = 1-p_ones\n", - "\n", - " # Return swapped probability to increase weight of unlikely class\n", - " return p_ones, p_zeros\n", - "\n", - " def deform_volume(self, src_vol, tgt_vol):\n", - " [src_dfrm, tgt_dfrm] = elasticdeform.deform_random_grid([src_vol, tgt_vol],\n", - " axis=(1, 2, 3),\n", - " sigma=self.deform_sigma,\n", - " points=self.deform_points,\n", - " order=self.deform_order)\n", - " if self.binary_target:\n", - " tgt_dfrm = tgt_dfrm > 0.1\n", - " \n", - " return self._min_max_scaling(src_dfrm), tgt_dfrm \n", - "\n", - " def augment_volume(self, src_vol, tgt_vol):\n", - " src_vol_aug = np.empty(src_vol.shape)\n", - " tgt_vol_aug = np.empty(tgt_vol.shape)\n", - "\n", - " for i in range(src_vol.shape[3]):\n", - " src_vol_aug[:,:,:,i,0], tgt_vol_aug[:,:,:,i,0] = self.seq(images=src_vol[:,:,:,i,0].astype('float16'), \n", - " segmentation_maps=tgt_vol[:,:,:,i,0].astype(bool))\n", - " return self._min_max_scaling(src_vol_aug), tgt_vol_aug\n", - "\n", - " def sample_augmentation(self, idx):\n", - " src, tgt = self.__getitem__(idx)\n", - "\n", - " src_aug, tgt_aug = self.augment_volume(src, tgt)\n", - " \n", - " if self.deform_augment:\n", - " src_aug, tgt_aug = self.deform_volume(src_aug, tgt_aug)\n", - "\n", - " return src_aug, tgt_aug \n", - "\n", - "# Define custom loss and dice coefficient\n", - "def dice_coefficient(y_true, y_pred):\n", - " eps = 1e-6\n", - " y_true_f = K.flatten(y_true)\n", - " y_pred_f = K.flatten(y_pred)\n", - " intersection = K.sum(y_true_f*y_pred_f)\n", - "\n", - " return (2.*intersection)/(K.sum(y_true_f*y_true_f)+K.sum(y_pred_f*y_pred_f)+eps)\n", - "\n", - "def weighted_binary_crossentropy(zero_weight, one_weight):\n", - " def _weighted_binary_crossentropy(y_true, y_pred):\n", - " binary_crossentropy = K.binary_crossentropy(y_true, y_pred)\n", - "\n", - " weight_vector = y_true*one_weight+(1.-y_true)*zero_weight\n", - " weighted_binary_crossentropy = weight_vector*binary_crossentropy\n", - "\n", - " return K.mean(weighted_binary_crossentropy)\n", - "\n", - " return _weighted_binary_crossentropy\n", - "\n", - "# Custom callback showing sample prediction\n", - "class SampleImageCallback(Callback):\n", - "\n", - " def __init__(self, model, sample_data, model_path, save=False):\n", - " self.model = model\n", - " self.sample_data = sample_data\n", - " self.model_path = model_path\n", - " self.save = save\n", - "\n", - " def on_epoch_end(self, epoch, logs={}):\n", - " sample_predict = self.model.predict_on_batch(self.sample_data)\n", - "\n", - " f=plt.figure(figsize=(16,8))\n", - " plt.subplot(1,2,1)\n", - " plt.imshow(self.sample_data[0,:,:,0,0], interpolation='nearest', cmap='gray')\n", - " plt.title('Sample source')\n", - " plt.axis('off');\n", - "\n", - " plt.subplot(1,2,2)\n", - " plt.imshow(sample_predict[0,:,:,0,0], interpolation='nearest', cmap='magma')\n", - " plt.title('Predicted target')\n", - " plt.axis('off');\n", - "\n", - " plt.show()\n", - "\n", - " if self.save:\n", - " plt.savefig(self.model_path + '/epoch_' + str(epoch+1) + '.png')\n", - "\n", - "\n", - "# Define Unet3D class\n", - "class Unet3D:\n", - "\n", - " def __init__(self,\n", - " shape=(256,256,16,1)):\n", - " if isinstance(shape, str):\n", - " shape = eval(shape)\n", - "\n", - " self.shape = shape\n", - " \n", - " input_tensor = Input(self.shape, name='input')\n", - "\n", - " self.model = self.unet_3D(input_tensor)\n", - "\n", - " def down_block_3D(self, input_tensor, filters):\n", - " x = Conv3D(filters=filters, kernel_size=(3,3,3), padding='same')(input_tensor)\n", - " x = BatchNormalization()(x)\n", - " x = ReLU()(x)\n", - "\n", - " x = Conv3D(filters=filters*2, kernel_size=(3,3,3), padding='same')(x)\n", - " x = BatchNormalization()(x)\n", - " x = ReLU()(x)\n", - "\n", - " return x\n", - "\n", - " def up_block_3D(self, input_tensor, concat_layer, filters):\n", - " x = Conv3DTranspose(filters, kernel_size=(2,2,2), strides=(2,2,2))(input_tensor)\n", - "\n", - " x = Concatenate()([x, concat_layer])\n", - "\n", - " x = Conv3D(filters=filters, kernel_size=(3,3,3), padding='same')(x)\n", - " x = BatchNormalization()(x)\n", - " x = ReLU()(x)\n", - "\n", - " x = Conv3D(filters=filters*2, kernel_size=(3,3,3), padding='same')(x)\n", - " x = BatchNormalization()(x)\n", - " x = ReLU()(x)\n", - "\n", - " return x\n", - "\n", - " def unet_3D(self, input_tensor, filters=32):\n", - " d1 = self.down_block_3D(input_tensor, filters=filters)\n", - " p1 = MaxPooling3D(pool_size=(2,2,2), strides=(2,2,2), data_format='channels_last')(d1)\n", - " d2 = self.down_block_3D(p1, filters=filters*2)\n", - " p2 = MaxPooling3D(pool_size=(2,2,2), strides=(2,2,2), data_format='channels_last')(d2)\n", - " d3 = self.down_block_3D(p2, filters=filters*4)\n", - " p3 = MaxPooling3D(pool_size=(2,2,2), strides=(2,2,2), data_format='channels_last')(d3)\n", - "\n", - " d4 = self.down_block_3D(p3, filters=filters*8)\n", - "\n", - " u1 = self.up_block_3D(d4, d3, filters=filters*4)\n", - " u2 = self.up_block_3D(u1, d2, filters=filters*2)\n", - " u3 = self.up_block_3D(u2, d1, filters=filters)\n", - "\n", - " output_tensor = Conv3D(filters=1, kernel_size=(1,1,1), activation='sigmoid')(u3)\n", - "\n", - " return Model(inputs=[input_tensor], outputs=[output_tensor])\n", - "\n", - " def summary(self):\n", - " return self.model.summary()\n", - "\n", - " # Pass generators instead\n", - " def train(self, \n", - " epochs, \n", - " batch_size, \n", - " train_generator,\n", - " val_generator, \n", - " model_path, \n", - " model_name,\n", - " optimizer='adam',\n", - " loss='weighted_binary_crossentropy',\n", - " metrics='dice',\n", - " ckpt_period=1, \n", - " save_best_ckpt_only=False, \n", - " ckpt_path=None):\n", - "\n", - " class_weight_zero, class_weight_one = train_generator.class_weights()\n", - " \n", - " if loss == 'weighted_binary_crossentropy':\n", - " loss = weighted_binary_crossentropy(class_weight_zero, class_weight_one)\n", - " \n", - " if metrics == 'dice':\n", - " metrics = dice_coefficient\n", - "\n", - " self.model.compile(optimizer=optimizer,\n", - " loss=loss,\n", - " metrics=[metrics])\n", - "\n", - " if ckpt_path is not None:\n", - " self.model.load_weights(ckpt_path)\n", - "\n", - " full_model_path = os.path.join(model_path, model_name)\n", - "\n", - " if not os.path.exists(full_model_path):\n", - " os.makedirs(full_model_path)\n", - " \n", - " log_dir = full_model_path + '/Quality Control'\n", - "\n", - " if not os.path.exists(log_dir):\n", - " os.makedirs(log_dir)\n", - " \n", - " ckpt_dir = full_model_path + '/ckpt'\n", - "\n", - " if not os.path.exists(ckpt_dir):\n", - " os.makedirs(ckpt_dir)\n", - "\n", - " csv_out_name = log_dir + '/training_evaluation.csv'\n", - " if ckpt_path is None:\n", - " csv_logger = CSVLogger(csv_out_name)\n", - " else:\n", - " csv_logger = CSVLogger(csv_out_name, append=True)\n", - "\n", - " if save_best_ckpt_only:\n", - " ckpt_name = ckpt_dir + '/' + model_name + '.hdf5'\n", - " else:\n", - " ckpt_name = ckpt_dir + '/' + model_name + '_epoch_{epoch:02d}_val_loss_{val_loss:.4f}.hdf5'\n", - " \n", - " model_ckpt = ModelCheckpoint(ckpt_name,\n", - " verbose=1,\n", - " period=ckpt_period,\n", - " save_best_only=save_best_ckpt_only,\n", - " save_weights_only=True)\n", - "\n", - " sample_batch, __ = val_generator.__getitem__(random.randint(0, len(val_generator)))\n", - " sample_img = SampleImageCallback(self.model, \n", - " sample_batch, \n", - " model_path)\n", - "\n", - " self.model.fit_generator(generator=train_generator,\n", - " validation_data=val_generator,\n", - " validation_steps=math.floor(len(val_generator)/batch_size),\n", - " epochs=epochs,\n", - " callbacks=[csv_logger,\n", - " model_ckpt,\n", - " sample_img])\n", - "\n", - " last_ckpt_name = ckpt_dir + '/' + model_name + '_last.hdf5'\n", - " self.model.save_weights(last_ckpt_name)\n", - "\n", - " def _min_max_scaling(self, data):\n", - " n = data - np.min(data)\n", - " d = np.max(data) - np.min(data) \n", - " \n", - " return n/d\n", - "\n", - " def predict(self, \n", - " input, \n", - " ckpt_path, \n", - " z_range=None, \n", - " downscaling=None, \n", - " true_patch_size=None):\n", - "\n", - " self.model.load_weights(ckpt_path)\n", - "\n", - " if isinstance(downscaling, str):\n", - " downscaling = eval(downscaling)\n", - "\n", - " if math.isnan(downscaling):\n", - " downscaling = None\n", - "\n", - " if isinstance(true_patch_size, str):\n", - " true_patch_size = eval(true_patch_size)\n", - " \n", - " if not isinstance(true_patch_size, tuple): \n", - " if math.isnan(true_patch_size):\n", - " true_patch_size = None\n", - "\n", - " if isinstance(input, str):\n", - " src_volume = tifffile.imread(input)\n", - " elif isinstance(input, np.ndarray):\n", - " src_volume = input\n", - " else:\n", - " raise TypeError('Input is not path or numpy array!')\n", - " \n", - " in_size = src_volume.shape\n", - "\n", - " if downscaling or true_patch_size is not None:\n", - " x_scaling = 0\n", - " y_scaling = 0\n", - "\n", - " if true_patch_size is not None:\n", - " x_scaling += true_patch_size[0]/self.shape[0]\n", - " y_scaling += true_patch_size[1]/self.shape[1]\n", - " if downscaling is not None:\n", - " x_scaling += downscaling\n", - " y_scaling += downscaling\n", - "\n", - " src_list = []\n", - " for i in range(src_volume.shape[0]):\n", - " src_list.append(transform.downscale_local_mean(src_volume[i], (int(x_scaling), int(y_scaling))))\n", - " src_volume = np.array(src_list) \n", - "\n", - " if z_range is not None:\n", - " src_volume = src_volume[z_range[0]:z_range[1]]\n", - "\n", - " src_volume = self._min_max_scaling(src_volume) \n", - "\n", - " src_array = np.zeros((1,\n", - " math.ceil(src_volume.shape[1]/self.shape[0])*self.shape[0], \n", - " math.ceil(src_volume.shape[2]/self.shape[1])*self.shape[1],\n", - " math.ceil(src_volume.shape[0]/self.shape[2])*self.shape[2], \n", - " self.shape[3]))\n", - "\n", - " for i in range(src_volume.shape[0]):\n", - " src_array[0,:src_volume.shape[1],:src_volume.shape[2],i,0] = src_volume[i]\n", - "\n", - " pred_array = np.empty(src_array.shape)\n", - "\n", - " for i in range(math.ceil(src_volume.shape[1]/self.shape[0])):\n", - " for j in range(math.ceil(src_volume.shape[2]/self.shape[1])):\n", - " for k in range(math.ceil(src_volume.shape[0]/self.shape[2])):\n", - " pred_temp = self.model.predict(src_array[:,\n", - " i*self.shape[0]:i*self.shape[0]+self.shape[0],\n", - " j*self.shape[1]:j*self.shape[1]+self.shape[1],\n", - " k*self.shape[2]:k*self.shape[2]+self.shape[2]])\n", - " pred_array[:,\n", - " i*self.shape[0]:i*self.shape[0]+self.shape[0],\n", - " j*self.shape[1]:j*self.shape[1]+self.shape[1],\n", - " k*self.shape[2]:k*self.shape[2]+self.shape[2]] = pred_temp\n", - " \n", - " pred_volume = np.rollaxis(np.squeeze(pred_array), -1)[:src_volume.shape[0],:src_volume.shape[1],:src_volume.shape[2]] \n", - "\n", - " if downscaling is not None:\n", - " pred_list = []\n", - " for i in range(pred_volume.shape[0]):\n", - " pred_list.append(transform.resize(pred_volume[i], (in_size[1], in_size[2]), preserve_range=True))\n", - " pred_volume = np.array(pred_list)\n", - "\n", - " return pred_volume\n", - "\n", - "def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n", - " class MyFPDF(FPDF, HTMLMixin):\n", - " pass\n", - "\n", - " pdf = MyFPDF()\n", - " pdf.add_page()\n", - " pdf.set_right_margin(-1)\n", - " pdf.set_font(\"Arial\", size = 11, style='B') \n", - "\n", - " Network = 'U-Net 3D'\n", - " day = datetime.now()\n", - " datetime_str = str(day)[0:10]\n", - "\n", - " Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n", - " pdf.multi_cell(180, 5, txt = Header, align = 'L') \n", - " \n", - " # add another cell \n", - " if trained:\n", - " training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n", - " pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n", - " pdf.ln(1)\n", - "\n", - " Header_2 = 'Information for your materials and methods:'\n", - " pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n", - "\n", - " all_packages = ''\n", - " for requirement in freeze(local_only=True):\n", - " all_packages = all_packages+requirement+', '\n", - " #print(all_packages)\n", - "\n", - " #Main Packages\n", - " main_packages = ''\n", - " version_numbers = []\n", - " for name in ['tensorflow','numpy','Keras']:\n", - " find_name=all_packages.find(name)\n", - " main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n", - " #Version numbers only here:\n", - " version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n", - "\n", - " cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n", - " cuda_version = cuda_version.stdout.decode('utf-8')\n", - " cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n", - " gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n", - " gpu_name = gpu_name.stdout.decode('utf-8')\n", - " gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n", - " #print(cuda_version[cuda_version.find(', V')+3:-1])\n", - " #print(gpu_name)\n", - "\n", - " if os.path.isdir(training_source):\n", - " shape = io.imread(training_source+'/'+os.listdir(training_source)[0]).shape\n", - " elif os.path.isfile(training_source):\n", - " shape = io.imread(training_source).shape\n", - " else:\n", - " print('Cannot read training data.')\n", - "\n", - " dataset_size = len(train_generator)\n", - "\n", - " text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+') with a batch size of '+str(batch_size)+' and a '+loss_function+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n", - "\n", - " if pretrained_model:\n", - " text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch_size: '+str(patch_size)+') with a batch size of '+str(batch_size)+' and a '+loss_function+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was retrained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n", - "\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - " pdf.multi_cell(190, 5, txt = text, align='L')\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.ln(1)\n", - " pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n", - " pdf.set_font('')\n", - " if augmentation:\n", - " aug_text = 'The dataset was augmented by'\n", - " if add_gaussian_blur == True:\n", - " aug_text = aug_text+'\\n- gaussian blur'\n", - " if add_linear_contrast == True:\n", - " aug_text = aug_text+'\\n- linear contrast'\n", - " if add_additive_gaussian_noise == True:\n", - " aug_text = aug_text+'\\n- additive gaussian noise'\n", - " if augmenters != '':\n", - " aug_text = aug_text+'\\n- imgaug augmentations: '+augmenters\n", - " if add_elastic_deform == True:\n", - " aug_text = aug_text+'\\n- elastic deformation'\n", - " else:\n", - " aug_text = 'No augmentation was used for training.'\n", - " pdf.multi_cell(190, 5, txt=aug_text, align='L')\n", - " pdf.set_font('Arial', size = 11, style = 'B')\n", - " pdf.ln(1)\n", - " pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - " if use_default_advanced_parameters:\n", - " pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n", - " pdf.cell(200, 5, txt='The following parameters were used for training:')\n", - " pdf.ln(1)\n", - " html = \"\"\" \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
ParameterValue
number_of_epochs{0}
batch_size{1}
patch_size{2}
image_pre_processing{3}
validation_split_in_percent{4}
downscaling_in_xy{5}
binary_target{6}
loss_function{7}
metrics{8}
optimizer{9}
checkpointing_period{10}
save_best_only{11}
resume_training{12}
\n", - " \"\"\".format(number_of_epochs,batch_size,str(patch_size[0])+'x'+str(patch_size[1])+'x'+str(patch_size[2]),image_pre_processing, validation_split_in_percent, downscaling_in_xy, str(binary_target), loss_function, metrics, optimizer, checkpointing_period, str(save_best_only), str(resume_training))\n", - " pdf.write_html(html)\n", - "\n", - " #pdf.multi_cell(190, 5, txt = text_2, align='L')\n", - " pdf.set_font(\"Arial\", size = 11, style='B')\n", - " pdf.ln(1)\n", - " pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.cell(30, 5, txt= 'Training_source:', align = 'L', ln=0)\n", - " pdf.set_font('')\n", - " pdf.multi_cell(170, 5, txt = training_source, align = 'L')\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.cell(28, 5, txt= 'Training_target:', align = 'L', ln=0)\n", - " pdf.set_font('')\n", - " pdf.multi_cell(170, 5, txt = training_target, align = 'L')\n", - " #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n", - " pdf.ln(1)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.cell(21, 5, txt= 'Model Path:', align = 'L', ln=0)\n", - " pdf.set_font('')\n", - " pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n", - " pdf.ln(1)\n", - " pdf.cell(60, 5, txt = 'Example Training pair (single slice)', ln=1)\n", - " pdf.ln(1)\n", - " exp_size = io.imread('/content/TrainingDataExample_Unet3D.png').shape\n", - " pdf.image('/content/TrainingDataExample_Unet3D.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n", - " pdf.ln(1)\n", - " ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" bioRxiv (2020).'\n", - " pdf.multi_cell(190, 5, txt = ref_1, align='L')\n", - " ref_2 = '- Unet 3D: Çiçek, Özgün, et al. \"3D U-Net: learning dense volumetric segmentation from sparse annotation.\" International conference on medical image computing and computer-assisted intervention. Springer, Cham, 2016.'\n", - " pdf.multi_cell(190, 5, txt = ref_2, align='L')\n", - " # if Use_Data_augmentation:\n", - " # ref_4 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\n", - " # pdf.multi_cell(190, 5, txt = ref_4, align='L')\n", - " pdf.ln(3)\n", - " reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n", - " pdf.set_font('Arial', size = 11, style='B')\n", - " pdf.multi_cell(190, 5, txt=reminder, align='C')\n", - "\n", - " pdf.output(model_path+'/'+model_name+'/'+model_name+'_training_report.pdf')\n", - "\n", - " print('------------------------------')\n", - " print('PDF report exported in '+model_path+'/'+model_name+'/')\n", - "\n", - "\n", - "def qc_pdf_export():\n", - " class MyFPDF(FPDF, HTMLMixin):\n", - " pass\n", - "\n", - " pdf = MyFPDF()\n", - " pdf.add_page()\n", - " pdf.set_right_margin(-1)\n", - " pdf.set_font(\"Arial\", size = 11, style='B') \n", - "\n", - " Network = 'U-Net 3D'\n", - "\n", - " day = datetime.now()\n", - " datetime_str = str(day)[0:10]\n", - "\n", - " Header = 'Quality Control report for '+Network+' model ('+qc_model_name+')\\nDate: '+datetime_str\n", - " pdf.multi_cell(180, 5, txt = Header, align = 'L') \n", - "\n", - " all_packages = ''\n", - " for requirement in freeze(local_only=True):\n", - " all_packages = all_packages+requirement+', '\n", - "\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 11, style = 'B')\n", - " pdf.ln(2)\n", - " pdf.cell(190, 5, txt = 'Loss curves', ln=1, align='L')\n", - " pdf.ln(1)\n", - " if os.path.exists(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/lossCurvePlots.png'):\n", - " exp_size = io.imread(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/lossCurvePlots.png').shape\n", - " pdf.image(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n", - " else:\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size=10)\n", - " pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.')\n", - " pdf.ln(2)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.ln(3)\n", - " pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n", - " pdf.ln(1)\n", - " exp_size = io.imread(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/QC_example_data.png').shape\n", - " pdf.image(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n", - " pdf.ln(1)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 11, style = 'B')\n", - " pdf.ln(1)\n", - " pdf.cell(180, 5, txt = 'IoU threshold optimisation', align='L', ln=1)\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - " pdf.ln(1)\n", - " pdf.cell(120, 5, txt='Highest IoU is {:.4f} with a threshold of {}'.format(best_iou, best_thresh), align='L', ln=1)\n", - " pdf.ln(2)\n", - " exp_size = io.imread(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/QC_IoU_analysis.png').shape\n", - " pdf.image(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/QC_IoU_analysis.png', x=16, y=None, w = round(exp_size[1]/6), h = round(exp_size[0]/6))\n", - " pdf.ln(1)\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - " ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" bioRxiv (2020).'\n", - " pdf.multi_cell(190, 5, txt = ref_1, align='L')\n", - " ref_2 = '- Unet 3D: Çiçek, Özgün, et al. \"3D U-Net: learning dense volumetric segmentation from sparse annotation.\" International conference on medical image computing and computer-assisted intervention. Springer, Cham, 2016.'\n", - " pdf.multi_cell(190, 5, txt = ref_2, align='L')\n", - "\n", - " pdf.ln(3)\n", - " reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n", - "\n", - " pdf.set_font('Arial', size = 11, style='B')\n", - " pdf.multi_cell(190, 5, txt=reminder, align='C')\n", - "\n", - " pdf.output(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/'+qc_model_name+'_QC_report.pdf')\n", - "\n", - " print('------------------------------')\n", - " print('QC PDF report exported in '+os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/')\n", - "\n", - "\n", - "# -------------- Other definitions -----------\n", - "W = '\\033[0m' # white (normal)\n", - "R = '\\033[31m' # red\n", - "prediction_prefix = 'Predicted_'\n", - "\n", - "\n", - "print('-------------------')\n", - "print('U-Net 3D and dependencies installed.')\n", - "\n", - "# Colors for the warning messages\n", - "class bcolors:\n", - " WARNING = '\\033[31m'\n", - " NORMAL = '\\033[0m' # white (normal)\n", - " \n", - "\n", - "# Check if this is the latest version of the notebook\n", - "Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n", - "\n", - "if Notebook_version == list(Latest_notebook_version.columns):\n", - " print(\"This notebook is up-to-date.\")\n", - "\n", - "if not Notebook_version == list(Latest_notebook_version.columns):\n", - " print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n", - "\n", - "\n", - "# Exporting requirements.txt for local run\n", - "!pip freeze > requirements.txt\n", - "after = [str(m) for m in sys.modules]\n", - "# Get minimum requirements file\n", - "\n", - "#Add the following lines before all imports: \n", - "# import sys\n", - "# before = [str(m) for m in sys.modules]\n", - "\n", - "#Add the following line after the imports:\n", - "# after = [str(m) for m in sys.modules]\n", - "\n", - "from builtins import any as b_any\n", - "\n", - "def filter_files(file_list, filter_list):\n", - " filtered_list = []\n", - " for fname in file_list:\n", - " if b_any(fname.split('==')[0] in s for s in filter_list):\n", - " filtered_list.append(fname)\n", - " return filtered_list\n", - "\n", - "df = pd.read_csv('requirements.txt', delimiter = \"\\n\")\n", - "mod_list = [m.split('.')[0] for m in after if not m in before]\n", - "req_list_temp = df.values.tolist()\n", - "req_list = [x[0] for x in req_list_temp]\n", - "\n", - "# Replace with package name \n", - "mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]\n", - "mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list] \n", - "filtered_list = filter_files(req_list, mod_replace_list)\n", - "\n", - "file=open('3D_UNet_requirements_simple.txt','w')\n", - "for item in filtered_list:\n", - " file.writelines(item + '\\n')\n", - "\n", - "file.close()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "HLYcZR9gMv42" - }, - "source": [ - "# **3. Select your parameters and paths**\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "AuESFimvMv43" - }, - "source": [ - "## **3.1. Choosing parameters**\n", - "\n", - "---\n", - "\n", - "### **Paths to training data and model**\n", - "\n", - "* **`training_source`** and **`training_target`** specify the paths to the training data. They can either be a single multipage TIFF file each or directories containing various multipage TIFF files in which case target and source files must be named identically within the respective directories. See Section 0 for a detailed description of the necessary directory structure.\n", - "\n", - "* **`model_name`** will be used when naming checkpoints. Adhere to a `lower_case_with_underscores` naming convention and beware of using the name of an existing model within the same folder, as it will be overwritten.\n", - "\n", - "* **`model_path`** specifies the directory where the model checkpoints and quality control logs will be saved.\n", - "\n", - "\n", - "**Note:** You can copy paths from the 'Files' tab by right-clicking any folder or file and selecting 'Copy path'. \n", - "\n", - "### **Training parameters**\n", - "\n", - "* **`number_of_epochs`** is the number of times the entire training data will be seen by the model. *Default: >100*\n", - "\n", - "* **`batch_size`** is the number of training patches of size `patch_size` that will be bundled together at each training step. *Default: 1*\n", - "\n", - "* **`patch_size`** specifies the size of the three-dimensional training patches in (x, y, z) that will be fed to the model. In order to avoid errors, preferably use a square aspect ratio or stick to the advanced parameters. *Default: <(512, 512, 16)*\n", - "\n", - "* **`validation_split_in_percent`** is the relative amount of training data that will be set aside for validation. *Default: 20* \n", - "\n", - "* **`downscaling_in_xy`** downscales the training images by the specified amount in x and y. This is useful to enforce isotropic pixel-size if the z resolution is lower than the xy resolution in the training volume or to capture a larger field-of-view while decreasing the memory requirements. *Default: 1*\n", - "\n", - "* **`image_pre_processing`** selects whether the training images are randomly cropped during training or resized to `patch_size`. Choose `randomly crop to patch_size` to shrink the field-of-view of the training images to the `patch_size`. *Default: resize to patch_size* \n", - "\n", - "* **`binary_target`** forces the target image to be binary. Choose this if your model is trained to perform binary segmentation tasks *Default: True* \n", - "\n", - "* **`loss_function`** defines the loss. Read more [here](https://keras.io/api/losses/). *Default: weighted_binary_crossentropy* \n", - "\n", - "* **`metrics`** defines the metric. Read more [here](https://keras.io/api/metrics/). *Default: dice* \n", - "\n", - "* **`optimizer`** defines the optimizer. Read more [here](https://keras.io/api/optimizers/). *Default: adam* \n", - "\n", - "**Note:** If a *ResourceExhaustedError* is raised in Section 4.1. during training, decrease `batch_size` and `patch_size`. Decrease `batch_size` first and if the error persists at `batch_size = 1`, reduce the `patch_size`. \n", - "\n", - "**Note:** The number of steps per epoch are calculated as `floor(augment_factor * (1 - validation_split) * num_of_slices / batch_size)` if `image_pre_processing` is `resize to patch_size` where `augment_factor` is three if `apply_data_augmentation` is `True` and one otherwise. The `num_of_slices` is the overall number of slices (z-depth) in the training set across all provided image volumes. If `image_pre_processing` is `randomly crop to patch_size`, the number of steps per epoch are calculated as `floor(augment_factor * volume / (crop_volume * batch_size))` where `volume` is the overall volume of the training data in pixels accounting for the validation split and `crop_volume` is defined as the volume in pixels based on the specified `patch_size`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "ewpNJ_I0Mv47" - }, - "outputs": [], - "source": [ - "#@markdown ###Path to training data:\n", - "training_source = \"\" #@param {type:\"string\"}\n", - "training_target = \"\" #@param {type:\"string\"}\n", - "\n", - "#@markdown ---\n", - "\n", - "#@markdown ###Model name and path to model folder:\n", - "model_name = \"\" #@param {type:\"string\"}\n", - "model_path = \"\" #@param {type:\"string\"}\n", - "\n", - "full_model_path = os.path.join(model_path, model_name)\n", - "\n", - "#@markdown ---\n", - "\n", - "#@markdown ###Training parameters\n", - "number_of_epochs = 100#@param {type:\"number\"}\n", - "\n", - "#@markdown ###Default advanced parameters\n", - "use_default_advanced_parameters = True #@param {type:\"boolean\"}\n", - "\n", - "#@markdown If not, please change:\n", - "\n", - "batch_size = 1#@param {type:\"number\"}\n", - "patch_size = (256,256,4) #@param {type:\"number\"} # in pixels\n", - "training_shape = patch_size + (1,)\n", - "image_pre_processing = 'randomly crop to patch_size' #@param [\"randomly crop to patch_size\", \"resize to patch_size\"]\n", - "\n", - "validation_split_in_percent = 20 #@param{type:\"number\"}\n", - "downscaling_in_xy = 2#@param {type:\"number\"} # in pixels\n", - "\n", - "binary_target = True #@param {type:\"boolean\"}\n", - "\n", - "loss_function = 'weighted_binary_crossentropy' #@param [\"weighted_binary_crossentropy\", \"binary_crossentropy\", \"categorical_crossentropy\", \"sparse_categorical_crossentropy\", \"mean_squared_error\", \"mean_absolute_error\"]\n", - "\n", - "metrics = 'dice' #@param [\"dice\", \"accuracy\"]\n", - "\n", - "optimizer = 'adam' #@param [\"adam\", \"sgd\", \"rmsprop\"]\n", - "\n", - "\n", - "if image_pre_processing == \"randomly crop to patch_size\":\n", - " random_crop = True\n", - "else:\n", - " random_crop = False\n", - "\n", - "if use_default_advanced_parameters: \n", - " print(\"Default advanced parameters enabled\")\n", - " batch_size = 1\n", - " training_shape = (256,256,8,1)\n", - " validation_split_in_percent = 20\n", - " downscaling_in_xy = 1\n", - " random_crop = True\n", - " binary_target = True\n", - " loss_function = 'weighted_binary_crossentropy'\n", - " metrics = 'dice'\n", - " optimizer = 'adam'\n", - "\n", - "#@markdown ###Checkpointing parameters\n", - "checkpointing_period = 1 #@param {type:\"number\"}\n", - "\n", - "#@markdown If chosen only the best checkpoint is saved, otherwise a checkpoint is saved every checkpoint_period epochs:\n", - "save_best_only = True #@param {type:\"boolean\"}\n", - "\n", - "#@markdown ###Resume training\n", - "#@markdown Choose if training was interrupted:\n", - "resume_training = False #@param {type:\"boolean\"}\n", - "\n", - "#@markdown ###Transfer learning\n", - "#@markdown For transfer learning, do not select resume_training and specify a checkpoint_path below:\n", - "checkpoint_path = \"\" #@param {type:\"string\"}\n", - "\n", - "if resume_training and checkpoint_path != \"\":\n", - " print('If resume_training is True while checkpoint_path is specified, resume_training will be set to False!')\n", - " resume_training = False\n", - " \n", - "\n", - "# Retrieve last checkpoint\n", - "if resume_training:\n", - " try:\n", - " ckpt_dir_list = glob(full_model_path + '/ckpt/*')\n", - " ckpt_dir_list.sort()\n", - " last_ckpt_path = ckpt_dir_list[-1]\n", - " print('Training will resume from checkpoint:', os.path.basename(last_ckpt_path))\n", - " except IndexError:\n", - " last_ckpt_path=None\n", - " print('CheckpointError: No previous checkpoints were found, training from scratch.')\n", - "elif not resume_training and checkpoint_path != \"\":\n", - " last_ckpt_path = checkpoint_path\n", - " assert os.path.isfile(last_ckpt_path), 'checkpoint_path does not exist!'\n", - "else:\n", - " last_ckpt_path=None\n", - "\n", - "# Instantiate Unet3D \n", - "model = Unet3D(shape=training_shape)\n", - "\n", - "#here we check that no model with the same name already exist\n", - "if not resume_training and os.path.exists(full_model_path): \n", - " print(bcolors.WARNING+'The model folder already exists and will be overwritten.'+bcolors.NORMAL)\n", - " # print('!! WARNING: Folder already exists and will be overwritten !!') \n", - " # shutil.rmtree(full_model_path)\n", - "\n", - "# if not os.path.exists(full_model_path):\n", - "# os.makedirs(full_model_path)\n", - "\n", - "# Show sample image\n", - "if os.path.isdir(training_source):\n", - " training_source_sample = sorted(glob(os.path.join(training_source, '*')))[0]\n", - " training_target_sample = sorted(glob(os.path.join(training_target, '*')))[0]\n", - "else:\n", - " training_source_sample = training_source\n", - " training_target_sample = training_target\n", - "\n", - "src_sample = tifffile.imread(training_source_sample)\n", - "src_sample = model._min_max_scaling(src_sample)\n", - "if binary_target:\n", - " tgt_sample = tifffile.imread(training_target_sample).astype(np.bool)\n", - "else:\n", - " tgt_sample = tifffile.imread(training_target_sample)\n", - "\n", - "src_down = transform.downscale_local_mean(src_sample[0], (downscaling_in_xy, downscaling_in_xy))\n", - "tgt_down = transform.downscale_local_mean(tgt_sample[0], (downscaling_in_xy, downscaling_in_xy)) \n", - "\n", - "if random_crop:\n", - " true_patch_size = None\n", - "\n", - " if src_down.shape[0] == training_shape[0]:\n", - " x_rand = 0\n", - " if src_down.shape[1] == training_shape[1]:\n", - " y_rand = 0\n", - " if src_down.shape[0] > training_shape[0]:\n", - " x_rand = np.random.randint(src_down.shape[0] - training_shape[0])\n", - " if src_down.shape[1] > training_shape[1]:\n", - " y_rand = np.random.randint(src_down.shape[1] - training_shape[1])\n", - " if src_down.shape[0] < training_shape[0] or src_down.shape[1] < training_shape[1]:\n", - " raise ValueError('Patch shape larger than (downscaled) source shape')\n", - "else:\n", - " true_patch_size = src_down.shape\n", - "\n", - "def scroll_in_z(z):\n", - " src_down = transform.downscale_local_mean(src_sample[z-1], (downscaling_in_xy,downscaling_in_xy))\n", - " tgt_down = transform.downscale_local_mean(tgt_sample[z-1], (downscaling_in_xy,downscaling_in_xy)) \n", - " if random_crop:\n", - " src_slice = src_down[x_rand:training_shape[0]+x_rand, y_rand:training_shape[1]+y_rand]\n", - " tgt_slice = tgt_down[x_rand:training_shape[0]+x_rand, y_rand:training_shape[1]+y_rand]\n", - " else:\n", - " \n", - " src_slice = transform.resize(src_down, (training_shape[0], training_shape[1]), mode='constant', preserve_range=True)\n", - " tgt_slice = transform.resize(tgt_down, (training_shape[0], training_shape[1]), mode='constant', preserve_range=True)\n", - "\n", - " f=plt.figure(figsize=(16,8))\n", - " plt.subplot(1,2,1)\n", - " plt.imshow(src_slice, cmap='gray')\n", - " plt.title('Training source (z = ' + str(z) + ')', fontsize=15)\n", - " plt.axis('off')\n", - "\n", - " plt.subplot(1,2,2)\n", - " plt.imshow(tgt_slice, cmap='magma')\n", - " plt.title('Training target (z = ' + str(z) + ')', fontsize=15)\n", - " plt.axis('off')\n", - " plt.savefig('/content/TrainingDataExample_Unet3D.png',bbox_inches='tight',pad_inches=0)\n", - " #plt.close()\n", - "\n", - "print('This is what the training images will look like with the chosen settings')\n", - "interact(scroll_in_z, z=widgets.IntSlider(min=1, max=src_sample.shape[0], step=1, value=0));\n", - "\n", - "#Create a copy of an example slice and close the display.\n", - "scroll_in_z(z=int(src_sample.shape[0]/2))\n", - "plt.close()\n", - "\n", - "# Save model parameters\n", - "params = {'training_source': training_source,\n", - " 'training_target': training_target,\n", - " 'model_name': model_name,\n", - " 'model_path': model_path,\n", - " 'number_of_epochs': number_of_epochs,\n", - " 'batch_size': batch_size,\n", - " 'training_shape': training_shape,\n", - " 'downscaling': downscaling_in_xy,\n", - " 'true_patch_size': true_patch_size,\n", - " 'val_split': validation_split_in_percent/100,\n", - " 'random_crop': random_crop}\n", - "\n", - "params_df = pd.DataFrame.from_dict(params, orient='index')\n", - "\n", - "# apply_data_augmentation = False\n", - "# pdf_export(augmentation = apply_data_augmentation, pretrained_model = resume_training)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "w_jCy7xOx2g3" - }, - "source": [ - "## **3.2. Data augmentation**\n", - " \n", - "---\n", - " Augmenting the training data increases robustness of the model by simulating possible variations within the training data which avoids it from overfitting on small datasets. We therefore strongly recommended augmenting the data and making sure that the applied augmentations are reasonable.\n", - "\n", - "* **Gaussian blur** blurs images using Gaussian kernels with a sigma of `gaussian_sigma`. This augmentation step is applied with a probability of `gaussian_frequency`. Read more [here](https://imgaug.readthedocs.io/en/latest/source/overview/blur.html#gaussianblur).\n", - "\n", - "* **Linear contrast** modifies the contrast of images according to `127 + alpha *(pixel_value-127)`, where `pixel_value` and `alpha` are sampled uniformly from the interval `[contrast_min, contrast_max]`. This augmentation step is applied with a probability of `contrast_frequency`. Read more [here](https://imgaug.readthedocs.io/en/latest/source/overview/contrast.html#linearcontrast).\n", - "\n", - "* **Additive Gaussian noise** adds Gaussian noise sampled once per pixel from a normal distribution `N(0, s)`, where `s` is sampled from `[scale_min, scale_max]`. This augmentation step is applied with a probability of `noise_frequency`. Read more [here](https://imgaug.readthedocs.io/en/latest/source/overview/arithmetic.html#additivegaussiannoise).\n", - "\n", - "* **Add custom augmenters** allows you to create a custom augmentation pipeline using the [augmenters available in the imagug library](https://imgaug.readthedocs.io/en/latest/source/overview_of_augmenters.html).\n", - "In the example above, the augmentation pipeline is equivalent to: \n", - "```\n", - "seq = iaa.Sequential([\n", - " iaa.Sometimes(0.3, iaa.GammaContrast((0.5, 2.0)), \n", - " iaa.Sometimes(0.4, iaa.AverageBlur((0.5, 2.0)), \n", - " iaa.Sometimes(0.5, iaa.LinearContrast((0.4, 1.6)), \n", - "], random_order=True)\n", - "```\n", - " Note that there is no limit on the number of augmenters that can be chained together and that individual augmenter and parameter entries must be separated by `;`. Custom augmenters do not overwrite the preset augmentation steps (*Gaussian blur*, *Linear contrast* or *Additive Gaussian noise*). Also, the augmenters, augmenter parameters and augmenter frequencies must be entered such that each position within the string corresponds to the same augmentation step.\n", - "\n", - "* **`apply_data_augmentation`** ensures that data augmentation is randomly applied to the training data at each training step. This includes inverting the order of the slices within a training patch, as well as applying any augmenters that are added. *Default: True*\n", - "\n", - "* **`add_elastic_deform`** ensures that elastic grid-based deformations are applied as described in the original 3D U-Net paper. *Default: True*" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "DMqWq5-AxnFU" - }, - "outputs": [], - "source": [ - "#@markdown ##**Augmentation options**\n", - "\n", - "#@markdown ###Data augmentation\n", - "\n", - "apply_data_augmentation = False #@param {type:\"boolean\"}\n", - "\n", - "# List of augmentations\n", - "augmentations = []\n", - "\n", - "#@markdown ###Gaussian blur\n", - "add_gaussian_blur = True #@param {type:\"boolean\"}\n", - "gaussian_sigma = 0.7#@param {type:\"number\"}\n", - "gaussian_frequency = 0.5 #@param {type:\"number\"}\n", - "\n", - "if add_gaussian_blur:\n", - " augmentations.append(iaa.Sometimes(gaussian_frequency, iaa.GaussianBlur(sigma=(0, gaussian_sigma))))\n", - "\n", - "#@markdown ###Linear contrast\n", - "add_linear_contrast = True #@param {type:\"boolean\"}\n", - "contrast_min = 0.4 #@param {type:\"number\"}\n", - "contrast_max = 1.6#@param {type:\"number\"}\n", - "contrast_frequency = 0.5 #@param {type:\"number\"}\n", - "\n", - "if add_linear_contrast:\n", - " augmentations.append(iaa.Sometimes(contrast_frequency, iaa.LinearContrast((contrast_min, contrast_max))))\n", - "\n", - "#@markdown ###Additive Gaussian noise\n", - "add_additive_gaussian_noise = False #@param {type:\"boolean\"}\n", - "scale_min = 0 #@param {type:\"number\"}\n", - "scale_max = 0.05 #@param {type:\"number\"}\n", - "noise_frequency = 0.5 #@param {type:\"number\"}\n", - "\n", - "if add_additive_gaussian_noise:\n", - " augmentations.append(iaa.Sometimes(noise_frequency, iaa.AdditiveGaussianNoise(scale=(scale_min, scale_max))))\n", - "\n", - "#@markdown ###Add custom augmenters\n", - "\n", - "augmenters = \"GammaContrast; AverageBlur; LinearContrast\" #@param {type:\"string\"}\n", - "\n", - "augmenter_params = \"(0.5, 2.0); (0.5, 2.0); (0.4, 1.6)\" #@param {type:\"string\"}\n", - "\n", - "augmenter_frequency = \"0.3; 0.4; 0.5\" #@param {type:\"string\"}\n", - "\n", - "aug_lst = augmenters.split(';')\n", - "aug_params_lst = augmenter_params.split(';')\n", - "aug_freq_lst = augmenter_frequency.split(';')\n", - "\n", - "assert len(aug_lst) == len(aug_params_lst) and len(aug_lst) == len(aug_freq_lst), 'The number of arguments in augmenters, augmenter_params and augmenter_frequency are not the same!'\n", - "\n", - "for __, (aug, param, freq) in enumerate(zip(aug_lst, aug_params_lst, aug_freq_lst)):\n", - " aug, param, freq = aug.strip(), param.strip(), freq.strip() \n", - " aug_func = iaa.Sometimes(eval(freq), getattr(iaa, aug)(eval(param)))\n", - " augmentations.append(aug_func)\n", - "\n", - "#@markdown ###Elastic deformations\n", - "add_elastic_deform = True #@param {type:\"boolean\"}\n", - "sigma = 2#@param {type:\"number\"}\n", - "points = 2#@param {type:\"number\"}\n", - "order = 2#@param {type:\"number\"}\n", - "\n", - "if add_elastic_deform:\n", - " deform_params = (sigma, points, order)\n", - "else:\n", - " deform_params = None\n", - "\n", - "train_generator = MultiPageTiffGenerator(training_source,\n", - " training_target,\n", - " batch_size=batch_size,\n", - " shape=training_shape,\n", - " augment=apply_data_augmentation,\n", - " augmentations=augmentations,\n", - " deform_augment=add_elastic_deform,\n", - " deform_augmentation_params=deform_params,\n", - " val_split=validation_split_in_percent/100,\n", - " random_crop=random_crop,\n", - " downscale=downscaling_in_xy,\n", - " binary_target=binary_target)\n", - "\n", - "val_generator = MultiPageTiffGenerator(training_source,\n", - " training_target,\n", - " batch_size=batch_size,\n", - " shape=training_shape,\n", - " val_split=validation_split_in_percent/100,\n", - " is_val=True,\n", - " random_crop=random_crop,\n", - " downscale=downscaling_in_xy,\n", - " binary_target=binary_target)\n", - "\n", - "\n", - "if apply_data_augmentation:\n", - " print('Data augmentation enabled.')\n", - " sample_src_aug, sample_tgt_aug = train_generator.sample_augmentation(random.randint(0, len(train_generator)))\n", - "\n", - " def scroll_in_z(z):\n", - " f=plt.figure(figsize=(16,8))\n", - " plt.subplot(1,2,1)\n", - " plt.imshow(sample_src_aug[0,:,:,z-1,0], cmap='gray')\n", - " plt.title('Sample augmented source (z = ' + str(z) + ')', fontsize=15)\n", - " plt.axis('off')\n", - "\n", - " plt.subplot(1,2,2)\n", - " plt.imshow(sample_tgt_aug[0,:,:,z-1,0], cmap='magma')\n", - " plt.title('Sample training target (z = ' + str(z) + ')', fontsize=15)\n", - " plt.axis('off')\n", - "\n", - " print('This is what the augmented training images will look like with the chosen settings')\n", - " interact(scroll_in_z, z=widgets.IntSlider(min=1, max=sample_src_aug.shape[3], step=1, value=0));\n", - "\n", - "else:\n", - " print('Data augmentation disabled.')\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "MCGklf1vZf2M" - }, - "source": [ - "# **4. Train the network**\n", - "---\n", - "\n", - "**CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training times must be less than 12 hours! If training takes longer than 12 hours, please decrease `number_of_epochs`." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1KYOuygETJkT" - }, - "source": [ - "## **4.1. Show model and start training**\n", - "---\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "lIUAOJ_LMv5E" - }, - "outputs": [], - "source": [ - "#@markdown ## Show model summary\n", - "model.summary()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "CyQI4ssarUp4" - }, - "outputs": [], - "source": [ - "#@markdown ##Start training\n", - "\n", - "#here we check that no model with the same name already exist, if so delete\n", - "if not resume_training and os.path.exists(full_model_path): \n", - " shutil.rmtree(full_model_path)\n", - " print(bcolors.WARNING+'!! WARNING: Folder already exists and has been overwritten !!'+bcolors.NORMAL) \n", - "\n", - "if not os.path.exists(full_model_path):\n", - " os.makedirs(full_model_path)\n", - "\n", - "pdf_export(augmentation = apply_data_augmentation, pretrained_model = resume_training)\n", - "\n", - "# Save file\n", - "params_df.to_csv(os.path.join(full_model_path, 'params.csv'))\n", - "\n", - "start = time.time()\n", - "# Start Training\n", - "model.train(epochs=number_of_epochs,\n", - " batch_size=batch_size,\n", - " train_generator=train_generator,\n", - " val_generator=val_generator,\n", - " model_path=model_path,\n", - " model_name=model_name,\n", - " loss=loss_function,\n", - " metrics=metrics,\n", - " optimizer=optimizer,\n", - " ckpt_period=checkpointing_period,\n", - " save_best_ckpt_only=save_best_only,\n", - " ckpt_path=last_ckpt_path)\n", - "\n", - "print('Training successfully completed!')\n", - "dt = time.time() - start\n", - "mins, sec = divmod(dt, 60) \n", - "hour, mins = divmod(mins, 60) \n", - "print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n", - "\n", - "#Create a pdf document with training summary\n", - "\n", - "pdf_export(trained = True, augmentation = apply_data_augmentation, pretrained_model = resume_training)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "0Dfn8ZsEMv5d" - }, - "source": [ - "##**4.2. Download your model from Google Drive**\n", - "\n", - "---\n", - "Once training is complete, the trained model is automatically saved to your Google Drive, in the **`model_path`** folder that was specified in Section 3. Download the folder to avoid any unwanted surprises, since the data can be erased if you train another model using the same `model_path`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "iwNmp1PUzRDQ", - "scrolled": true - }, - "outputs": [], - "source": [ - "#@markdown ##Download model directory\n", - "#@markdown 1. Specify the model_path in `model_path_download` otherwise the model sepcified in Section 3.1 will be downloaded\n", - "#@markdown 2. Run this cell to zip the model directory\n", - "#@markdown 3. Download the zipped file from the *Files* tab on the left\n", - "\n", - "from google.colab import files\n", - "\n", - "model_path_download = \"\" #@param {type:\"string\"}\n", - "\n", - "if len(model_path_download) == 0:\n", - " model_path_download = full_model_path\n", - "\n", - "model_name_download = os.path.basename(model_path_download)\n", - "\n", - "print('Zipping', model_name_download)\n", - "\n", - "zip_model_path = model_name_download + '.zip'\n", - "\n", - "!zip -r \"$zip_model_path\" \"$model_path_download\"\n", - "\n", - "print('Successfully saved zipped model directory as', zip_model_path)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "_0Hynw3-xHp1" - }, - "source": [ - "# **5. Evaluate your model**\n", - "---\n", - "\n", - "In this section the newly trained model can be assessed for performance. This involves inspecting the loss function in Section 5.1. and employing more advanced metrics in Section 5.2.\n", - "\n", - "**We highly recommend performing quality control on all newly trained models.**\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "eAJzMwPA6tlH" - }, - "outputs": [], - "source": [ - "#@markdown ###Model to be evaluated:\n", - "#@markdown If left blank, the latest model defined in Section 3 will be evaluated:\n", - "\n", - "qc_model_name = \"\" #@param {type:\"string\"}\n", - "qc_model_path = \"\" #@param {type:\"string\"}\n", - "\n", - "if len(qc_model_path) == 0 and len(qc_model_name) == 0:\n", - " qc_model_name = model_name\n", - " qc_model_path = model_path\n", - "\n", - "full_qc_model_path = os.path.join(qc_model_path, qc_model_name)\n", - "\n", - "if os.path.exists(full_qc_model_path):\n", - " print(qc_model_name + ' will be evaluated')\n", - "else:\n", - " W = '\\033[0m' # white (normal)\n", - " R = '\\033[31m' # red\n", - " print(R+'!! WARNING: The chosen model does not exist !!'+W)\n", - " print('Please make sure you provide a valid model path and model name before proceeding further.')\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "dhJROwlAMv5o" - }, - "source": [ - "## **5.1. Inspecting loss function**\n", - "---\n", - "\n", - "**The training loss** is the error between prediction and target after each epoch calculated across the training data while the **validation loss** calculates the error on the (unseen) validation data. During training these values should decrease until converging at which point the model has been sufficiently trained. If the validation loss starts increasing while the training loss has plateaued, the model has overfit on the training data which reduces its ability to generalise. Aim to halt training before this point.\n", - "\n", - "**Note:** For a more in-depth explanation please refer to [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols et al.\n", - "\n", - "\n", - "The accuracy is another performance metric that is calculated after each epoch. We use the [Sørensen–Dice coefficient](https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient) to score the prediction accuracy. \n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "vMzSP50kMv5p" - }, - "outputs": [], - "source": [ - "#@markdown ##Visualise loss and accuracy\n", - "lossDataFromCSV = []\n", - "vallossDataFromCSV = []\n", - "accuracyDataFromCSV = []\n", - "valaccuracyDataFromCSV = []\n", - "\n", - "with open(full_qc_model_path + '/Quality Control/training_evaluation.csv', 'r') as csvfile:\n", - " csvRead = csv.reader(csvfile, delimiter=',')\n", - " next(csvRead)\n", - " for row in csvRead:\n", - " lossDataFromCSV.append(float(row[2]))\n", - " vallossDataFromCSV.append(float(row[4]))\n", - " accuracyDataFromCSV.append(float(row[1]))\n", - " valaccuracyDataFromCSV.append(float(row[3]))\n", - "\n", - "epochNumber = range(len(lossDataFromCSV))\n", - "plt.figure(figsize=(15,10))\n", - "\n", - "plt.subplot(2,1,1)\n", - "plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n", - "plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n", - "plt.title('Training and validation loss', fontsize=14)\n", - "plt.ylabel('Loss', fontsize=12)\n", - "plt.xlabel('Epochs', fontsize=12)\n", - "plt.legend()\n", - "\n", - "plt.subplot(2,1,2)\n", - "plt.plot(epochNumber,accuracyDataFromCSV, label='Training accuracy')\n", - "plt.plot(epochNumber,valaccuracyDataFromCSV, label='Validation accuracy')\n", - "plt.title('Training and validation accuracy', fontsize=14)\n", - "plt.ylabel('Dice', fontsize=12)\n", - "plt.xlabel('Epochs', fontsize=12)\n", - "plt.legend()\n", - "plt.savefig(full_qc_model_path + '/Quality Control/lossCurvePlots.png', bbox_inches='tight', pad_inches=0)\n", - "plt.show()\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "X5_92nL2xdP6" - }, - "source": [ - "## **5.2. Error mapping and quality metrics estimation**\n", - "---\n", - "This section will provide both a visual indication of the model performance by comparing the overlay of the predicted and source volume." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "w90MdriMxhjD" - }, - "outputs": [], - "source": [ - "#@markdown ##Compare prediction and ground-truth on testing data\n", - "\n", - "#@markdown Provide an unseen annotated dataset to determine the performance of the model:\n", - "\n", - "testing_source = \"\" #@param{type:\"string\"}\n", - "testing_target = \"\" #@param{type:\"string\"}\n", - "\n", - "qc_dir = full_qc_model_path + '/Quality Control'\n", - "predict_dir = qc_dir + '/Prediction'\n", - "if os.path.exists(predict_dir):\n", - " shutil.rmtree(predict_dir)\n", - "\n", - "os.makedirs(predict_dir)\n", - "\n", - "# predict_dir + '/' + \n", - "predict_path = os.path.splitext(os.path.basename(testing_source))[0] + '_prediction.tif'\n", - "\n", - "def last_chars(x):\n", - " return(x[-11:])\n", - "\n", - "try:\n", - " ckpt_dir_list = glob(full_qc_model_path + '/ckpt/*')\n", - " ckpt_dir_list.sort(key=last_chars)\n", - " last_ckpt_path = ckpt_dir_list[0]\n", - " print('Predicting from checkpoint:', os.path.basename(last_ckpt_path))\n", - "except IndexError:\n", - " raise CheckpointError('No previous checkpoints were found, please retrain model.')\n", - "\n", - "# Load parameters\n", - "params = pd.read_csv(os.path.join(full_qc_model_path, 'params.csv'), names=['val'], header=0, index_col=0) \n", - "\n", - "model = Unet3D(shape=params.loc['training_shape', 'val'])\n", - "\n", - "prediction = model.predict(testing_source, last_ckpt_path, downscaling=params.loc['downscaling', 'val'], true_patch_size=params.loc['true_patch_size', 'val'])\n", - "\n", - "tifffile.imwrite(predict_path, prediction.astype('float32'), imagej=True)\n", - "\n", - "print('Predicted images!')\n", - "\n", - "qc_metrics_path = full_qc_model_path + '/Quality Control/QC_metrics_' + qc_model_name + '.csv'\n", - "\n", - "test_target = tifffile.imread(testing_target)\n", - "test_source = tifffile.imread(testing_source)\n", - "test_prediction = tifffile.imread(predict_path)\n", - "\n", - "def scroll_in_z(z):\n", - "\n", - " plt.figure(figsize=(25,5))\n", - " # Source\n", - " plt.subplot(1,4,1)\n", - " plt.axis('off')\n", - " plt.imshow(test_source[z-1], cmap='gray')\n", - " plt.title('Source (z = ' + str(z) + ')', fontsize=15)\n", - "\n", - " # Target (Ground-truth)\n", - " plt.subplot(1,4,2)\n", - " plt.axis('off')\n", - " plt.imshow(test_target[z-1], cmap='magma')\n", - " plt.title('Target (z = ' + str(z) + ')', fontsize=15)\n", - "\n", - " # Prediction\n", - " plt.subplot(1,4,3)\n", - " plt.axis('off')\n", - " plt.imshow(test_prediction[z-1], cmap='magma')\n", - " plt.title('Prediction (z = ' + str(z) + ')', fontsize=15)\n", - " \n", - " # Overlay\n", - " plt.subplot(1,4,4)\n", - " plt.axis('off')\n", - " plt.imshow(test_target[z-1], cmap='Greens')\n", - " plt.imshow(test_prediction[z-1], alpha=0.5, cmap='Purples')\n", - " plt.title('Overlay (z = ' + str(z) + ')', fontsize=15)\n", - " plt.savefig(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/QC_example_data.png', bbox_inches='tight', pad_inches=0)\n", - "interact(scroll_in_z, z=widgets.IntSlider(min=1, max=test_source.shape[0], step=1, value=0));" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "aIvRxpZlsFeZ" - }, - "source": [ - "## **5.3. Determine best Intersection over Union and threshold**\n", - "---\n", - "\n", - "**Note:** This section is only relevant if the target image is a binary mask and `binary_target` is selected in Section 3! \n", - "\n", - "This section will provide both a visual and a quantitative indication of the model performance by comparing the overlay of the predicted and source volume, as well as computing the highest [**Intersection over Union**](https://en.wikipedia.org/wiki/Jaccard_index) (IoU) score. The IoU is also known as the Jaccard Index. \n", - "\n", - "The best threshold is calculated using the IoU. Each threshold value from 0 to 255 is tested and the threshold with the highest score is deemed the best. The IoU is calculated for the entire volume in 3D." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "XhkeZTFusHA8" - }, - "outputs": [], - "source": [ - "\n", - "#@markdown ##Calculate Intersection over Union and best threshold \n", - "prediction = tifffile.imread(predict_path)\n", - "prediction = np.interp(prediction, (prediction.min(), prediction.max()), (0, 255))\n", - "\n", - "target = tifffile.imread(testing_target).astype(np.bool)\n", - "\n", - "def iou_vs_threshold(prediction, target):\n", - " threshold_list = []\n", - " IoU_scores_list = []\n", - "\n", - " for threshold in range(0,256): \n", - " mask = prediction > threshold\n", - "\n", - " intersection = np.logical_and(target, mask)\n", - " union = np.logical_or(target, mask)\n", - " iou_score = np.sum(intersection) / np.sum(union)\n", - "\n", - " threshold_list.append(threshold)\n", - " IoU_scores_list.append(iou_score)\n", - "\n", - " return threshold_list, IoU_scores_list\n", - "\n", - "threshold_list, IoU_scores_list = iou_vs_threshold(prediction, target)\n", - "thresh_arr = np.array(list(zip(threshold_list, IoU_scores_list)))\n", - "best_thresh = int(np.where(thresh_arr == np.max(thresh_arr[:,1]))[0])\n", - "best_iou = IoU_scores_list[best_thresh]\n", - "\n", - "print('Highest IoU is {:.4f} with a threshold of {}'.format(best_iou, best_thresh))\n", - "\n", - "def adjust_threshold(threshold, z):\n", - "\n", - " f=plt.figure(figsize=(25,5))\n", - " plt.subplot(1,4,1)\n", - " plt.imshow((prediction[z-1] > threshold).astype('uint8'), cmap='magma')\n", - " plt.title('Prediction (Threshold = ' + str(threshold) + ')', fontsize=15)\n", - " plt.axis('off')\n", - "\n", - " plt.subplot(1,4,2)\n", - " plt.imshow(target[z-1], cmap='magma')\n", - " plt.title('Target (z = ' + str(z) + ')', fontsize=15)\n", - " plt.axis('off')\n", - "\n", - " plt.subplot(1,4,3)\n", - " plt.axis('off')\n", - " plt.imshow(test_source[z-1], cmap='gray')\n", - " plt.imshow((prediction[z-1] > threshold).astype('uint8'), alpha=0.4, cmap='Reds')\n", - " plt.title('Overlay (z = ' + str(z) + ')', fontsize=15)\n", - "\n", - " plt.subplot(1,4,4)\n", - " plt.title('Threshold vs. IoU', fontsize=15)\n", - " plt.plot(threshold_list, IoU_scores_list)\n", - " plt.plot(threshold, IoU_scores_list[threshold], 'ro') \n", - " plt.ylabel('IoU score')\n", - " plt.xlabel('Threshold')\n", - " plt.savefig(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/QC_IoU_analysis.png',bbox_inches=matplotlib.transforms.Bbox([[17.5,0],[23,5]]),pad_inches=0)\n", - " plt.show()\n", - "\n", - "interact(adjust_threshold, \n", - " threshold=widgets.IntSlider(min=0, max=255, step=1, value=best_thresh),\n", - " z=widgets.IntSlider(min=1, max=prediction.shape[0], step=1, value=0));\n", - "\n", - "#Make a pdf summary of the QC results\n", - "\n", - "qc_pdf_export()\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-tJeeJjLnRkP" - }, - "source": [ - "# **6. Using the trained model**\n", - "\n", - "---\n", - "\n", - "Once sufficient performance of the trained model has been established using Section 5, the network can be used to segment unseen volumetric data." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "d8wuQGjoq6eN" - }, - "source": [ - "## **6.1. Generate predictions from unseen dataset**\n", - "---\n", - "\n", - "The most recently trained model can now be used to predict segmentation masks on unseen images. If you want to use an older model, leave `model_path` blank. Predicted output images are saved in `output_path` as Image-J compatible TIFF files.\n", - "\n", - "## **Prediction parameters**\n", - "\n", - "* **`source_path`** specifies the location of the source \n", - "image volume.\n", - "\n", - "* **`output_directory`** specified the directory where the output predictions are stored.\n", - "\n", - "* **`binary_target`** should be chosen if the network is trained to predict binary segmentation masks.\n", - "\n", - "* **`threshold`** can be calculated in Section 5 and is used to generate binary masks from the predictions.\n", - "\n", - "* **`big_tiff`** should be chosen if the expected prediction exceeds 4GB. The predictions will be saved using the BigTIFF format. Beware that this might substantially reduce the prediction speed. *Default: False* \n", - "\n", - "* **`prediction_depth`** is only relevant if the prediction is saved as a BigTIFF. The prediction will not be performed in one go to not deplete the memory resources. Instead, the prediction is iteratively performed on a subset of the entire volume with shape `(source.shape[0], source.shape[1], prediction_depth)`. *Default: 32*\n", - "\n", - "* **`model_path`** specifies the path to a model other than the most recently trained." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "DEmhPh5fsWX2" - }, - "outputs": [], - "source": [ - "#@markdown ## Download example volume\n", - "\n", - "#@markdown This can take up to an hour\n", - "\n", - "import requests \n", - "import os\n", - "from tqdm.notebook import tqdm \n", - "\n", - "\n", - "def download_from_url(url, save_as):\n", - " file_url = url\n", - " r = requests.get(file_url, stream=True) \n", - " \n", - " with open(save_as, 'wb') as file: \n", - " for block in tqdm(r.iter_content(chunk_size = 1024), desc = 'Downloading ' + os.path.basename(save_as), total=3275073, ncols=1000):\n", - " if block:\n", - " file.write(block) \n", - "\n", - "download_from_url('https://documents.epfl.ch/groups/c/cv/cvlab-unit/www/data/%20ElectronMicroscopy_Hippocampus/volumedata.tif', 'example_dataset/volumedata.tif')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "y2TD5p7MZrEb" - }, - "outputs": [], - "source": [ - "#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then run the cell to predict outputs from your unseen images.\n", - "\n", - "source_path = \"\" #@param {type:\"string\"}\n", - "output_directory = \"\" #@param {type:\"string\"}\n", - "\n", - "if not os.path.exists(output_directory):\n", - " os.makedirs(output_directory)\n", - "\n", - "output_path = os.path.join(output_directory, os.path.splitext(os.path.basename(source_path))[0] + '_predicted.tif')\n", - "#@markdown ###Prediction parameters:\n", - "\n", - "binary_target = True #@param {type:\"boolean\"}\n", - "\n", - "save_probability_map = False #@param {type:\"boolean\"}\n", - "\n", - "#@markdown Determine best threshold in Section 5.2.\n", - "\n", - "use_calculated_threshold = True #@param {type:\"boolean\"}\n", - "threshold = 200#@param {type:\"number\"}\n", - "\n", - "# Tifffile library issues means that images cannot be appended to \n", - "#@markdown Choose if prediction file exceeds 4GB or if input file is very large (above 2GB). Image volume saved as BigTIFF.\n", - "big_tiff = False #@param {type:\"boolean\"}\n", - "\n", - "#@markdown Reduce `prediction_depth` if runtime runs out of memory during prediction. Only relevant if prediction saved as BigTIFF\n", - "\n", - "prediction_depth = 32#@param {type:\"number\"}\n", - "\n", - "#@markdown ###Model to be evaluated\n", - "#@markdown If left blank, the latest model defined in Section 5 will be evaluated\n", - "\n", - "full_model_path_ = \"\" #@param {type:\"string\"}\n", - "\n", - "if len(full_model_path_) == 0:\n", - " full_model_path_ = os.path.join(qc_model_path, qc_model_name) \n", - "\n", - "\n", - "\n", - "# Load parameters\n", - "params = pd.read_csv(os.path.join(full_model_path_, 'params.csv'), names=['val'], header=0, index_col=0) \n", - "model = Unet3D(shape=params.loc['training_shape', 'val'])\n", - "\n", - "if use_calculated_threshold:\n", - " threshold = best_thresh\n", - "\n", - "def last_chars(x):\n", - " return(x[-11:])\n", - "\n", - "try:\n", - " ckpt_dir_list = glob(full_model_path_ + '/ckpt/*')\n", - " ckpt_dir_list.sort(key=last_chars)\n", - " last_ckpt_path = ckpt_dir_list[0]\n", - " print('Predicting from checkpoint:', os.path.basename(last_ckpt_path))\n", - "except IndexError:\n", - " raise CheckpointError('No previous checkpoints were found, please retrain model.')\n", - "\n", - "src = tifffile.imread(source_path)\n", - "\n", - "if src.nbytes >= 4e9:\n", - " big_tiff = True\n", - " print('The source file exceeds 4GB in memory, prediction will be saved as BigTIFF!')\n", - "\n", - "if binary_target:\n", - " if not big_tiff:\n", - " prediction = model.predict(src, last_ckpt_path, downscaling=params.loc['downscaling', 'val'], true_patch_size=params.loc['true_patch_size', 'val'])\n", - " prediction = np.interp(prediction, (prediction.min(), prediction.max()), (0, 255))\n", - " prediction = (prediction > threshold).astype('float32')\n", - "\n", - " tifffile.imwrite(output_path, prediction, imagej=True)\n", - "\n", - " else:\n", - " with tifffile.TiffWriter(output_path, bigtiff=True) as tif:\n", - " for i in tqdm(range(0, src.shape[0], prediction_depth)):\n", - " prediction = model.predict(src, last_ckpt_path, z_range=(i,i+prediction_depth), downscaling=params.loc['downscaling', 'val'], true_patch_size=params.loc['true_patch_size', 'val'])\n", - " prediction = np.interp(prediction, (prediction.min(), prediction.max()), (0, 255))\n", - " prediction = (prediction > threshold).astype('float32')\n", - " \n", - " for j in range(prediction.shape[0]):\n", - " tif.save(prediction[j])\n", - "\n", - "if not binary_target or save_probability_map:\n", - " if not binary_target:\n", - " prob_map_path = output_path\n", - " else:\n", - " prob_map_path = os.path.splitext(output_path)[0] + '_prob_map.tif'\n", - " \n", - " if not big_tiff:\n", - " prediction = model.predict(src, last_ckpt_path, downscaling=params.loc['downscaling', 'val'], true_patch_size=params.loc['true_patch_size', 'val'])\n", - " prediction = np.interp(prediction, (prediction.min(), prediction.max()), (0, 255))\n", - " tifffile.imwrite(prob_map_path, prediction.astype('float32'), imagej=True)\n", - "\n", - " else:\n", - " with tifffile.TiffWriter(prob_map_path, bigtiff=True) as tif:\n", - " for i in tqdm(range(0, src.shape[0], prediction_depth)):\n", - " prediction = model.predict(src, last_ckpt_path, z_range=(i,i+prediction_depth), downscaling=params.loc['downscaling', 'val'], true_patch_size=params.loc['true_patch_size', 'val'])\n", - " prediction = np.interp(prediction, (prediction.min(), prediction.max()), (0, 255))\n", - " \n", - " for j in range(prediction.shape[0]):\n", - " tif.save(prediction[j])\n", - "\n", - "print('Predictions saved as', output_path)\n", - "\n", - "src_volume = tifffile.imread(source_path)\n", - "pred_volume = tifffile.imread(output_path)\n", - "\n", - "def scroll_in_z(z):\n", - " \n", - " f=plt.figure(figsize=(25,5))\n", - " plt.subplot(1,2,1)\n", - " plt.imshow(src_volume[z-1], cmap='gray')\n", - " plt.title('Source (z = ' + str(z) + ')', fontsize=15)\n", - " plt.axis('off')\n", - "\n", - " plt.subplot(1,2,2)\n", - " plt.imshow(pred_volume[z-1], cmap='magma')\n", - " plt.title('Prediction (z = ' + str(z) + ')', fontsize=15)\n", - " plt.axis('off')\n", - "\n", - "interact(scroll_in_z, z=widgets.IntSlider(min=1, max=src_volume.shape[0], step=1, value=0));\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "hvkd66PldsXB" - }, - "source": [ - "## **6.2. Download your predictions**\n", - "---\n", - "\n", - "**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "UvSlTaH14s3t" - }, - "source": [ - "\n", - "#**Thank you for using 3D U-Net!**" - ] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "collapsed_sections": [], - "machine_shape": "hm", - "name": "U-Net_3D_ZeroCostDL4Mic.ipynb", - "provenance": [], - "toc_visible": true - }, - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.4" - } - }, - "nbformat": 4, - "nbformat_minor": 1 -} +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"U-Net_3D_ZeroCostDL4Mic.ipynb","provenance":[],"collapsed_sections":[],"toc_visible":true,"machine_shape":"hm"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.7"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"IkSguVy8Xv83"},"source":["# **U-Net (3D)**\n"," ---\n","\n"," The 3D U-Net was first introduced by [Çiçek et al](https://arxiv.org/abs/1606.06650) for learning dense volumetric segmentations from sparsely annotated ground-truth data building upon the original U-Net architecture by [Ronneberger et al](https://arxiv.org/abs/1505.04597). \n","\n","**This particular implementation allows supervised learning between any two types of 3D image data. If you are interested in image segmentation of 2D datasets, you should use the 2D U-Net notebook instead.**\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the *Zero-Cost Deep-Learning to Enhance Microscopy* project ([ZeroCostDL4Mic](/~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki)) jointly developed by the [Jacquemet](https://cellmig.org/) and [Henriques](https://henriqueslab.github.io/) laboratories and created by Daniel Krentzel.\n","\n","This notebook is laregly based on the following paper: \n","\n","[**3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation**](https://arxiv.org/pdf/1606.06650.pdf) by Özgün Çiçek *et al.* published on arXiv in 2016\n","\n","The following two Python libraries play an important role in the notebook: \n","\n","1. [**Elasticdeform**](/~https://github.com/gvtulder/elasticdeform)\n"," by Gijs van Tulder was used to augment the 3D training data using elastic grid-based deformations as described in the original 3D U-Net paper. \n","\n","2. [**Tifffile**](/~https://github.com/cgohlke/tifffile) by Christoph Gohlke is a great library for reading and writing TIFF files. \n","\n","3. [**Imgaug**](/~https://github.com/aleju/imgaug) by Alexander Jung *et al.* is an amazing library for image augmentation in machine learning - it is the most complete and extensive image augmentation package I have found to date. \n","\n","The [example dataset](https://www.epfl.ch/labs/cvlab/data/data-em/) represents a 5x5x5µm section taken from the CA1 hippocampus region of the brain with annotated mitochondria and was acquired by Graham Knott and Marco Cantoni at EPFL.\n","\n","\n","**Please also cite the original paper and relevant Python libraries when using or developing this notebook.**"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use ZeroCostDL4Mic notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cells: \n","\n","**Text cells** provide information and can be modified by double-clicking the cell. You are currently reading a text cell. You can create a new one by clicking `+ Text`.\n","\n","**Code cells** contain code which can be modfied by selecting the cell. To execute the cell, move your cursor to the `[]`-symbol on the left side of the cell (a play button should appear). Click it to execute the cell. Once the cell is fully executed, the animation stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","Three tabs are located on the upper left side of the notebook:\n","\n","1. *Table of contents* contains the structure of the notebook. Click the headers to move quickly between sections.\n","\n","2. *Code snippets* provides a wide array of example code specific to Google Colab. You can ignore this when using this notebook.\n","\n","3. *Files* displays the current working directory. We will mount your Google Drive in Section 1.2. so that you can access your files and save them permanently.\n","\n","**Important:** All uploaded files are purged once the runtime ends.\n","\n","**Note:** The directory *sample data* in *Files* contains default files. Do not upload anything there!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive by clicking *File* -> *Save a copy in Drive*.\n","\n","To **edit a cell**, double click on the text. This will either display the source code (in code cells) or the [markdown](https://colab.research.google.com/notebooks/markdown_guide.ipynb#scrollTo=70pYkR9LiOV0) (in text cells).\n","You can use `#` in code cells to comment out parts of the code. This allows you to keep the original piece of code while not executing it."]},{"cell_type":"markdown","metadata":{"id":"gKDLkLWUd-YX"},"source":["#**0. Before getting started**\n","---\n","\n","As the network operates in three dimensions, certain consideration should be given to correctly pre-processing the data. Ensure that the structure of interest does not substantially change between slices - image volumes with isotropic pixelsizes are ideal for this architecture.\n","\n","Each image volume must be provided as an **8-bit** or **binary multipage TIFF file** to maintain the correct ordering of individual image slices. If more than one image volume has been annotated, source and target files must be named identically and placed in separate directories. In case only one image volume has been annotated, source and target file do not have to be placed in separate directories and can be named differently, as long as their paths are explicitly provided in Section 3. \n","\n","**Prepare two datasets** (*training* and *testing*) for quality control puproses. Make sure that the *testing* dataset does not overlap with the *training* dataset and is ideally sourced from a different acquisiton and sample to ensure robustness of the trained model. \n","\n","\n","---\n","\n","\n","### **Directory structure**\n","\n","Make sure to adhere to one of the following directory structures. If only one annotated training volume exists, choose the first structure. In case more than one training volume is available, choose the second structure.\n","\n","**Structure 1:** Only one training volume\n","```\n","path/to/directory/with/one/training/volume\n","│--training_source.tif\n","│--training_target.tif\n","| \n","│--testing_source.tif\n","|--testing_target.tif \n","|\n","|--data_to_predict_on.tif\n","|--prediction_results.tif\n","\n","```\n","**Structure 2:** Various training volumes\n","```\n","path/to/directory/with/various/training/volumes\n","│--testing_source.tif\n","|--testing_target.tif \n","|\n","└───training\n","| └───source\n","| | |--training_volume_one.tif\n","| | |--training_volume_two.tif\n","| | |--...\n","| | |--training_volume_n.tif\n","| |\n","| └───target\n","| |--training_volume_one.tif\n","| |--training_volume_two.tif\n","| |--...\n","| |--training_volume_n.tif\n","|\n","|--data_to_predict_on.tif\n","|--prediction_results.tif\n","```\n","**Note:** Naming directories is completely up to you, as long as the paths are correctly specified throughout the notebook.\n","\n","\n","---\n","\n","\n","### **Important note**\n","\n","* If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do so), you will need to run **Sections 1 - 4**, then use **Section 5** to assess the quality of your model and **Section 6** to run predictions using the model that you trained.\n","\n","* If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **Sections 1 and 2** to set up the notebook, then use **Section 5** to assess the quality of your model.\n","\n","* If you only wish to **Run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **Sections 1 and 2** to set up the notebook, then use **Section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"code","metadata":{"id":"M-GZMaL7pd8a","cellView":"form"},"source":["#@markdown ##**Download example dataset**\n","\n","#@markdown This usually takes a few minutes. The images are saved in *example_dataset*.\n","\n","import requests \n","import os\n","from tqdm.notebook import tqdm \n","\n","def make_directory(dir):\n"," if not os.path.exists(dir):\n"," os.makedirs(dir)\n","\n","def download_from_url(url, save_as):\n"," file_url = url\n"," r = requests.get(file_url, stream=True) \n"," \n"," with open(save_as, 'wb') as file: \n"," for block in tqdm(r.iter_content(chunk_size = 1024), desc = 'Downloading ' + os.path.basename(save_as), total=126875, ncols=1000):\n"," if block:\n"," file.write(block) \n","\n","\n","make_directory('example_dataset')\n","\n","download_from_url('https://documents.epfl.ch/groups/c/cv/cvlab-unit/www/data/%20ElectronMicroscopy_Hippocampus/training.tif', 'example_dataset/training.tif')\n","download_from_url('https://documents.epfl.ch/groups/c/cv/cvlab-unit/www/data/%20ElectronMicroscopy_Hippocampus/training_groundtruth.tif', 'example_dataset/training_groundtruth.tif')\n","download_from_url('https://documents.epfl.ch/groups/c/cv/cvlab-unit/www/data/%20ElectronMicroscopy_Hippocampus/testing.tif', 'example_dataset/testing.tif')\n","download_from_url('https://documents.epfl.ch/groups/c/cv/cvlab-unit/www/data/%20ElectronMicroscopy_Hippocampus/testing_groundtruth.tif', 'example_dataset/testing_groundtruth.tif')\n","\n","print('Example dataset successfully downloaded!')"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"zxELU7CIp4oF","cellView":"form"},"source":["#@markdown ##Unzip pre-trained model directory\n","\n","#@markdown 1. Upload a zipped model directory using the *Files* tab\n","#@markdown 2. Run this cell to unzip your model file\n","#@markdown 3. The model directory will appear in the *Files* tab \n","\n","from google.colab import files\n","\n","zipped_model_file = \"\" #@param {type:\"string\"}\n","\n","!unzip \"$zipped_model_file\""],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"AdN8B91xZO0x"},"source":["# **1. Install 3D U-Net dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"fq21zJVFNASx","cellView":"form"},"source":["#@markdown ##Install dependencies and instantiate network\n","Notebook_version = '1.13'\n","Network = 'U-Net (3D)'\n","\n","from builtins import any as b_any\n","\n","def get_requirements_path():\n"," # Store requirements file in 'contents' directory \n"," current_dir = os.getcwd()\n"," dir_count = current_dir.count('/') - 1\n"," path = '../' * (dir_count) + 'requirements.txt'\n"," return path\n","\n","def filter_files(file_list, filter_list):\n"," filtered_list = []\n"," for fname in file_list:\n"," if b_any(fname.split('==')[0] in s for s in filter_list):\n"," filtered_list.append(fname)\n"," return filtered_list\n","\n","def build_requirements_file(before, after):\n"," path = get_requirements_path()\n","\n"," # Exporting requirements.txt for local run\n"," !pip freeze > $path\n","\n"," # Get minimum requirements file\n"," df = pd.read_csv(path, delimiter = \"\\n\")\n"," mod_list = [m.split('.')[0] for m in after if not m in before]\n"," req_list_temp = df.values.tolist()\n"," req_list = [x[0] for x in req_list_temp]\n","\n"," # Replace with package name and handle cases where import name is different to module name\n"," mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]\n"," mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list] \n"," filtered_list = filter_files(req_list, mod_replace_list)\n","\n"," file=open(path,'w')\n"," for item in filtered_list:\n"," file.writelines(item + '\\n')\n","\n"," file.close()\n","\n","import sys\n","before = [str(m) for m in sys.modules]\n","\n","#Put the imported code and libraries here\n","!pip install fpdf\n","from __future__ import absolute_import, division, print_function, unicode_literals\n","\n","try:\n"," import elasticdeform\n","except:\n"," !pip install elasticdeform\n"," import elasticdeform\n","\n","try:\n"," import tifffile\n","except:\n"," !pip install tifffile\n"," import tifffile\n","\n","try:\n"," import imgaug.augmenters as iaa\n","except:\n"," !pip install imgaug\n"," import imgaug.augmenters as iaa\n","\n","import os\n","import csv\n","import random\n","import h5py\n","import imageio\n","import math\n","import shutil\n","\n","import pandas as pd\n","from glob import glob\n","from tqdm import tqdm\n","\n","from skimage import transform\n","from skimage import exposure\n","from skimage import color\n","from skimage import io\n","\n","from scipy.ndimage import zoom\n","\n","import matplotlib.pyplot as plt\n","\n","import numpy as np\n","import tensorflow as tf\n","\n","# from keras import backend as K\n","\n","# from keras.layers import Conv3D\n","# from keras.layers import BatchNormalization\n","# from keras.layers import ReLU\n","# from keras.layers import MaxPooling3D\n","# from keras.layers import Conv3DTranspose\n","# from keras.layers import Input\n","# from keras.layers import Concatenate\n","\n","# from keras.models import Model\n","\n","# from keras.utils import Sequence\n","# from keras.callbacks import ModelCheckpoint\n","# from keras.callbacks import CSVLogger\n","# from keras.callbacks import Callback\n","\n","from tensorflow.keras import backend as K\n","\n","from tensorflow.keras.layers import Conv3D\n","from tensorflow.keras.layers import BatchNormalization\n","from tensorflow.keras.layers import ReLU\n","from tensorflow.keras.layers import MaxPooling3D\n","from tensorflow.keras.layers import Conv3DTranspose\n","from tensorflow.keras.layers import Input\n","from tensorflow.keras.layers import Concatenate\n","\n","from tensorflow.keras.models import Model\n","\n","from tensorflow.keras.utils import Sequence\n","from tensorflow.keras.callbacks import ModelCheckpoint\n","from tensorflow.keras.callbacks import CSVLogger\n","from tensorflow.keras.callbacks import Callback\n","\n","from tensorflow.keras.metrics import RootMeanSquaredError\n","\n","from tensorflow.keras.optimizers import Adam, SGD, RMSprop\n","\n","from ipywidgets import interact\n","from ipywidgets import interactive\n","from ipywidgets import fixed\n","from ipywidgets import interact_manual \n","import ipywidgets as widgets\n","\n","from fpdf import FPDF, HTMLMixin\n","from datetime import datetime\n","import subprocess\n","from pip._internal.operations.freeze import freeze\n","import time\n","\n","from skimage import io\n","import matplotlib\n","\n","print(\"Dependencies installed and imported.\")\n","\n","# Define MultiPageTiffGenerator class\n","class MultiPageTiffGenerator(Sequence):\n","\n"," def __init__(self,\n"," source_path,\n"," target_path,\n"," batch_size=1,\n"," shape=(128,128,32,1),\n"," augment=False,\n"," augmentations=[],\n"," deform_augment=False,\n"," deform_augmentation_params=(5,3,4),\n"," val_split=0.2,\n"," is_val=False,\n"," random_crop=True,\n"," downscale=1,\n"," binary_target=False):\n","\n"," # If directory with various multi-page tiffiles is provided read as list\n"," if os.path.isfile(source_path):\n"," self.dir_flag = False\n"," self.source = tifffile.imread(source_path)\n"," if binary_target:\n"," self.target = tifffile.imread(target_path).astype(np.bool)\n"," else:\n"," self.target = tifffile.imread(target_path)\n","\n"," elif os.path.isdir(source_path):\n"," self.dir_flag = True\n"," self.source_dir_list = glob(os.path.join(source_path, '*'))\n"," self.target_dir_list = glob(os.path.join(target_path, '*'))\n","\n"," self.source_dir_list.sort()\n"," self.target_dir_list.sort()\n","\n"," self.shape = shape\n"," self.batch_size = batch_size\n"," self.augment = augment\n"," self.val_split = val_split\n"," self.is_val = is_val\n"," self.random_crop = random_crop\n"," self.downscale = downscale\n"," self.binary_target = binary_target\n"," self.deform_augment = deform_augment\n"," self.on_epoch_end()\n"," \n"," if self.augment:\n"," # pass list of augmentation functions \n"," self.seq = iaa.Sequential(augmentations, random_order=True) # apply augmenters in random order\n"," if self.deform_augment:\n"," self.deform_sigma, self.deform_points, self.deform_order = deform_augmentation_params\n","\n"," def __len__(self):\n"," # If various multi-page tiff files provided sum all images within each\n"," if self.augment:\n"," augment_factor = 4\n"," else:\n"," augment_factor = 1\n"," \n"," if self.dir_flag:\n"," num_of_imgs = 0\n"," for tiff_path in self.source_dir_list:\n"," num_of_imgs += tifffile.imread(tiff_path).shape[0]\n"," xy_shape = tifffile.imread(self.source_dir_list[0]).shape[1:]\n","\n"," if self.is_val:\n"," if self.random_crop:\n"," crop_volume = self.shape[0] * self.shape[1] * self.shape[2]\n"," volume = xy_shape[0] * xy_shape[1] * self.val_split * num_of_imgs\n"," return math.floor(augment_factor * volume / (crop_volume * self.batch_size * self.downscale))\n"," else:\n"," return math.floor(self.val_split * num_of_imgs / self.batch_size)\n"," else:\n"," if self.random_crop:\n"," crop_volume = self.shape[0] * self.shape[1] * self.shape[2]\n"," volume = xy_shape[0] * xy_shape[1] * (1 - self.val_split) * num_of_imgs\n"," return math.floor(augment_factor * volume / (crop_volume * self.batch_size * self.downscale))\n","\n"," else:\n"," return math.floor(augment_factor*(1 - self.val_split) * num_of_imgs/self.batch_size)\n"," else:\n"," if self.is_val:\n"," if self.random_crop:\n"," crop_volume = self.shape[0] * self.shape[1] * self.shape[2]\n"," volume = self.source.shape[0] * self.source.shape[1] * self.val_split * self.source.shape[2]\n"," return math.floor(augment_factor * volume / (crop_volume * self.batch_size * self.downscale))\n"," else:\n"," return math.floor((self.val_split * self.source.shape[0] / self.batch_size))\n"," else:\n"," if self.random_crop:\n"," crop_volume = self.shape[0] * self.shape[1] * self.shape[2]\n"," volume = self.source.shape[0] * self.source.shape[1] * (1 - self.val_split) * self.source.shape[2]\n"," return math.floor(augment_factor * volume / (crop_volume * self.batch_size * self.downscale))\n"," else:\n"," return math.floor(augment_factor * (1 - self.val_split) * self.source.shape[0] / self.batch_size)\n","\n"," def __getitem__(self, idx):\n"," source_batch = np.empty((self.batch_size,\n"," self.shape[0],\n"," self.shape[1],\n"," self.shape[2],\n"," self.shape[3]))\n"," target_batch = np.empty((self.batch_size,\n"," self.shape[0],\n"," self.shape[1],\n"," self.shape[2],\n"," self.shape[3]))\n","\n"," for batch in range(self.batch_size):\n"," # Modulo operator ensures IndexError is avoided\n"," stack_start = self.batch_list[(idx+batch*self.shape[2])%len(self.batch_list)]\n","\n"," if self.dir_flag:\n"," self.source = tifffile.imread(self.source_dir_list[stack_start[0]])\n"," if self.binary_target:\n"," self.target = tifffile.imread(self.target_dir_list[stack_start[0]]).astype(np.bool)\n"," else:\n"," self.target = tifffile.imread(self.target_dir_list[stack_start[0]])\n","\n"," src_list = []\n"," tgt_list = []\n"," for i in range(stack_start[1], stack_start[1]+self.shape[2]):\n"," src = self.source[i]\n"," src = transform.downscale_local_mean(src, (self.downscale, self.downscale))\n"," if not self.random_crop:\n"," src = transform.resize(src, (self.shape[0], self.shape[1]), mode='constant', preserve_range=True)\n"," src = self._min_max_scaling(src)\n"," src_list.append(src)\n","\n"," tgt = self.target[i]\n"," tgt = transform.downscale_local_mean(tgt, (self.downscale, self.downscale))\n"," if not self.random_crop:\n"," tgt = transform.resize(tgt, (self.shape[0], self.shape[1]), mode='constant', preserve_range=True)\n"," if not self.binary_target:\n"," tgt = self._min_max_scaling(tgt)\n"," tgt_list.append(tgt)\n","\n"," if self.random_crop:\n"," if src.shape[0] == self.shape[0]:\n"," x_rand = 0\n"," if src.shape[1] == self.shape[1]:\n"," y_rand = 0\n"," if src.shape[0] > self.shape[0]:\n"," x_rand = np.random.randint(src.shape[0] - self.shape[0])\n"," if src.shape[1] > self.shape[1]:\n"," y_rand = np.random.randint(src.shape[1] - self.shape[1])\n"," if src.shape[0] < self.shape[0] or src.shape[1] < self.shape[1]:\n"," raise ValueError('Patch shape larger than (downscaled) source shape')\n"," \n"," for i in range(self.shape[2]):\n"," if self.random_crop:\n"," src = src_list[i]\n"," tgt = tgt_list[i]\n"," src_crop = src[x_rand:self.shape[0]+x_rand, y_rand:self.shape[1]+y_rand]\n"," tgt_crop = tgt[x_rand:self.shape[0]+x_rand, y_rand:self.shape[1]+y_rand]\n"," else:\n"," src_crop = src_list[i]\n"," tgt_crop = tgt_list[i]\n","\n"," source_batch[batch,:,:,i,0] = src_crop\n"," target_batch[batch,:,:,i,0] = tgt_crop\n","\n"," if self.augment:\n"," # On-the-fly data augmentation\n"," source_batch, target_batch = self.augment_volume(source_batch, target_batch)\n","\n"," # Data augmentation by reversing stack\n"," if np.random.random() > 0.5:\n"," source_batch, target_batch = source_batch[::-1], target_batch[::-1]\n"," \n"," # Data augmentation by elastic deformation\n"," if np.random.random() > 0.5 and self.deform_augment:\n"," source_batch, target_batch = self.deform_volume(source_batch, target_batch)\n"," \n"," if not self.binary_target:\n"," target_batch = self._min_max_scaling(target_batch)\n"," \n"," return self._min_max_scaling(source_batch), target_batch\n"," \n"," else:\n"," return source_batch, target_batch\n","\n"," def on_epoch_end(self):\n"," # Validation split performed here\n"," self.batch_list = []\n"," # Create batch_list of all combinations of tifffile and stack position\n"," if self.dir_flag:\n"," for i in range(len(self.source_dir_list)):\n"," num_of_pages = tifffile.imread(self.source_dir_list[i]).shape[0]\n"," if self.is_val:\n"," start_page = num_of_pages-math.floor(self.val_split*num_of_pages)\n"," for j in range(start_page, num_of_pages-self.shape[2]):\n"," self.batch_list.append([i, j])\n"," else:\n"," last_page = math.floor((1-self.val_split)*num_of_pages)\n"," for j in range(last_page-self.shape[2]):\n"," self.batch_list.append([i, j])\n"," else:\n"," num_of_pages = self.source.shape[0]\n"," if self.is_val:\n"," start_page = num_of_pages-math.floor(self.val_split*num_of_pages)\n"," for j in range(start_page, num_of_pages-self.shape[2]):\n"," self.batch_list.append([0, j])\n","\n"," else:\n"," last_page = math.floor((1-self.val_split)*num_of_pages)\n"," for j in range(last_page-self.shape[2]):\n"," self.batch_list.append([0, j])\n"," \n"," if self.is_val and (len(self.batch_list) <= 0):\n"," raise ValueError('validation_split too small! Increase val_split or decrease z-depth')\n"," random.shuffle(self.batch_list)\n"," \n"," def _min_max_scaling(self, data):\n"," n = data - np.min(data)\n"," d = np.max(data) - np.min(data) \n"," \n"," return n/d\n"," \n"," def class_weights(self):\n"," ones = 0\n"," pixels = 0\n","\n"," if self.dir_flag:\n"," for i in range(len(self.target_dir_list)):\n"," tgt = tifffile.imread(self.target_dir_list[i]).astype(np.bool)\n"," ones += np.sum(tgt)\n"," pixels += tgt.shape[0]*tgt.shape[1]*tgt.shape[2]\n"," else:\n"," ones = np.sum(self.target)\n"," pixels = self.target.shape[0]*self.target.shape[1]*self.target.shape[2]\n"," p_ones = ones/pixels\n"," p_zeros = 1-p_ones\n","\n"," # Return swapped probability to increase weight of unlikely class\n"," return p_ones, p_zeros\n","\n"," def deform_volume(self, src_vol, tgt_vol):\n"," [src_dfrm, tgt_dfrm] = elasticdeform.deform_random_grid([src_vol, tgt_vol],\n"," axis=(1, 2, 3),\n"," sigma=self.deform_sigma,\n"," points=self.deform_points,\n"," order=self.deform_order)\n"," if self.binary_target:\n"," tgt_dfrm = tgt_dfrm > 0.1\n"," \n"," return self._min_max_scaling(src_dfrm), tgt_dfrm \n","\n"," def augment_volume(self, src_vol, tgt_vol):\n"," src_vol_aug = np.empty(src_vol.shape)\n"," tgt_vol_aug = np.empty(tgt_vol.shape)\n","\n"," for i in range(src_vol.shape[3]):\n"," src_vol_aug[:,:,:,i,0], tgt_vol_aug[:,:,:,i,0] = self.seq(images=src_vol[:,:,:,i,0].astype('float16'), \n"," segmentation_maps=tgt_vol[:,:,:,i,0].astype(bool))\n"," return self._min_max_scaling(src_vol_aug), tgt_vol_aug\n","\n"," def sample_augmentation(self, idx):\n"," src, tgt = self.__getitem__(idx)\n","\n"," src_aug, tgt_aug = self.augment_volume(src, tgt)\n"," \n"," if self.deform_augment:\n"," src_aug, tgt_aug = self.deform_volume(src_aug, tgt_aug)\n","\n"," return src_aug, tgt_aug \n","\n","# Define custom loss and dice coefficient\n","def dice_coefficient(y_true, y_pred):\n"," eps = 1e-6\n"," y_true_f = K.flatten(y_true)\n"," y_pred_f = K.flatten(y_pred)\n"," intersection = K.sum(y_true_f*y_pred_f)\n","\n"," return (2.*intersection)/(K.sum(y_true_f*y_true_f)+K.sum(y_pred_f*y_pred_f)+eps)\n","\n","def weighted_binary_crossentropy(zero_weight, one_weight):\n"," def _weighted_binary_crossentropy(y_true, y_pred):\n"," binary_crossentropy = K.binary_crossentropy(y_true, y_pred)\n","\n"," weight_vector = y_true*one_weight+(1.-y_true)*zero_weight\n"," weighted_binary_crossentropy = weight_vector*binary_crossentropy\n","\n"," return K.mean(weighted_binary_crossentropy)\n","\n"," return _weighted_binary_crossentropy\n","\n","# Custom callback showing sample prediction\n","class SampleImageCallback(Callback):\n","\n"," def __init__(self, model, sample_data, model_path, save=False):\n"," self.model = model\n"," self.sample_data = sample_data\n"," self.model_path = model_path\n"," self.save = save\n","\n"," def on_epoch_end(self, epoch, logs={}):\n"," sample_predict = self.model.predict_on_batch(self.sample_data)\n","\n"," f=plt.figure(figsize=(16,8))\n"," plt.subplot(1,2,1)\n"," plt.imshow(self.sample_data[0,:,:,0,0], interpolation='nearest', cmap='gray')\n"," plt.title('Sample source')\n"," plt.axis('off');\n","\n"," plt.subplot(1,2,2)\n"," plt.imshow(sample_predict[0,:,:,0,0], interpolation='nearest', cmap='magma')\n"," plt.title('Predicted target')\n"," plt.axis('off');\n","\n"," plt.show()\n","\n"," if self.save:\n"," plt.savefig(self.model_path + '/epoch_' + str(epoch+1) + '.png')\n","\n","\n","# Define Unet3D class\n","class Unet3D:\n","\n"," def __init__(self,\n"," shape=(256,256,16,1)):\n"," if isinstance(shape, str):\n"," shape = eval(shape)\n","\n"," self.shape = shape\n"," \n"," input_tensor = Input(self.shape, name='input')\n","\n"," self.model = self.unet_3D(input_tensor)\n","\n"," def down_block_3D(self, input_tensor, filters):\n"," x = Conv3D(filters=filters, kernel_size=(3,3,3), padding='same')(input_tensor)\n"," x = BatchNormalization()(x)\n"," x = ReLU()(x)\n","\n"," x = Conv3D(filters=filters*2, kernel_size=(3,3,3), padding='same')(x)\n"," x = BatchNormalization()(x)\n"," x = ReLU()(x)\n","\n"," return x\n","\n"," def up_block_3D(self, input_tensor, concat_layer, filters):\n"," x = Conv3DTranspose(filters, kernel_size=(2,2,2), strides=(2,2,2))(input_tensor)\n","\n"," x = Concatenate()([x, concat_layer])\n","\n"," x = Conv3D(filters=filters, kernel_size=(3,3,3), padding='same')(x)\n"," x = BatchNormalization()(x)\n"," x = ReLU()(x)\n","\n"," x = Conv3D(filters=filters*2, kernel_size=(3,3,3), padding='same')(x)\n"," x = BatchNormalization()(x)\n"," x = ReLU()(x)\n","\n"," return x\n","\n"," def unet_3D(self, input_tensor, filters=32):\n"," d1 = self.down_block_3D(input_tensor, filters=filters)\n"," p1 = MaxPooling3D(pool_size=(2,2,2), strides=(2,2,2), data_format='channels_last')(d1)\n"," d2 = self.down_block_3D(p1, filters=filters*2)\n"," p2 = MaxPooling3D(pool_size=(2,2,2), strides=(2,2,2), data_format='channels_last')(d2)\n"," d3 = self.down_block_3D(p2, filters=filters*4)\n"," p3 = MaxPooling3D(pool_size=(2,2,2), strides=(2,2,2), data_format='channels_last')(d3)\n","\n"," d4 = self.down_block_3D(p3, filters=filters*8)\n","\n"," u1 = self.up_block_3D(d4, d3, filters=filters*4)\n"," u2 = self.up_block_3D(u1, d2, filters=filters*2)\n"," u3 = self.up_block_3D(u2, d1, filters=filters)\n","\n"," output_tensor = Conv3D(filters=1, kernel_size=(1,1,1), activation='sigmoid')(u3)\n","\n"," return Model(inputs=[input_tensor], outputs=[output_tensor])\n","\n"," def summary(self):\n"," return self.model.summary()\n","\n"," # Pass generators instead\n"," def train(self, \n"," epochs, \n"," batch_size, \n"," train_generator,\n"," val_generator, \n"," model_path, \n"," model_name,\n"," optimizer='adam',\n"," learning_rate=0.001,\n"," loss='weighted_binary_crossentropy',\n"," metrics='dice',\n"," ckpt_period=1, \n"," save_best_ckpt_only=False, \n"," ckpt_path=None):\n","\n"," class_weight_zero, class_weight_one = train_generator.class_weights()\n"," \n"," if loss == 'weighted_binary_crossentropy':\n"," loss = weighted_binary_crossentropy(class_weight_zero, class_weight_one)\n"," \n"," if metrics == 'dice':\n"," metrics = dice_coefficient\n","\n"," if optimizer == 'adam':\n"," optimizer = Adam(learning_rate=learning_rate)\n"," elif optimizer == 'sgd':\n"," optimizer = SGD(learning_rate=learning_rate)\n"," elif optimizer == 'rmsprop':\n"," optimizer = RMSprop(learning_rate=learning_rate)\n","\n"," self.model.compile(optimizer=optimizer,\n"," loss=loss,\n"," metrics=[metrics])\n","\n"," if ckpt_path is not None:\n"," self.model.load_weights(ckpt_path)\n","\n"," full_model_path = os.path.join(model_path, model_name)\n","\n"," if not os.path.exists(full_model_path):\n"," os.makedirs(full_model_path)\n"," \n"," log_dir = full_model_path + '/Quality Control'\n","\n"," if not os.path.exists(log_dir):\n"," os.makedirs(log_dir)\n"," \n"," ckpt_dir = full_model_path + '/ckpt'\n","\n"," if not os.path.exists(ckpt_dir):\n"," os.makedirs(ckpt_dir)\n","\n"," csv_out_name = log_dir + '/training_evaluation.csv'\n"," if ckpt_path is None:\n"," csv_logger = CSVLogger(csv_out_name)\n"," else:\n"," csv_logger = CSVLogger(csv_out_name, append=True)\n","\n"," if save_best_ckpt_only:\n"," ckpt_name = ckpt_dir + '/' + model_name + '.hdf5'\n"," else:\n"," ckpt_name = ckpt_dir + '/' + model_name + '_epoch_{epoch:02d}_val_loss_{val_loss:.4f}.hdf5'\n"," \n"," model_ckpt = ModelCheckpoint(ckpt_name,\n"," verbose=1,\n"," period=ckpt_period,\n"," save_best_only=save_best_ckpt_only,\n"," save_weights_only=True)\n","\n"," sample_batch, __ = val_generator.__getitem__(random.randint(0, len(val_generator)))\n"," sample_img = SampleImageCallback(self.model, \n"," sample_batch, \n"," model_path)\n","\n"," self.model.fit_generator(generator=train_generator,\n"," validation_data=val_generator,\n"," validation_steps=math.floor(len(val_generator)/batch_size),\n"," epochs=epochs,\n"," callbacks=[csv_logger,\n"," model_ckpt,\n"," sample_img])\n","\n"," last_ckpt_name = ckpt_dir + '/' + model_name + '_last.hdf5'\n"," self.model.save_weights(last_ckpt_name)\n","\n"," def _min_max_scaling(self, data):\n"," n = data - np.min(data)\n"," d = np.max(data) - np.min(data) \n"," \n"," return n/d\n","\n"," def predict(self, \n"," input, \n"," ckpt_path, \n"," z_range=None, \n"," downscaling=None, \n"," true_patch_size=None):\n","\n"," self.model.load_weights(ckpt_path)\n","\n"," if isinstance(downscaling, str):\n"," downscaling = eval(downscaling)\n","\n"," if math.isnan(downscaling):\n"," downscaling = None\n","\n"," if isinstance(true_patch_size, str):\n"," true_patch_size = eval(true_patch_size)\n"," \n"," if not isinstance(true_patch_size, tuple): \n"," if math.isnan(true_patch_size):\n"," true_patch_size = None\n","\n"," if isinstance(input, str):\n"," src_volume = tifffile.imread(input)\n"," elif isinstance(input, np.ndarray):\n"," src_volume = input\n"," else:\n"," raise TypeError('Input is not path or numpy array!')\n"," \n"," in_size = src_volume.shape\n","\n"," if downscaling or true_patch_size is not None:\n"," x_scaling = 0\n"," y_scaling = 0\n","\n"," if true_patch_size is not None:\n"," x_scaling += true_patch_size[0]/self.shape[0]\n"," y_scaling += true_patch_size[1]/self.shape[1]\n"," if downscaling is not None:\n"," x_scaling += downscaling\n"," y_scaling += downscaling\n","\n"," src_list = []\n"," for i in range(src_volume.shape[0]):\n"," src_list.append(transform.downscale_local_mean(src_volume[i], (int(x_scaling), int(y_scaling))))\n"," src_volume = np.array(src_list) \n","\n"," if z_range is not None:\n"," src_volume = src_volume[z_range[0]:z_range[1]]\n","\n"," src_volume = self._min_max_scaling(src_volume) \n","\n"," src_array = np.zeros((1,\n"," math.ceil(src_volume.shape[1]/self.shape[0])*self.shape[0], \n"," math.ceil(src_volume.shape[2]/self.shape[1])*self.shape[1],\n"," math.ceil(src_volume.shape[0]/self.shape[2])*self.shape[2], \n"," self.shape[3]))\n","\n"," for i in range(src_volume.shape[0]):\n"," src_array[0,:src_volume.shape[1],:src_volume.shape[2],i,0] = src_volume[i]\n","\n"," pred_array = np.empty(src_array.shape)\n","\n"," for i in range(math.ceil(src_volume.shape[1]/self.shape[0])):\n"," for j in range(math.ceil(src_volume.shape[2]/self.shape[1])):\n"," for k in range(math.ceil(src_volume.shape[0]/self.shape[2])):\n"," pred_temp = self.model.predict(src_array[:,\n"," i*self.shape[0]:i*self.shape[0]+self.shape[0],\n"," j*self.shape[1]:j*self.shape[1]+self.shape[1],\n"," k*self.shape[2]:k*self.shape[2]+self.shape[2]])\n"," pred_array[:,\n"," i*self.shape[0]:i*self.shape[0]+self.shape[0],\n"," j*self.shape[1]:j*self.shape[1]+self.shape[1],\n"," k*self.shape[2]:k*self.shape[2]+self.shape[2]] = pred_temp\n"," \n"," pred_volume = np.rollaxis(np.squeeze(pred_array), -1)[:src_volume.shape[0],:src_volume.shape[1],:src_volume.shape[2]] \n","\n"," if downscaling is not None:\n"," pred_list = []\n"," for i in range(pred_volume.shape[0]):\n"," pred_list.append(transform.resize(pred_volume[i], (in_size[1], in_size[2]), preserve_range=True))\n"," pred_volume = np.array(pred_list)\n","\n"," return pred_volume\n","\n","def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," \n"," # add another cell \n"," if trained:\n"," training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n"," pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n"," pdf.ln(1)\n","\n"," Header_2 = 'Information for your materials and methods:'\n"," pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n"," #print(all_packages)\n","\n"," #Main Packages\n"," main_packages = ''\n"," version_numbers = []\n"," for name in ['tensorflow','numpy','Keras']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n"," cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n"," cuda_version = cuda_version.stdout.decode('utf-8')\n"," cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n"," gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n"," gpu_name = gpu_name.stdout.decode('utf-8')\n"," gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n"," #print(cuda_version[cuda_version.find(', V')+3:-1])\n"," #print(gpu_name)\n","\n"," if os.path.isdir(training_source):\n"," shape = io.imread(training_source+'/'+os.listdir(training_source)[0]).shape\n"," elif os.path.isfile(training_source):\n"," shape = io.imread(training_source).shape\n"," else:\n"," print('Cannot read training data.')\n","\n"," dataset_size = len(train_generator)\n","\n"," text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+') with a batch size of '+str(batch_size)+' and a '+loss_function+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," if pretrained_model:\n"," text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch_size: '+str(patch_size)+') with a batch size of '+str(batch_size)+' and a '+loss_function+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was retrained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.multi_cell(190, 5, txt = text, align='L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n"," pdf.set_font('')\n"," if augmentation:\n"," aug_text = 'The dataset was augmented by'\n"," if add_gaussian_blur == True:\n"," aug_text = aug_text+'\\n- gaussian blur'\n"," if add_linear_contrast == True:\n"," aug_text = aug_text+'\\n- linear contrast'\n"," if add_additive_gaussian_noise == True:\n"," aug_text = aug_text+'\\n- additive gaussian noise'\n"," if augmenters != '':\n"," aug_text = aug_text+'\\n- imgaug augmentations: '+augmenters\n"," if add_elastic_deform == True:\n"," aug_text = aug_text+'\\n- elastic deformation'\n"," else:\n"," aug_text = 'No augmentation was used for training.'\n"," pdf.multi_cell(190, 5, txt=aug_text, align='L')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," if use_default_advanced_parameters:\n"," pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n"," pdf.cell(200, 5, txt='The following parameters were used for training:')\n"," pdf.ln(1)\n"," html = \"\"\" \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
number_of_epochs{0}
batch_size{1}
patch_size{2}
image_pre_processing{3}
validation_split_in_percent{4}
downscaling_in_xy{5}
binary_target{6}
loss_function{7}
metrics{8}
optimizer{9}
checkpointing_period{10}
save_best_only{11}
resume_training{12}
\n"," \"\"\".format(number_of_epochs,batch_size,str(patch_size[0])+'x'+str(patch_size[1])+'x'+str(patch_size[2]),image_pre_processing, validation_split_in_percent, downscaling_in_xy, str(binary_target), loss_function, metrics, optimizer, checkpointing_period, str(save_best_only), str(resume_training))\n"," pdf.write_html(html)\n","\n"," #pdf.multi_cell(190, 5, txt = text_2, align='L')\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," pdf.ln(1)\n"," pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(30, 5, txt= 'Training_source:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = training_source, align = 'L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(28, 5, txt= 'Training_target:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = training_target, align = 'L')\n"," #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(21, 5, txt= 'Model Path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n"," pdf.ln(1)\n"," pdf.cell(60, 5, txt = 'Example Training pair (single slice)', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread('/content/TrainingDataExample_Unet3D.png').shape\n"," pdf.image('/content/TrainingDataExample_Unet3D.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- Unet 3D: Çiçek, Özgün, et al. \"3D U-Net: learning dense volumetric segmentation from sparse annotation.\" International conference on medical image computing and computer-assisted intervention. Springer, Cham, 2016.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," # if Use_Data_augmentation:\n"," # ref_4 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\n"," # pdf.multi_cell(190, 5, txt = ref_4, align='L')\n"," pdf.ln(3)\n"," reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(model_path+'/'+model_name+'/'+model_name+'_training_report.pdf')\n","\n"," print('------------------------------')\n"," print('PDF report exported in '+model_path+'/'+model_name+'/')\n","\n","\n","def qc_pdf_export():\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'U-Net 3D'\n","\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Quality Control report for '+Network+' model ('+qc_model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(2)\n"," pdf.cell(190, 5, txt = 'Loss curves', ln=1, align='L')\n"," pdf.ln(1)\n"," if os.path.exists(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/lossCurvePlots.png'):\n"," exp_size = io.imread(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/lossCurvePlots.png').shape\n"," pdf.image(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.')\n"," pdf.ln(2)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(3)\n"," pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/QC_example_data.png').shape\n"," pdf.image(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'IoU threshold optimisation', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.ln(1)\n"," pdf.cell(120, 5, txt='Highest IoU is {:.4f} with a threshold of {}'.format(best_iou, best_thresh), align='L', ln=1)\n"," pdf.ln(2)\n"," exp_size = io.imread(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/QC_IoU_analysis.png').shape\n"," pdf.image(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/QC_IoU_analysis.png', x=16, y=None, w = round(exp_size[1]/6), h = round(exp_size[0]/6))\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- Unet 3D: Çiçek, Özgün, et al. \"3D U-Net: learning dense volumetric segmentation from sparse annotation.\" International conference on medical image computing and computer-assisted intervention. Springer, Cham, 2016.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n"," pdf.ln(3)\n"," reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/'+qc_model_name+'_QC_report.pdf')\n","\n"," print('------------------------------')\n"," print('QC PDF report exported in '+os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/')\n","\n","\n","# -------------- Other definitions -----------\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","prediction_prefix = 'Predicted_'\n","\n","\n","print('-------------------')\n","print('U-Net 3D and dependencies installed.')\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n"," NORMAL = '\\033[0m' # white (normal)\n"," \n","\n","# Check if this is the latest version of the notebook\n","# Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n","\n","# if Notebook_version == list(Latest_notebook_version.columns):\n","# print(\"This notebook is up-to-date.\")\n","\n","# if not Notebook_version == list(Latest_notebook_version.columns):\n","# print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","All_notebook_versions = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_Notebook_versions.csv\", dtype=str)\n","print('Notebook version: '+Notebook_version)\n","Latest_Notebook_version = All_notebook_versions[All_notebook_versions[\"Notebook\"] == Network]['Version'].iloc[0]\n","print('Latest notebook version: '+Latest_Notebook_version)\n","if Notebook_version == Latest_Notebook_version:\n"," print(\"This notebook is up-to-date.\")\n","else:\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","\n","\n","# Build requirements file for local run\n","after = [str(m) for m in sys.modules]\n","build_requirements_file(before, after)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **2. Complete the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["\n","## **2.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelerator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"zCvebubeSaGY","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"sNIVx8_CLolt"},"source":["## **2.2. Mount Google Drive**\n","---\n"," To use this notebook with your **own data**, place it in a folder on **Google Drive** following one of the directory structures outlined in **Section 0**.\n","\n","1. **Run** the **cell** below to mount your Google Drive and follow the link. \n","\n","2. **Sign in** to your Google account and press 'Allow'. \n","\n","3. Next, copy the **authorization code**, paste it into the cell and press enter. This will allow Colab to read and write data from and to your Google Drive. \n","\n","4. Once this is done, your data can be viewed in the **Files tab** on the top left of the notebook after hitting 'Refresh'."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","cellView":"form"},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"HLYcZR9gMv42"},"source":["# **3. Select your parameters and paths**\n","---"]},{"cell_type":"markdown","metadata":{"id":"AuESFimvMv43"},"source":["## **3.1. Choosing parameters**\n","\n","---\n","\n","### **Paths to training data and model**\n","\n","* **`training_source`** and **`training_target`** specify the paths to the training data. They can either be a single multipage TIFF file each or directories containing various multipage TIFF files in which case target and source files must be named identically within the respective directories. See Section 0 for a detailed description of the necessary directory structure.\n","\n","* **`model_name`** will be used when naming checkpoints. Adhere to a `lower_case_with_underscores` naming convention and beware of using the name of an existing model within the same folder, as it will be overwritten.\n","\n","* **`model_path`** specifies the directory where the model checkpoints and quality control logs will be saved.\n","\n","\n","**Note:** You can copy paths from the 'Files' tab by right-clicking any folder or file and selecting 'Copy path'. \n","\n","### **Training parameters**\n","\n","* **`number_of_epochs`** is the number of times the entire training data will be seen by the model. *Default: >100*\n","\n","* **`batch_size`** is the number of training patches of size `patch_size` that will be bundled together at each training step. *Default: 1*\n","\n","* **`patch_size`** specifies the size of the three-dimensional training patches in (x, y, z) that will be fed to the model. In order to avoid errors, preferably use a square aspect ratio or stick to the advanced parameters. *Default: <(512, 512, 16)*\n","\n","* **`validation_split_in_percent`** is the relative amount of training data that will be set aside for validation. *Default: 20* \n","\n","* **`downscaling_in_xy`** downscales the training images by the specified amount in x and y. This is useful to enforce isotropic pixel-size if the z resolution is lower than the xy resolution in the training volume or to capture a larger field-of-view while decreasing the memory requirements. *Default: 1*\n","\n","* **`image_pre_processing`** selects whether the training images are randomly cropped during training or resized to `patch_size`. Choose `randomly crop to patch_size` to shrink the field-of-view of the training images to the `patch_size`. *Default: resize to patch_size* \n","\n","* **`binary_target`** forces the target image to be binary. Choose this if your model is trained to perform binary segmentation tasks *Default: True* \n","\n","* **`loss_function`** defines the loss. Read more [here](https://keras.io/api/losses/). *Default: weighted_binary_crossentropy* \n","\n","* **`metrics`** defines the metric. Read more [here](https://keras.io/api/metrics/). *Default: dice* \n","\n","* **`optimizer`** defines the optimizer. Read more [here](https://keras.io/api/optimizers/). *Default: adam* \n","\n","* **`learning_rate`** defines the learning rate. Read more [here](https://keras.io/api/optimizers/). *Default: 0.001* \n","\n","**Note:** If a *ResourceExhaustedError* is raised in Section 4.1. during training, decrease `batch_size` and `patch_size`. Decrease `batch_size` first and if the error persists at `batch_size = 1`, reduce the `patch_size`. \n","\n","**Note:** The number of steps per epoch are calculated as `floor(augment_factor * (1 - validation_split) * num_of_slices / batch_size)` if `image_pre_processing` is `resize to patch_size` where `augment_factor` is three if `apply_data_augmentation` is `True` and one otherwise. The `num_of_slices` is the overall number of slices (z-depth) in the training set across all provided image volumes. If `image_pre_processing` is `randomly crop to patch_size`, the number of steps per epoch are calculated as `floor(augment_factor * volume / (crop_volume * batch_size))` where `volume` is the overall volume of the training data in pixels accounting for the validation split and `crop_volume` is defined as the volume in pixels based on the specified `patch_size`."]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","cellView":"form"},"source":["#@markdown ###Path to training data:\n","training_source = \"\" #@param {type:\"string\"}\n","training_target = \"\" #@param {type:\"string\"}\n","\n","#@markdown ---\n","\n","#@markdown ###Model name and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","full_model_path = os.path.join(model_path, model_name)\n","\n","#@markdown ---\n","\n","#@markdown ###Training parameters\n","number_of_epochs = 100#@param {type:\"number\"}\n","\n","#@markdown ###Default advanced parameters\n","use_default_advanced_parameters = True #@param {type:\"boolean\"}\n","\n","#@markdown If not, please change:\n","\n","batch_size = 1#@param {type:\"number\"}\n","patch_size = (256,256,8) #@param {type:\"number\"} # in pixels\n","training_shape = patch_size + (1,)\n","image_pre_processing = 'resize to patch_size' #@param [\"randomly crop to patch_size\", \"resize to patch_size\"]\n","\n","validation_split_in_percent = 20 #@param{type:\"number\"}\n","downscaling_in_xy = 1#@param {type:\"number\"} # in pixels\n","\n","binary_target = True #@param {type:\"boolean\"}\n","\n","loss_function = 'weighted_binary_crossentropy' #@param [\"weighted_binary_crossentropy\", \"binary_crossentropy\", \"categorical_crossentropy\", \"sparse_categorical_crossentropy\", \"mean_squared_error\", \"mean_absolute_error\"]\n","\n","metrics = 'dice' #@param [\"dice\", \"accuracy\"]\n","\n","optimizer = 'adam' #@param [\"adam\", \"sgd\", \"rmsprop\"]\n","\n","learning_rate = 0.00001 #@param{type:\"number\"}\n","\n","if image_pre_processing == \"randomly crop to patch_size\":\n"," random_crop = True\n","else:\n"," random_crop = False\n","\n","if use_default_advanced_parameters: \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 1\n"," training_shape = (256,256,8,1)\n"," validation_split_in_percent = 20\n"," downscaling_in_xy = 1\n"," random_crop = True\n"," binary_target = True\n"," loss_function = 'weighted_binary_crossentropy'\n"," metrics = 'dice'\n"," optimizer = 'adam'\n"," learning_rate = 0.001 \n"," \n","#@markdown ###Checkpointing parameters\n","checkpointing_period = 1 #@param {type:\"number\"}\n","\n","#@markdown If chosen only the best checkpoint is saved, otherwise a checkpoint is saved every checkpoint_period epochs:\n","save_best_only = True #@param {type:\"boolean\"}\n","\n","#@markdown ###Resume training\n","#@markdown Choose if training was interrupted:\n","resume_training = False #@param {type:\"boolean\"}\n","\n","#@markdown ###Transfer learning\n","#@markdown For transfer learning, do not select resume_training and specify a checkpoint_path below:\n","checkpoint_path = \"\" #@param {type:\"string\"}\n","\n","if resume_training and checkpoint_path != \"\":\n"," print('If resume_training is True while checkpoint_path is specified, resume_training will be set to False!')\n"," resume_training = False\n"," \n","\n","# Retrieve last checkpoint\n","if resume_training:\n"," try:\n"," ckpt_dir_list = glob(full_model_path + '/ckpt/*')\n"," ckpt_dir_list.sort()\n"," last_ckpt_path = ckpt_dir_list[-1]\n"," print('Training will resume from checkpoint:', os.path.basename(last_ckpt_path))\n"," except IndexError:\n"," last_ckpt_path=None\n"," print('CheckpointError: No previous checkpoints were found, training from scratch.')\n","elif not resume_training and checkpoint_path != \"\":\n"," last_ckpt_path = checkpoint_path\n"," assert os.path.isfile(last_ckpt_path), 'checkpoint_path does not exist!'\n","else:\n"," last_ckpt_path=None\n","\n","# Instantiate Unet3D \n","model = Unet3D(shape=training_shape)\n","\n","#here we check that no model with the same name already exist\n","if not resume_training and os.path.exists(full_model_path): \n"," print(bcolors.WARNING+'The model folder already exists and will be overwritten.'+bcolors.NORMAL)\n"," # print('!! WARNING: Folder already exists and will be overwritten !!') \n"," # shutil.rmtree(full_model_path)\n","\n","# if not os.path.exists(full_model_path):\n","# os.makedirs(full_model_path)\n","\n","# Show sample image\n","if os.path.isdir(training_source):\n"," training_source_sample = sorted(glob(os.path.join(training_source, '*')))[0]\n"," training_target_sample = sorted(glob(os.path.join(training_target, '*')))[0]\n","else:\n"," training_source_sample = training_source\n"," training_target_sample = training_target\n","\n","src_sample = tifffile.imread(training_source_sample)\n","src_sample = model._min_max_scaling(src_sample)\n","if binary_target:\n"," tgt_sample = tifffile.imread(training_target_sample).astype(np.bool)\n","else:\n"," tgt_sample = tifffile.imread(training_target_sample)\n","\n","src_down = transform.downscale_local_mean(src_sample[0], (downscaling_in_xy, downscaling_in_xy))\n","tgt_down = transform.downscale_local_mean(tgt_sample[0], (downscaling_in_xy, downscaling_in_xy)) \n","\n","if random_crop:\n"," true_patch_size = None\n","\n"," if src_down.shape[0] == training_shape[0]:\n"," x_rand = 0\n"," if src_down.shape[1] == training_shape[1]:\n"," y_rand = 0\n"," if src_down.shape[0] > training_shape[0]:\n"," x_rand = np.random.randint(src_down.shape[0] - training_shape[0])\n"," if src_down.shape[1] > training_shape[1]:\n"," y_rand = np.random.randint(src_down.shape[1] - training_shape[1])\n"," if src_down.shape[0] < training_shape[0] or src_down.shape[1] < training_shape[1]:\n"," raise ValueError('Patch shape larger than (downscaled) source shape')\n","else:\n"," true_patch_size = src_down.shape\n","\n","def scroll_in_z(z):\n"," src_down = transform.downscale_local_mean(src_sample[z-1], (downscaling_in_xy,downscaling_in_xy))\n"," tgt_down = transform.downscale_local_mean(tgt_sample[z-1], (downscaling_in_xy,downscaling_in_xy)) \n"," if random_crop:\n"," src_slice = src_down[x_rand:training_shape[0]+x_rand, y_rand:training_shape[1]+y_rand]\n"," tgt_slice = tgt_down[x_rand:training_shape[0]+x_rand, y_rand:training_shape[1]+y_rand]\n"," else:\n"," \n"," src_slice = transform.resize(src_down, (training_shape[0], training_shape[1]), mode='constant', preserve_range=True)\n"," tgt_slice = transform.resize(tgt_down, (training_shape[0], training_shape[1]), mode='constant', preserve_range=True)\n","\n"," f=plt.figure(figsize=(16,8))\n"," plt.subplot(1,2,1)\n"," plt.imshow(src_slice, cmap='gray')\n"," plt.title('Training source (z = ' + str(z) + ')', fontsize=15)\n"," plt.axis('off')\n","\n"," plt.subplot(1,2,2)\n"," plt.imshow(tgt_slice, cmap='magma')\n"," plt.title('Training target (z = ' + str(z) + ')', fontsize=15)\n"," plt.axis('off')\n"," plt.savefig('/content/TrainingDataExample_Unet3D.png',bbox_inches='tight',pad_inches=0)\n"," #plt.close()\n","\n","print('This is what the training images will look like with the chosen settings')\n","interact(scroll_in_z, z=widgets.IntSlider(min=1, max=src_sample.shape[0], step=1, value=0));\n","plt.show()\n","#Create a copy of an example slice and close the display.\n","scroll_in_z(z=int(src_sample.shape[0]/2))\n","# If you close the display, then the users can't interactively inspect the data\n","# plt.close()\n","\n","# Save model parameters\n","params = {'training_source': training_source,\n"," 'training_target': training_target,\n"," 'model_name': model_name,\n"," 'model_path': model_path,\n"," 'number_of_epochs': number_of_epochs,\n"," 'batch_size': batch_size,\n"," 'training_shape': training_shape,\n"," 'downscaling': downscaling_in_xy,\n"," 'true_patch_size': true_patch_size,\n"," 'val_split': validation_split_in_percent/100,\n"," 'random_crop': random_crop}\n","\n","params_df = pd.DataFrame.from_dict(params, orient='index')\n","\n","# apply_data_augmentation = False\n","# pdf_export(augmentation = apply_data_augmentation, pretrained_model = resume_training)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"w_jCy7xOx2g3"},"source":["## **3.2. Data augmentation**\n"," \n","---\n"," Augmenting the training data increases robustness of the model by simulating possible variations within the training data which avoids it from overfitting on small datasets. We therefore strongly recommended augmenting the data and making sure that the applied augmentations are reasonable.\n","\n","* **Gaussian blur** blurs images using Gaussian kernels with a sigma of `gaussian_sigma`. This augmentation step is applied with a probability of `gaussian_frequency`. Read more [here](https://imgaug.readthedocs.io/en/latest/source/overview/blur.html#gaussianblur).\n","\n","* **Linear contrast** modifies the contrast of images according to `127 + alpha *(pixel_value-127)`, where `pixel_value` and `alpha` are sampled uniformly from the interval `[contrast_min, contrast_max]`. This augmentation step is applied with a probability of `contrast_frequency`. Read more [here](https://imgaug.readthedocs.io/en/latest/source/overview/contrast.html#linearcontrast).\n","\n","* **Additive Gaussian noise** adds Gaussian noise sampled once per pixel from a normal distribution `N(0, s)`, where `s` is sampled from `[scale_min, scale_max]`. This augmentation step is applied with a probability of `noise_frequency`. Read more [here](https://imgaug.readthedocs.io/en/latest/source/overview/arithmetic.html#additivegaussiannoise).\n","\n","* **Add custom augmenters** allows you to create a custom augmentation pipeline using the [augmenters available in the imagug library](https://imgaug.readthedocs.io/en/latest/source/overview_of_augmenters.html).\n","In the example above, the augmentation pipeline is equivalent to: \n","```\n","seq = iaa.Sequential([\n"," iaa.Sometimes(0.3, iaa.GammaContrast((0.5, 2.0)), \n"," iaa.Sometimes(0.4, iaa.AverageBlur((0.5, 2.0)), \n"," iaa.Sometimes(0.5, iaa.LinearContrast((0.4, 1.6)), \n","], random_order=True)\n","```\n"," Note that there is no limit on the number of augmenters that can be chained together and that individual augmenter and parameter entries must be separated by `;`. Custom augmenters do not overwrite the preset augmentation steps (*Gaussian blur*, *Linear contrast* or *Additive Gaussian noise*). Also, the augmenters, augmenter parameters and augmenter frequencies must be entered such that each position within the string corresponds to the same augmentation step.\n","\n","* **`apply_data_augmentation`** ensures that data augmentation is randomly applied to the training data at each training step. This includes inverting the order of the slices within a training patch, as well as applying any augmenters that are added. *Default: True*\n","\n","* **`add_elastic_deform`** ensures that elastic grid-based deformations are applied as described in the original 3D U-Net paper. *Default: True*"]},{"cell_type":"code","metadata":{"id":"DMqWq5-AxnFU","cellView":"form"},"source":["#@markdown ##**Augmentation options**\n","\n","#@markdown ###Data augmentation\n","\n","apply_data_augmentation = True #@param {type:\"boolean\"}\n","\n","# List of augmentations\n","augmentations = []\n","\n","#@markdown ###Gaussian blur\n","add_gaussian_blur = True #@param {type:\"boolean\"}\n","gaussian_sigma = 0.7#@param {type:\"number\"}\n","gaussian_frequency = 0.5 #@param {type:\"number\"}\n","\n","if add_gaussian_blur:\n"," augmentations.append(iaa.Sometimes(gaussian_frequency, iaa.GaussianBlur(sigma=(0, gaussian_sigma))))\n","\n","#@markdown ###Linear contrast\n","add_linear_contrast = True #@param {type:\"boolean\"}\n","contrast_min = 0.4 #@param {type:\"number\"}\n","contrast_max = 1.6#@param {type:\"number\"}\n","contrast_frequency = 0.5 #@param {type:\"number\"}\n","\n","if add_linear_contrast:\n"," augmentations.append(iaa.Sometimes(contrast_frequency, iaa.LinearContrast((contrast_min, contrast_max))))\n","\n","#@markdown ###Additive Gaussian noise\n","add_additive_gaussian_noise = False #@param {type:\"boolean\"}\n","scale_min = 0 #@param {type:\"number\"}\n","scale_max = 0.05 #@param {type:\"number\"}\n","noise_frequency = 0.5 #@param {type:\"number\"}\n","\n","if add_additive_gaussian_noise:\n"," augmentations.append(iaa.Sometimes(noise_frequency, iaa.AdditiveGaussianNoise(scale=(scale_min, scale_max))))\n","\n","#@markdown ###Add custom augmenters\n","add_custom_augmenters = False #@param {type:\"boolean\"} \n","augmenters = \"\" #@param {type:\"string\"}\n","\n","if add_custom_augmenters:\n","\n"," augmenter_params = \"\" #@param {type:\"string\"}\n","\n"," augmenter_frequency = \"\" #@param {type:\"string\"}\n","\n"," aug_lst = augmenters.split(';')\n"," aug_params_lst = augmenter_params.split(';')\n"," aug_freq_lst = augmenter_frequency.split(';')\n","\n"," assert len(aug_lst) == len(aug_params_lst) and len(aug_lst) == len(aug_freq_lst), 'The number of arguments in augmenters, augmenter_params and augmenter_frequency are not the same!'\n","\n"," for __, (aug, param, freq) in enumerate(zip(aug_lst, aug_params_lst, aug_freq_lst)):\n"," aug, param, freq = aug.strip(), param.strip(), freq.strip() \n"," aug_func = iaa.Sometimes(eval(freq), getattr(iaa, aug)(eval(param)))\n"," augmentations.append(aug_func)\n","\n","#@markdown ###Elastic deformations\n","add_elastic_deform = True #@param {type:\"boolean\"}\n","sigma = 2#@param {type:\"number\"}\n","points = 2#@param {type:\"number\"}\n","order = 2#@param {type:\"number\"}\n","\n","if add_elastic_deform:\n"," deform_params = (sigma, points, order)\n","else:\n"," deform_params = None\n","\n","train_generator = MultiPageTiffGenerator(training_source,\n"," training_target,\n"," batch_size=batch_size,\n"," shape=training_shape,\n"," augment=apply_data_augmentation,\n"," augmentations=augmentations,\n"," deform_augment=add_elastic_deform,\n"," deform_augmentation_params=deform_params,\n"," val_split=validation_split_in_percent/100,\n"," random_crop=random_crop,\n"," downscale=downscaling_in_xy,\n"," binary_target=binary_target)\n","\n","val_generator = MultiPageTiffGenerator(training_source,\n"," training_target,\n"," batch_size=batch_size,\n"," shape=training_shape,\n"," val_split=validation_split_in_percent/100,\n"," is_val=True,\n"," random_crop=random_crop,\n"," downscale=downscaling_in_xy,\n"," binary_target=binary_target)\n","\n","\n","if apply_data_augmentation:\n"," print('Data augmentation enabled.')\n"," sample_src_aug, sample_tgt_aug = train_generator.sample_augmentation(random.randint(0, len(train_generator)))\n","\n"," def scroll_in_z(z):\n"," f=plt.figure(figsize=(16,8))\n"," plt.subplot(1,2,1)\n"," plt.imshow(sample_src_aug[0,:,:,z-1,0], cmap='gray')\n"," plt.title('Sample augmented source (z = ' + str(z) + ')', fontsize=15)\n"," plt.axis('off')\n","\n"," plt.subplot(1,2,2)\n"," plt.imshow(sample_tgt_aug[0,:,:,z-1,0], cmap='magma')\n"," plt.title('Sample training target (z = ' + str(z) + ')', fontsize=15)\n"," plt.axis('off')\n","\n"," print('This is what the augmented training images will look like with the chosen settings')\n"," interact(scroll_in_z, z=widgets.IntSlider(min=1, max=sample_src_aug.shape[3], step=1, value=0));\n","\n","else:\n"," print('Data augmentation disabled.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"MCGklf1vZf2M"},"source":["# **4. Train the network**\n","---\n","\n","**CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training times must be less than 12 hours! If training takes longer than 12 hours, please decrease `number_of_epochs`."]},{"cell_type":"markdown","metadata":{"id":"1KYOuygETJkT"},"source":["## **4.1. Show model and start training**\n","---\n"]},{"cell_type":"code","metadata":{"id":"lIUAOJ_LMv5E","cellView":"form"},"source":["#@markdown ## Show model summary\n","model.summary()"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"CyQI4ssarUp4","cellView":"form"},"source":["#@markdown ##Start training\n","\n","#here we check that no model with the same name already exist, if so delete\n","if not resume_training and os.path.exists(full_model_path): \n"," shutil.rmtree(full_model_path)\n"," print(bcolors.WARNING+'!! WARNING: Folder already exists and has been overwritten !!'+bcolors.NORMAL) \n","\n","if not os.path.exists(full_model_path):\n"," os.makedirs(full_model_path)\n","\n","pdf_export(augmentation = apply_data_augmentation, pretrained_model = resume_training)\n","\n","# Save file\n","params_df.to_csv(os.path.join(full_model_path, 'params.csv'))\n","\n","start = time.time()\n","# Start Training\n","model.train(epochs=number_of_epochs,\n"," batch_size=batch_size,\n"," train_generator=train_generator,\n"," val_generator=val_generator,\n"," model_path=model_path,\n"," model_name=model_name,\n"," loss=loss_function,\n"," metrics=metrics,\n"," optimizer=optimizer,\n"," learning_rate=learning_rate,\n"," ckpt_period=checkpointing_period,\n"," save_best_ckpt_only=save_best_only,\n"," ckpt_path=last_ckpt_path)\n","\n","print('Training successfully completed!')\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","#Create a pdf document with training summary\n","\n","pdf_export(trained = True, augmentation = apply_data_augmentation, pretrained_model = resume_training)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"0Dfn8ZsEMv5d"},"source":["##**4.2. Download your model from Google Drive**\n","\n","---\n","Once training is complete, the trained model is automatically saved to your Google Drive, in the **`model_path`** folder that was specified in Section 3. Download the folder to avoid any unwanted surprises, since the data can be erased if you train another model using the same `model_path`."]},{"cell_type":"code","metadata":{"scrolled":true,"id":"iwNmp1PUzRDQ","cellView":"form"},"source":["#@markdown ##Download model directory\n","#@markdown 1. Specify the model_path in `model_path_download` otherwise the model sepcified in Section 3.1 will be downloaded\n","#@markdown 2. Run this cell to zip the model directory\n","#@markdown 3. Download the zipped file from the *Files* tab on the left\n","\n","from google.colab import files\n","\n","model_path_download = \"\" #@param {type:\"string\"}\n","\n","if len(model_path_download) == 0:\n"," model_path_download = full_model_path\n","\n","model_name_download = os.path.basename(model_path_download)\n","\n","print('Zipping', model_name_download)\n","\n","zip_model_path = model_name_download + '.zip'\n","\n","!zip -r \"$zip_model_path\" \"$model_path_download\"\n","\n","print('Successfully saved zipped model directory as', zip_model_path)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"_0Hynw3-xHp1"},"source":["# **5. Evaluate your model**\n","---\n","\n","In this section the newly trained model can be assessed for performance. This involves inspecting the loss function in Section 5.1. and employing more advanced metrics in Section 5.2.\n","\n","**We highly recommend performing quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"id":"eAJzMwPA6tlH","cellView":"form"},"source":["#@markdown ###Model to be evaluated:\n","#@markdown If left blank, the latest model defined in Section 3 will be evaluated:\n","\n","qc_model_name = \"\" #@param {type:\"string\"}\n","qc_model_path = \"\" #@param {type:\"string\"}\n","\n","if len(qc_model_path) == 0 and len(qc_model_name) == 0:\n"," qc_model_name = model_name\n"," qc_model_path = model_path\n","\n","full_qc_model_path = os.path.join(qc_model_path, qc_model_name)\n","\n","if os.path.exists(full_qc_model_path):\n"," print(qc_model_name + ' will be evaluated')\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"dhJROwlAMv5o"},"source":["## **5.1. Inspecting loss function**\n","---\n","\n","**The training loss** is the error between prediction and target after each epoch calculated across the training data while the **validation loss** calculates the error on the (unseen) validation data. During training these values should decrease until converging at which point the model has been sufficiently trained. If the validation loss starts increasing while the training loss has plateaued, the model has overfit on the training data which reduces its ability to generalise. Aim to halt training before this point.\n","\n","**Note:** For a more in-depth explanation please refer to [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols et al.\n","\n","\n","The accuracy is another performance metric that is calculated after each epoch. We use the [Sørensen–Dice coefficient](https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient) to score the prediction accuracy. \n","\n"]},{"cell_type":"code","metadata":{"id":"vMzSP50kMv5p","cellView":"form"},"source":["#@markdown ##Visualise loss and accuracy\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","accuracyDataFromCSV = []\n","valaccuracyDataFromCSV = []\n","\n","with open(full_qc_model_path + '/Quality Control/training_evaluation.csv', 'r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[2]))\n"," vallossDataFromCSV.append(float(row[4]))\n"," accuracyDataFromCSV.append(float(row[1]))\n"," valaccuracyDataFromCSV.append(float(row[3]))\n","\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training and validation loss', fontsize=14)\n","plt.ylabel('Loss', fontsize=12)\n","plt.xlabel('Epochs', fontsize=12)\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.plot(epochNumber,accuracyDataFromCSV, label='Training accuracy')\n","plt.plot(epochNumber,valaccuracyDataFromCSV, label='Validation accuracy')\n","plt.title('Training and validation accuracy', fontsize=14)\n","plt.ylabel('Dice', fontsize=12)\n","plt.xlabel('Epochs', fontsize=12)\n","plt.legend()\n","plt.savefig(full_qc_model_path + '/Quality Control/lossCurvePlots.png', bbox_inches='tight', pad_inches=0)\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"X5_92nL2xdP6"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","This section will provide both a visual indication of the model performance by comparing the overlay of the predicted and source volume."]},{"cell_type":"code","metadata":{"id":"w90MdriMxhjD","cellView":"form"},"source":["#@markdown ##Compare prediction and ground-truth on testing data\n","\n","#@markdown Provide an unseen annotated dataset to determine the performance of the model:\n","\n","testing_source = \"\" #@param{type:\"string\"}\n","testing_target = \"\" #@param{type:\"string\"}\n","\n","qc_dir = full_qc_model_path + '/Quality Control'\n","predict_dir = qc_dir + '/Prediction'\n","if os.path.exists(predict_dir):\n"," shutil.rmtree(predict_dir)\n","\n","os.makedirs(predict_dir)\n","\n","# predict_dir + '/' + \n","predict_path = os.path.splitext(os.path.basename(testing_source))[0] + '_prediction.tif'\n","\n","def last_chars(x):\n"," return(x[-11:])\n","\n","try:\n"," ckpt_dir_list = glob(full_qc_model_path + '/ckpt/*')\n"," ckpt_dir_list.sort(key=last_chars)\n"," last_ckpt_path = ckpt_dir_list[0]\n"," print('Predicting from checkpoint:', os.path.basename(last_ckpt_path))\n","except IndexError:\n"," raise CheckpointError('No previous checkpoints were found, please retrain model.')\n","\n","# Load parameters\n","params = pd.read_csv(os.path.join(full_qc_model_path, 'params.csv'), names=['val'], header=0, index_col=0) \n","\n","model = Unet3D(shape=params.loc['training_shape', 'val'])\n","\n","prediction = model.predict(testing_source, last_ckpt_path, downscaling=params.loc['downscaling', 'val'], true_patch_size=params.loc['true_patch_size', 'val'])\n","\n","tifffile.imwrite(predict_path, prediction.astype('float32'), imagej=True)\n","\n","print('Predicted images!')\n","\n","qc_metrics_path = full_qc_model_path + '/Quality Control/QC_metrics_' + qc_model_name + '.csv'\n","\n","test_target = tifffile.imread(testing_target)\n","test_source = tifffile.imread(testing_source)\n","test_prediction = tifffile.imread(predict_path)\n","\n","def scroll_in_z(z):\n","\n"," plt.figure(figsize=(25,5))\n"," # Source\n"," plt.subplot(1,4,1)\n"," plt.axis('off')\n"," plt.imshow(test_source[z-1], cmap='gray')\n"," plt.title('Source (z = ' + str(z) + ')', fontsize=15)\n","\n"," # Target (Ground-truth)\n"," plt.subplot(1,4,2)\n"," plt.axis('off')\n"," plt.imshow(test_target[z-1], cmap='magma')\n"," plt.title('Target (z = ' + str(z) + ')', fontsize=15)\n","\n"," # Prediction\n"," plt.subplot(1,4,3)\n"," plt.axis('off')\n"," plt.imshow(test_prediction[z-1], cmap='magma')\n"," plt.title('Prediction (z = ' + str(z) + ')', fontsize=15)\n"," \n"," # Overlay\n"," plt.subplot(1,4,4)\n"," plt.axis('off')\n"," plt.imshow(test_target[z-1], cmap='Greens')\n"," plt.imshow(test_prediction[z-1], alpha=0.5, cmap='Purples')\n"," plt.title('Overlay (z = ' + str(z) + ')', fontsize=15)\n"," plt.savefig(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/QC_example_data.png', bbox_inches='tight', pad_inches=0)\n","interact(scroll_in_z, z=widgets.IntSlider(min=1, max=test_source.shape[0], step=1, value=0));"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"aIvRxpZlsFeZ"},"source":["## **5.3. Determine best Intersection over Union and threshold**\n","---\n","\n","**Note:** This section is only relevant if the target image is a binary mask and `binary_target` is selected in Section 3! \n","\n","This section will provide both a visual and a quantitative indication of the model performance by comparing the overlay of the predicted and source volume, as well as computing the highest [**Intersection over Union**](https://en.wikipedia.org/wiki/Jaccard_index) (IoU) score. The IoU is also known as the Jaccard Index. \n","\n","The best threshold is calculated using the IoU. Each threshold value from 0 to 255 is tested and the threshold with the highest score is deemed the best. The IoU is calculated for the entire volume in 3D."]},{"cell_type":"code","metadata":{"id":"XhkeZTFusHA8","cellView":"form"},"source":["\n","#@markdown ##Calculate Intersection over Union and best threshold \n","prediction = tifffile.imread(predict_path)\n","prediction = np.interp(prediction, (prediction.min(), prediction.max()), (0, 255))\n","\n","target = tifffile.imread(testing_target).astype(np.bool)\n","\n","def iou_vs_threshold(prediction, target):\n"," threshold_list = []\n"," IoU_scores_list = []\n","\n"," for threshold in range(0,256): \n"," mask = prediction > threshold\n","\n"," intersection = np.logical_and(target, mask)\n"," union = np.logical_or(target, mask)\n"," iou_score = np.sum(intersection) / np.sum(union)\n","\n"," threshold_list.append(threshold)\n"," IoU_scores_list.append(iou_score)\n","\n"," return threshold_list, IoU_scores_list\n","\n","threshold_list, IoU_scores_list = iou_vs_threshold(prediction, target)\n","thresh_arr = np.array(list(zip(threshold_list, IoU_scores_list)))\n","best_thresh = int(np.where(thresh_arr == np.max(thresh_arr[:,1]))[0])\n","best_iou = IoU_scores_list[best_thresh]\n","\n","print('Highest IoU is {:.4f} with a threshold of {}'.format(best_iou, best_thresh))\n","\n","def adjust_threshold(threshold, z):\n","\n"," f=plt.figure(figsize=(25,5))\n"," plt.subplot(1,4,1)\n"," plt.imshow((prediction[z-1] > threshold).astype('uint8'), cmap='magma')\n"," plt.title('Prediction (Threshold = ' + str(threshold) + ')', fontsize=15)\n"," plt.axis('off')\n","\n"," plt.subplot(1,4,2)\n"," plt.imshow(target[z-1], cmap='magma')\n"," plt.title('Target (z = ' + str(z) + ')', fontsize=15)\n"," plt.axis('off')\n","\n"," plt.subplot(1,4,3)\n"," plt.axis('off')\n"," plt.imshow(test_source[z-1], cmap='gray')\n"," plt.imshow((prediction[z-1] > threshold).astype('uint8'), alpha=0.4, cmap='Reds')\n"," plt.title('Overlay (z = ' + str(z) + ')', fontsize=15)\n","\n"," plt.subplot(1,4,4)\n"," plt.title('Threshold vs. IoU', fontsize=15)\n"," plt.plot(threshold_list, IoU_scores_list)\n"," plt.plot(threshold, IoU_scores_list[threshold], 'ro') \n"," plt.ylabel('IoU score')\n"," plt.xlabel('Threshold')\n"," plt.savefig(os.path.join(qc_model_path,qc_model_name,'Quality Control')+'/QC_IoU_analysis.png',bbox_inches=matplotlib.transforms.Bbox([[17.5,0],[23,5]]),pad_inches=0)\n"," plt.show()\n","\n","interact(adjust_threshold, \n"," threshold=widgets.IntSlider(min=0, max=255, step=1, value=best_thresh),\n"," z=widgets.IntSlider(min=1, max=prediction.shape[0], step=1, value=0));\n","\n","#Make a pdf summary of the QC results\n","\n","qc_pdf_export()\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-tJeeJjLnRkP"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","Once sufficient performance of the trained model has been established using Section 5, the network can be used to segment unseen volumetric data."]},{"cell_type":"markdown","metadata":{"id":"d8wuQGjoq6eN"},"source":["## **6.1. Generate predictions from unseen dataset**\n","---\n","\n","The most recently trained model can now be used to predict segmentation masks on unseen images. If you want to use an older model, leave `model_path` blank. Predicted output images are saved in `output_path` as Image-J compatible TIFF files.\n","\n","## **Prediction parameters**\n","\n","* **`source_path`** specifies the location of the source \n","image volume.\n","\n","* **`output_directory`** specified the directory where the output predictions are stored.\n","\n","* **`binary_target`** should be chosen if the network is trained to predict binary segmentation masks.\n","\n","* **`threshold`** can be calculated in Section 5 and is used to generate binary masks from the predictions.\n","\n","* **`big_tiff`** should be chosen if the expected prediction exceeds 4GB. The predictions will be saved using the BigTIFF format. Beware that this might substantially reduce the prediction speed. *Default: False* \n","\n","* **`prediction_depth`** is only relevant if the prediction is saved as a BigTIFF. The prediction will not be performed in one go to not deplete the memory resources. Instead, the prediction is iteratively performed on a subset of the entire volume with shape `(source.shape[0], source.shape[1], prediction_depth)`. *Default: 32*\n","\n","* **`model_path`** specifies the path to a model other than the most recently trained."]},{"cell_type":"code","metadata":{"cellView":"form","id":"DEmhPh5fsWX2"},"source":["#@markdown ## Download example volume\n","\n","#@markdown This can take up to an hour\n","\n","import requests \n","import os\n","from tqdm.notebook import tqdm \n","\n","\n","def download_from_url(url, save_as):\n"," file_url = url\n"," r = requests.get(file_url, stream=True) \n"," \n"," with open(save_as, 'wb') as file: \n"," for block in tqdm(r.iter_content(chunk_size = 1024), desc = 'Downloading ' + os.path.basename(save_as), total=3275073, ncols=1000):\n"," if block:\n"," file.write(block) \n","\n","download_from_url('https://documents.epfl.ch/groups/c/cv/cvlab-unit/www/data/%20ElectronMicroscopy_Hippocampus/volumedata.tif', 'example_dataset/volumedata.tif')"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"y2TD5p7MZrEb","cellView":"form"},"source":["#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then run the cell to predict outputs from your unseen images.\n","\n","source_path = \"\" #@param {type:\"string\"}\n","output_directory = \"\" #@param {type:\"string\"}\n","\n","if not os.path.exists(output_directory):\n"," os.makedirs(output_directory)\n","\n","output_path = os.path.join(output_directory, os.path.splitext(os.path.basename(source_path))[0] + '_predicted.tif')\n","#@markdown ###Prediction parameters:\n","\n","binary_target = True #@param {type:\"boolean\"}\n","\n","save_probability_map = False #@param {type:\"boolean\"}\n","\n","#@markdown Determine best threshold in Section 5.2.\n","\n","use_calculated_threshold = True #@param {type:\"boolean\"}\n","threshold = 200#@param {type:\"number\"}\n","\n","# Tifffile library issues means that images cannot be appended to \n","#@markdown Choose if prediction file exceeds 4GB or if input file is very large (above 2GB). Image volume saved as BigTIFF.\n","big_tiff = False #@param {type:\"boolean\"}\n","\n","#@markdown Reduce `prediction_depth` if runtime runs out of memory during prediction. Only relevant if prediction saved as BigTIFF\n","\n","prediction_depth = 32#@param {type:\"number\"}\n","\n","#@markdown ###Model to be evaluated\n","#@markdown If left blank, the latest model defined in Section 5 will be evaluated\n","\n","full_model_path_ = \"\" #@param {type:\"string\"}\n","\n","if len(full_model_path_) == 0:\n"," full_model_path_ = os.path.join(qc_model_path, qc_model_name) \n","\n","\n","\n","# Load parameters\n","params = pd.read_csv(os.path.join(full_model_path_, 'params.csv'), names=['val'], header=0, index_col=0) \n","model = Unet3D(shape=params.loc['training_shape', 'val'])\n","\n","if use_calculated_threshold:\n"," threshold = best_thresh\n","\n","def last_chars(x):\n"," return(x[-11:])\n","\n","try:\n"," ckpt_dir_list = glob(full_model_path_ + '/ckpt/*')\n"," ckpt_dir_list.sort(key=last_chars)\n"," last_ckpt_path = ckpt_dir_list[0]\n"," print('Predicting from checkpoint:', os.path.basename(last_ckpt_path))\n","except IndexError:\n"," raise CheckpointError('No previous checkpoints were found, please retrain model.')\n","\n","src = tifffile.imread(source_path)\n","\n","if src.nbytes >= 4e9:\n"," big_tiff = True\n"," print('The source file exceeds 4GB in memory, prediction will be saved as BigTIFF!')\n","\n","if binary_target:\n"," if not big_tiff:\n"," prediction = model.predict(src, last_ckpt_path, downscaling=params.loc['downscaling', 'val'], true_patch_size=params.loc['true_patch_size', 'val'])\n"," prediction = np.interp(prediction, (prediction.min(), prediction.max()), (0, 255))\n"," prediction = (prediction > threshold).astype('float32')\n","\n"," tifffile.imwrite(output_path, prediction, imagej=True)\n","\n"," else:\n"," with tifffile.TiffWriter(output_path, bigtiff=True) as tif:\n"," for i in tqdm(range(0, src.shape[0], prediction_depth)):\n"," prediction = model.predict(src, last_ckpt_path, z_range=(i,i+prediction_depth), downscaling=params.loc['downscaling', 'val'], true_patch_size=params.loc['true_patch_size', 'val'])\n"," prediction = np.interp(prediction, (prediction.min(), prediction.max()), (0, 255))\n"," prediction = (prediction > threshold).astype('float32')\n"," \n"," for j in range(prediction.shape[0]):\n"," tif.save(prediction[j])\n","\n","if not binary_target or save_probability_map:\n"," if not binary_target:\n"," prob_map_path = output_path\n"," else:\n"," prob_map_path = os.path.splitext(output_path)[0] + '_prob_map.tif'\n"," \n"," if not big_tiff:\n"," prediction = model.predict(src, last_ckpt_path, downscaling=params.loc['downscaling', 'val'], true_patch_size=params.loc['true_patch_size', 'val'])\n"," prediction = np.interp(prediction, (prediction.min(), prediction.max()), (0, 255))\n"," tifffile.imwrite(prob_map_path, prediction.astype('float32'), imagej=True)\n","\n"," else:\n"," with tifffile.TiffWriter(prob_map_path, bigtiff=True) as tif:\n"," for i in tqdm(range(0, src.shape[0], prediction_depth)):\n"," prediction = model.predict(src, last_ckpt_path, z_range=(i,i+prediction_depth), downscaling=params.loc['downscaling', 'val'], true_patch_size=params.loc['true_patch_size', 'val'])\n"," prediction = np.interp(prediction, (prediction.min(), prediction.max()), (0, 255))\n"," \n"," for j in range(prediction.shape[0]):\n"," tif.save(prediction[j])\n","\n","print('Predictions saved as', output_path)\n","\n","src_volume = tifffile.imread(source_path)\n","pred_volume = tifffile.imread(output_path)\n","\n","def scroll_in_z(z):\n"," \n"," f=plt.figure(figsize=(25,5))\n"," plt.subplot(1,2,1)\n"," plt.imshow(src_volume[z-1], cmap='gray')\n"," plt.title('Source (z = ' + str(z) + ')', fontsize=15)\n"," plt.axis('off')\n","\n"," plt.subplot(1,2,2)\n"," plt.imshow(pred_volume[z-1], cmap='magma')\n"," plt.title('Prediction (z = ' + str(z) + ')', fontsize=15)\n"," plt.axis('off')\n","\n","interact(scroll_in_z, z=widgets.IntSlider(min=1, max=src_volume.shape[0], step=1, value=0));\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB"},"source":["## **6.2. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"q3lSeWp3G8eD"},"source":["# **7. Version log**\n","\n","---\n","**v1.13**: \n","* The section 1 and 2 are now swapped for better export of *requirements.txt*. \n","* This version also now includes built-in version check and the version log that you're reading now.\n","* Keras libraries are now imported via TensorFlow.\n","* The learning rate can be changed in section 3.1.\n","\n","---"]},{"cell_type":"markdown","metadata":{"id":"UvSlTaH14s3t"},"source":["\n","#**Thank you for using 3D U-Net!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/YOLOv2_ZeroCostDL4Mic.ipynb b/Colab_notebooks/YOLOv2_ZeroCostDL4Mic.ipynb index bb1ac059..52f1f952 100644 --- a/Colab_notebooks/YOLOv2_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/YOLOv2_ZeroCostDL4Mic.ipynb @@ -1 +1 @@ -{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"YOLOv2_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1VlYfohmBOvSVtkYci7R2gktMj-F32oK4","timestamp":1622645551473},{"file_id":"1bQuSKv6gvjvWhnzoIVjqUNvlC3F_-Jvw","timestamp":1619709372524},{"file_id":"1A26cn0nxWQCv-LuP3UBfyCWlKBGIo0RU","timestamp":1610968154980},{"file_id":"1MmLTCC0nyX3Akb9V4C_OVxM3X_M8u-eX","timestamp":1610543191319},{"file_id":"1paNjUObR5Rcr4BMGADJTz0PQBBLZDPrY","timestamp":1602522500580},{"file_id":"1WZRIoSBNcRUEq4-Rq5M4mDkIaOlEHnxz","timestamp":1588762142860},{"file_id":"10weAY0es-pEfHlACCaBCKK7PmgdoJqdh","timestamp":1587728072051},{"file_id":"10Ze0rFZoooyyTL_OIVWGdFJEhWE6_cSB","timestamp":1586789421439},{"file_id":"1SsGyUbWcMaLGHFepMuKElRNYLdEBUwf6","timestamp":1583244509550}],"collapsed_sections":[],"toc_visible":true,"machine_shape":"hm"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.7"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"IkSguVy8Xv83"},"source":["# **YOLOv2**\n","---\n","\n"," YOLOv2 is a deep-learning method designed to perform object detection and classification of objects in images, published by [Redmon and Farhadi](https://ieeexplore.ieee.org/document/8100173). This is based on the original [YOLO](https://arxiv.org/abs/1506.02640) implementation published by the same authors. YOLOv2 is trained on images with class annotations in the form of bounding boxes drawn around the objects of interest. The images are downsampled by a convolutional neural network (CNN) and objects are classified in two final fully connected layers in the network. YOLOv2 learns classification and object detection simultaneously by taking the whole input image into account, predicting many possible bounding box solutions, and then using regression to find the best bounding boxes and classifications for each object.\n","\n","**This particular notebook enables object detection and classification on 2D images given ground truth bounding boxes. If you are interested in image segmentation, you should use our U-net or Stardist notebooks instead.**\n","\n","---\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is based on the following papers: \n","\n","**YOLO9000: Better, Faster, Stronger** from Joseph Redmon and Ali Farhadi in Proceedings of the IEEE conference on computer vision and pattern recognition, 2017, (https://ieeexplore.ieee.org/document/8100173)\n","\n","**You Only Look Once: Unified, Real-Time Object Detection** from Joseph Redmon, Santosh Divvala, Ross Girshick, Ali Farhadi in IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2016, (https://ieeexplore.ieee.org/document/7780460)\n","\n","**Note: The source code for this notebook is adapted for keras and can be found in: (/~https://github.com/experiencor/keras-yolo2)**\n","\n","\n","**Please also cite these original papers when using or developing this notebook.**"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","\n","\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"gKDLkLWUd-YX"},"source":["#**0. Before getting started**\n","---\n"," Preparing the dataset carefully is essential to make this YOLOv2 notebook work. This model requires as input a set of images (currently .jpg) and as target a list of annotation files in Pascal VOC format. The annotation files should have the exact same name as the input files, except with an .xml instead of the .jpg extension. The annotation files contain the class labels and all bounding boxes for the objects for each image in your dataset. Most datasets will give the option of saving the annotations in this format or using software for hand-annotations will automatically save the annotations in this format. \n","\n"," If you want to assemble your own dataset we recommend using the open source https://www.makesense.ai/ resource. You can follow our instructions on how to label your dataset with this tool on our [wiki](/~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki/Object-Detection-(YOLOv2)).\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n"," **Additionally, the corresponding input and output files need to have the same name**.\n","\n"," Please note that you currently can **only use .png or .jpg files!**\n","\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Input images (Training_source)\n"," - img_1.png, img_2.png, ...\n"," - High SNR images (Training_source_annotations)\n"," - img_1.xml, img_2.xml, ...\n"," - **Quality control dataset**\n"," - Input images\n"," - img_1.png, img_2.png\n"," - High SNR images\n"," - img_1.xml, img_2.xml\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **1. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelerator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"zCvebubeSaGY","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"sNIVx8_CLolt"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","cellView":"form"},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"AdN8B91xZO0x"},"source":["# **2. Install YOLOv2 and Dependencies**\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"bEN_Qt10Opz-"},"source":["## **2.1. Install key dependencies**\n","---\n"," "]},{"cell_type":"code","metadata":{"cellView":"form","id":"5fGzcX6sOTxH"},"source":["#@markdown ##Install YOLOv2 and dependencies\n","\n","!pip install pascal-voc-writer\n","!pip install fpdf\n","!pip install PTable\n","!pip install h5py==2.10\n","\n","#Force session restart\n","exit(0)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"zljHDmoLOu4W"},"source":["## **2.2. Restart your runtime**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"id":"SNLwKiVXO000"},"source":["** Your Runtime has automatically restarted. This is normal.**\n"]},{"cell_type":"markdown","metadata":{"id":"8OsMrZ8hO7D8"},"source":["## **2.3. Load key dependencies**\n","---\n"," "]},{"cell_type":"code","metadata":{"id":"fq21zJVFNASx","cellView":"form"},"source":["Notebook_version = ['1.12.2']\n","\n","from builtins import any as b_any\n","\n","def get_requirements_path():\n"," # Store requirements file in 'contents' directory \n"," current_dir = os.getcwd()\n"," dir_count = current_dir.count('/') - 1\n"," path = '../' * (dir_count) + 'requirements.txt'\n"," return path\n","\n","def filter_files(file_list, filter_list):\n"," filtered_list = []\n"," for fname in file_list:\n"," if b_any(fname.split('==')[0] in s for s in filter_list):\n"," filtered_list.append(fname)\n"," return filtered_list\n","\n","def build_requirements_file(before, after):\n"," path = get_requirements_path()\n","\n"," # Exporting requirements.txt for local run\n"," !pip freeze > $path\n","\n"," # Get minimum requirements file\n"," df = pd.read_csv(path, delimiter = \"\\n\")\n"," mod_list = [m.split('.')[0] for m in after if not m in before]\n"," req_list_temp = df.values.tolist()\n"," req_list = [x[0] for x in req_list_temp]\n","\n"," # Replace with package name and handle cases where import name is different to module name\n"," mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]\n"," mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list] \n"," filtered_list = filter_files(req_list, mod_replace_list)\n","\n"," file=open(path,'w')\n"," for item in filtered_list:\n"," file.writelines(item + '\\n')\n","\n"," file.close()\n","\n","import sys\n","before = [str(m) for m in sys.modules]\n","\n","#@markdown ##Install Network and Dependencies\n","%tensorflow_version 1.x\n","\n","from pascal_voc_writer import Writer\n","from __future__ import division\n","from __future__ import print_function\n","from __future__ import absolute_import\n","import csv\n","import random\n","import pprint\n","import time\n","import numpy as np\n","from optparse import OptionParser\n","import pickle\n","import math\n","import cv2\n","import copy\n","import math\n","from matplotlib import pyplot as plt\n","import matplotlib.patches as patches\n","import tensorflow as tf\n","import pandas as pd\n","import os\n","import shutil\n","from skimage import io\n","from sklearn.metrics import average_precision_score\n","\n","from keras.models import Model\n","from keras.layers import Flatten, Dense, Input, Conv2D, MaxPooling2D, Dropout, Reshape, Activation, Conv2D, MaxPooling2D, BatchNormalization, Lambda\n","from keras.layers.advanced_activations import LeakyReLU\n","from keras.layers.merge import concatenate\n","from keras.applications.mobilenet import MobileNet\n","from keras.applications import InceptionV3\n","from keras.applications.vgg16 import VGG16\n","from keras.applications.resnet50 import ResNet50\n","\n","from keras import backend as K\n","from keras.optimizers import Adam, SGD, RMSprop\n","from keras.layers import GlobalAveragePooling2D, GlobalMaxPooling2D, TimeDistributed\n","from keras.engine.topology import get_source_inputs\n","from keras.utils import layer_utils\n","from keras.utils.data_utils import get_file\n","from keras.objectives import categorical_crossentropy\n","from keras.models import Model\n","from keras.utils import generic_utils\n","from keras.engine import Layer, InputSpec\n","from keras import initializers, regularizers\n","from keras.utils import Sequence\n","import xml.etree.ElementTree as ET\n","from collections import OrderedDict, Counter\n","import json\n","import imageio\n","import imgaug as ia\n","from imgaug import augmenters as iaa\n","import copy\n","import cv2\n","from tqdm import tqdm\n","from tempfile import mkstemp\n","from shutil import move, copymode\n","from os import fdopen, remove\n","from fpdf import FPDF, HTMLMixin\n","from datetime import datetime\n","from pip._internal.operations.freeze import freeze\n","import subprocess as sp\n","\n","from prettytable import from_csv\n","\n","# from matplotlib.pyplot import imread\n","\n","ia.seed(1)\n","# imgaug uses matplotlib backend for displaying images\n","from imgaug.augmentables.bbs import BoundingBox, BoundingBoxesOnImage\n","import re\n","import glob\n","\n","#Here, we import a different github repo which includes the map_evaluation.py\n","!git clone /~https://github.com/rodrigo2019/keras_yolo2.git\n","\n","if os.path.exists('/content/gdrive/My Drive/keras-yolo2'):\n"," shutil.rmtree('/content/gdrive/My Drive/keras-yolo2')\n","\n","#Here, we import the main github repo for this notebook and move it to the gdrive\n","!git clone /~https://github.com/experiencor/keras-yolo2.git\n","shutil.move('/content/keras-yolo2','/content/gdrive/My Drive/keras-yolo2')\n","#Now, we move the map_evaluation.py file to the main repo for this notebook.\n","#The source repo of the map_evaluation.py can then be ignored and is not further relevant for this notebook.\n","shutil.move('/content/keras_yolo2/keras_yolov2/map_evaluation.py','/content/gdrive/My Drive/keras-yolo2/map_evaluation.py')\n","\n","os.chdir('/content/gdrive/My Drive/keras-yolo2')\n","\n","\n","from backend import BaseFeatureExtractor, FullYoloFeature\n","from preprocessing import parse_annotation, BatchGenerator\n","\n","\n","\n","def plt_rectangle(plt,label,x1,y1,x2,y2,fontsize=10):\n"," '''\n"," == Input ==\n"," \n"," plt : matplotlib.pyplot object\n"," label : string containing the object class name\n"," x1 : top left corner x coordinate\n"," y1 : top left corner y coordinate\n"," x2 : bottom right corner x coordinate\n"," y2 : bottom right corner y coordinate\n"," '''\n"," linewidth = 1\n"," color = \"yellow\"\n"," plt.text(x1,y1,label,fontsize=fontsize,backgroundcolor=\"magenta\")\n"," plt.plot([x1,x1],[y1,y2], linewidth=linewidth,color=color)\n"," plt.plot([x2,x2],[y1,y2], linewidth=linewidth,color=color)\n"," plt.plot([x1,x2],[y1,y1], linewidth=linewidth,color=color)\n"," plt.plot([x1,x2],[y2,y2], linewidth=linewidth,color=color)\n","\n","def extract_single_xml_file(tree,object_count=True):\n"," Nobj = 0\n"," row = OrderedDict()\n"," for elems in tree.iter():\n","\n"," if elems.tag == \"size\":\n"," for elem in elems:\n"," row[elem.tag] = int(elem.text)\n"," if elems.tag == \"object\":\n"," for elem in elems:\n"," if elem.tag == \"name\":\n"," row[\"bbx_{}_{}\".format(Nobj,elem.tag)] = str(elem.text) \n"," if elem.tag == \"bndbox\":\n"," for k in elem:\n"," row[\"bbx_{}_{}\".format(Nobj,k.tag)] = float(k.text)\n"," Nobj += 1\n"," if object_count == True:\n"," row[\"Nobj\"] = Nobj\n"," return(row)\n","\n","def count_objects(tree):\n"," Nobj=0\n"," for elems in tree.iter():\n"," if elems.tag == \"object\":\n"," for elem in elems:\n"," if elem.tag == \"bndbox\":\n"," Nobj += 1\n"," return(Nobj)\n","\n","def compute_overlap(a, b):\n"," \"\"\"\n"," Code originally from /~https://github.com/rbgirshick/py-faster-rcnn.\n"," Parameters\n"," ----------\n"," a: (N, 4) ndarray of float\n"," b: (K, 4) ndarray of float\n"," Returns\n"," -------\n"," overlaps: (N, K) ndarray of overlap between boxes and query_boxes\n"," \"\"\"\n"," area = (b[:, 2] - b[:, 0]) * (b[:, 3] - b[:, 1])\n","\n"," iw = np.minimum(np.expand_dims(a[:, 2], axis=1), b[:, 2]) - np.maximum(np.expand_dims(a[:, 0], 1), b[:, 0])\n"," ih = np.minimum(np.expand_dims(a[:, 3], axis=1), b[:, 3]) - np.maximum(np.expand_dims(a[:, 1], 1), b[:, 1])\n","\n"," iw = np.maximum(iw, 0)\n"," ih = np.maximum(ih, 0)\n","\n"," ua = np.expand_dims((a[:, 2] - a[:, 0]) * (a[:, 3] - a[:, 1]), axis=1) + area - iw * ih\n","\n"," ua = np.maximum(ua, np.finfo(float).eps)\n","\n"," intersection = iw * ih\n","\n"," return intersection / ua\n","\n","def compute_ap(recall, precision):\n"," \"\"\" Compute the average precision, given the recall and precision curves.\n"," Code originally from /~https://github.com/rbgirshick/py-faster-rcnn.\n","\n"," # Arguments\n"," recall: The recall curve (list).\n"," precision: The precision curve (list).\n"," # Returns\n"," The average precision as computed in py-faster-rcnn.\n"," \"\"\"\n"," # correct AP calculation\n"," # first append sentinel values at the end\n"," mrec = np.concatenate(([0.], recall, [1.]))\n"," mpre = np.concatenate(([0.], precision, [0.]))\n","\n"," # compute the precision envelope\n"," for i in range(mpre.size - 1, 0, -1):\n"," mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])\n","\n"," # to calculate area under PR curve, look for points\n"," # where X axis (recall) changes value\n"," i = np.where(mrec[1:] != mrec[:-1])[0]\n","\n"," # and sum (\\Delta recall) * prec\n"," ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])\n"," return ap \n","\n","def load_annotation(image_folder,annotations_folder, i, config):\n"," annots = []\n"," imgs, anns = parse_annotation(annotations_folder,image_folder,config['model']['labels'])\n"," for obj in imgs[i]['object']:\n"," annot = [obj['xmin'], obj['ymin'], obj['xmax'], obj['ymax'], config['model']['labels'].index(obj['name'])]\n"," annots += [annot]\n","\n"," if len(annots) == 0: annots = [[]]\n","\n"," return np.array(annots)\n","\n","def _calc_avg_precisions(config,image_folder,annotations_folder,weights_path,iou_threshold,score_threshold):\n","\n"," # gather all detections and annotations\n"," all_detections = [[None for _ in range(len(config['model']['labels']))] for _ in range(len(os.listdir(image_folder)))]\n"," all_annotations = [[None for _ in range(len(config['model']['labels']))] for _ in range(len(os.listdir(annotations_folder)))]\n","\n"," for i in range(len(os.listdir(image_folder))):\n"," raw_image = cv2.imread(os.path.join(image_folder,sorted(os.listdir(image_folder))[i]))\n"," raw_height, raw_width, _ = raw_image.shape\n"," #print(raw_height)\n"," # make the boxes and the labels\n"," yolo = YOLO(backend = config['model']['backend'],\n"," input_size = config['model']['input_size'], \n"," labels = config['model']['labels'], \n"," max_box_per_image = config['model']['max_box_per_image'],\n"," anchors = config['model']['anchors'])\n"," yolo.load_weights(weights_path)\n"," pred_boxes = yolo.predict(raw_image,iou_threshold=iou_threshold,score_threshold=score_threshold)\n","\n"," score = np.array([box.score for box in pred_boxes])\n"," #print(score)\n"," pred_labels = np.array([box.label for box in pred_boxes])\n"," #print(len(pred_boxes))\n"," if len(pred_boxes) > 0:\n"," pred_boxes = np.array([[box.xmin * raw_width, box.ymin * raw_height, box.xmax * raw_width,\n"," box.ymax * raw_height, box.score] for box in pred_boxes])\n"," else:\n"," pred_boxes = np.array([[]])\n","\n"," # sort the boxes and the labels according to scores\n"," score_sort = np.argsort(-score)\n"," pred_labels = pred_labels[score_sort]\n"," pred_boxes = pred_boxes[score_sort]\n","\n"," # copy detections to all_detections\n"," for label in range(len(config['model']['labels'])):\n"," all_detections[i][label] = pred_boxes[pred_labels == label, :]\n","\n"," annotations = load_annotation(image_folder,annotations_folder,i,config)\n","\n"," # copy ground truth to all_annotations\n"," for label in range(len(config['model']['labels'])):\n"," all_annotations[i][label] = annotations[annotations[:, 4] == label, :4].copy()\n","\n"," # compute mAP by comparing all detections and all annotations\n"," average_precisions = {}\n"," F1_scores = {}\n"," total_recall = []\n"," total_precision = []\n"," \n"," with open(QC_model_folder+\"/Quality Control/QC_results.csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"class\", \"false positive\", \"true positive\", \"false negative\", \"recall\", \"precision\", \"accuracy\", \"f1 score\", \"average_precision\"]) \n"," \n"," for label in range(len(config['model']['labels'])):\n"," false_positives = np.zeros((0,))\n"," true_positives = np.zeros((0,))\n"," scores = np.zeros((0,))\n"," num_annotations = 0.0\n","\n"," for i in range(len(os.listdir(image_folder))):\n"," detections = all_detections[i][label]\n"," annotations = all_annotations[i][label]\n"," num_annotations += annotations.shape[0]\n"," detected_annotations = []\n","\n"," for d in detections:\n"," scores = np.append(scores, d[4])\n","\n"," if annotations.shape[0] == 0:\n"," false_positives = np.append(false_positives, 1)\n"," true_positives = np.append(true_positives, 0)\n"," continue\n","\n"," overlaps = compute_overlap(np.expand_dims(d, axis=0), annotations)\n"," assigned_annotation = np.argmax(overlaps, axis=1)\n"," max_overlap = overlaps[0, assigned_annotation]\n","\n"," if max_overlap >= iou_threshold and assigned_annotation not in detected_annotations:\n"," false_positives = np.append(false_positives, 0)\n"," true_positives = np.append(true_positives, 1)\n"," detected_annotations.append(assigned_annotation)\n"," else:\n"," false_positives = np.append(false_positives, 1)\n"," true_positives = np.append(true_positives, 0)\n","\n"," # no annotations -> AP for this class is 0 (is this correct?)\n"," if num_annotations == 0:\n"," average_precisions[label] = 0\n"," continue\n","\n"," # sort by score\n"," indices = np.argsort(-scores)\n"," false_positives = false_positives[indices]\n"," true_positives = true_positives[indices]\n","\n"," # compute false positives and true positives\n"," false_positives = np.cumsum(false_positives)\n"," true_positives = np.cumsum(true_positives)\n","\n"," # compute recall and precision\n"," recall = true_positives / num_annotations\n"," precision = true_positives / np.maximum(true_positives + false_positives, np.finfo(np.float64).eps)\n"," total_recall.append(recall)\n"," total_precision.append(precision)\n"," #print(precision)\n"," # compute average precision\n"," average_precision = compute_ap(recall, precision)\n"," average_precisions[label] = average_precision\n","\n"," if len(precision) != 0:\n"," F1_score = 2*(precision[-1]*recall[-1]/(precision[-1]+recall[-1]))\n"," F1_scores[label] = F1_score\n"," writer.writerow([config['model']['labels'][label], str(int(false_positives[-1])), str(int(true_positives[-1])), str(int(num_annotations-true_positives[-1])), str(recall[-1]), str(precision[-1]), str(true_positives[-1]/num_annotations), str(F1_scores[label]), str(average_precisions[label])])\n"," else:\n"," F1_score = 0\n"," F1_scores[label] = F1_score\n"," writer.writerow([config['model']['labels'][label], str(0), str(0), str(0), str(0), str(0), str(0), str(F1_score), str(average_precisions[label])])\n"," return F1_scores, average_precisions, total_recall, total_precision\n","\n","\n","def show_frame(pred_bb, pred_classes, pred_conf, gt_bb, gt_classes, class_dict, background=np.zeros((512, 512, 3)), show_confidence=True):\n"," \"\"\"\n"," Here, we are adapting classes and functions from /~https://github.com/MathGaron/mean_average_precision\n"," \"\"\"\n"," \"\"\"\n"," Plot the boundingboxes\n"," :param pred_bb: (np.array) Predicted Bounding Boxes [x1, y1, x2, y2] : Shape [n_pred, 4]\n"," :param pred_classes: (np.array) Predicted Classes : Shape [n_pred]\n"," :param pred_conf: (np.array) Predicted Confidences [0.-1.] : Shape [n_pred]\n"," :param gt_bb: (np.array) Ground Truth Bounding Boxes [x1, y1, x2, y2] : Shape [n_gt, 4]\n"," :param gt_classes: (np.array) Ground Truth Classes : Shape [n_gt]\n"," :param class_dict: (dictionary) Key value pairs of classes, e.g. {0:'dog',1:'cat',2:'horse'}\n"," :return:\n"," \"\"\"\n"," n_pred = pred_bb.shape[0]\n"," n_gt = gt_bb.shape[0]\n"," n_class = int(np.max(np.append(pred_classes, gt_classes)) + 1)\n"," #print(n_class)\n"," if len(background.shape) < 3:\n"," h, w = background.shape\n"," else:\n"," h, w, c = background.shape\n","\n"," ax = plt.subplot(\"111\")\n"," ax.imshow(background)\n"," cmap = plt.cm.get_cmap('hsv')\n","\n"," confidence_alpha = pred_conf.copy()\n"," if not show_confidence:\n"," confidence_alpha.fill(1)\n","\n"," for i in range(n_pred):\n"," x1 = pred_bb[i, 0]# * w\n"," y1 = pred_bb[i, 1]# * h\n"," x2 = pred_bb[i, 2]# * w\n"," y2 = pred_bb[i, 3]# * h\n"," rect_w = x2 - x1\n"," rect_h = y2 - y1\n"," #print(x1, y1)\n"," ax.add_patch(patches.Rectangle((x1, y1), rect_w, rect_h,\n"," fill=False,\n"," edgecolor=cmap(float(pred_classes[i]) / n_class),\n"," linestyle='dashdot',\n"," alpha=confidence_alpha[i]))\n","\n"," for i in range(n_gt):\n"," x1 = gt_bb[i, 0]# * w\n"," y1 = gt_bb[i, 1]# * h\n"," x2 = gt_bb[i, 2]# * w\n"," y2 = gt_bb[i, 3]# * h\n"," rect_w = x2 - x1\n"," rect_h = y2 - y1\n"," ax.add_patch(patches.Rectangle((x1, y1), rect_w, rect_h,\n"," fill=False,\n"," edgecolor=cmap(float(gt_classes[i]) / n_class)))\n","\n"," legend_handles = []\n","\n"," for i in range(n_class):\n"," legend_handles.append(patches.Patch(color=cmap(float(i) / n_class), label=class_dict[i]))\n"," \n"," ax.legend(handles=legend_handles)\n"," plt.show()\n","\n","class BoundBox:\n"," \"\"\"\n"," Here, we are adapting classes and functions from /~https://github.com/MathGaron/mean_average_precision\n"," \"\"\"\n"," def __init__(self, xmin, ymin, xmax, ymax, c = None, classes = None):\n"," self.xmin = xmin\n"," self.ymin = ymin\n"," self.xmax = xmax\n"," self.ymax = ymax\n"," \n"," self.c = c\n"," self.classes = classes\n","\n"," self.label = -1\n"," self.score = -1\n","\n"," def get_label(self):\n"," if self.label == -1:\n"," self.label = np.argmax(self.classes)\n"," \n"," return self.label\n"," \n"," def get_score(self):\n"," if self.score == -1:\n"," self.score = self.classes[self.get_label()]\n"," \n"," return self.score\n","\n","class WeightReader:\n"," def __init__(self, weight_file):\n"," self.offset = 4\n"," self.all_weights = np.fromfile(weight_file, dtype='float32')\n"," \n"," def read_bytes(self, size):\n"," self.offset = self.offset + size\n"," return self.all_weights[self.offset-size:self.offset]\n"," \n"," def reset(self):\n"," self.offset = 4\n","\n","def bbox_iou(box1, box2):\n"," intersect_w = _interval_overlap([box1.xmin, box1.xmax], [box2.xmin, box2.xmax])\n"," intersect_h = _interval_overlap([box1.ymin, box1.ymax], [box2.ymin, box2.ymax]) \n"," \n"," intersect = intersect_w * intersect_h\n","\n"," w1, h1 = box1.xmax-box1.xmin, box1.ymax-box1.ymin\n"," w2, h2 = box2.xmax-box2.xmin, box2.ymax-box2.ymin\n"," \n"," union = w1*h1 + w2*h2 - intersect\n"," \n"," return float(intersect) / union\n","\n","def draw_boxes(image, boxes, labels):\n"," image_h, image_w, _ = image.shape\n"," #Changes in box color added by LvC\n"," # class_colours = []\n"," # for c in range(len(labels)):\n"," # colour = np.random.randint(low=0,high=255,size=3).tolist()\n"," # class_colours.append(tuple(colour))\n"," for box in boxes:\n"," xmin = int(box.xmin*image_w)\n"," ymin = int(box.ymin*image_h)\n"," xmax = int(box.xmax*image_w)\n"," ymax = int(box.ymax*image_h)\n"," if box.get_label() == 0:\n"," cv2.rectangle(image, (xmin,ymin), (xmax,ymax), (255,0,0), 3)\n"," elif box.get_label() == 1:\n"," cv2.rectangle(image, (xmin,ymin), (xmax,ymax), (0,255,0), 3)\n"," else:\n"," cv2.rectangle(image, (xmin,ymin), (xmax,ymax), (0,0,255), 3)\n"," #cv2.rectangle(image, (xmin,ymin), (xmax,ymax), class_colours[box.get_label()], 3)\n"," cv2.putText(image, \n"," labels[box.get_label()] + ' ' + str(round(box.get_score(),3)), \n"," (xmin, ymin - 13), \n"," cv2.FONT_HERSHEY_SIMPLEX, \n"," 1e-3 * image_h, \n"," (0,0,0), 2)\n"," #print(box.get_label()) \n"," return image \n","\n","#Function added by LvC\n","def save_boxes(image_path, boxes, labels):#, save_path):\n"," image = cv2.imread(image_path)\n"," image_h, image_w, _ = image.shape\n"," save_boxes =[]\n"," save_boxes_names = []\n"," save_boxes.append(os.path.basename(image_path))\n"," save_boxes_names.append(os.path.basename(image_path))\n"," for box in boxes:\n"," # xmin = box.xmin\n"," save_boxes.append(int(box.xmin*image_w))\n"," save_boxes_names.append(int(box.xmin*image_w))\n"," # ymin = box.ymin\n"," save_boxes.append(int(box.ymin*image_h))\n"," save_boxes_names.append(int(box.ymin*image_h))\n"," # xmax = box.xmax\n"," save_boxes.append(int(box.xmax*image_w))\n"," save_boxes_names.append(int(box.xmax*image_w))\n"," # ymax = box.ymax\n"," save_boxes.append(int(box.ymax*image_h))\n"," save_boxes_names.append(int(box.ymax*image_h))\n"," score = box.get_score()\n"," save_boxes.append(score)\n"," save_boxes_names.append(score)\n"," label = box.get_label()\n"," save_boxes.append(label)\n"," save_boxes_names.append(labels[label])\n"," \n"," #This file will be for later analysis of the bounding boxes in imagej\n"," if not os.path.exists('/content/predicted_bounding_boxes.csv'):\n"," with open('/content/predicted_bounding_boxes.csv', 'w', newline='') as csvfile:\n"," csvwriter = csv.writer(csvfile, delimiter=',')\n"," specs_list = ['filename']+['xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class']*len(boxes)\n"," csvwriter.writerow(specs_list)\n"," csvwriter.writerow(save_boxes)\n"," else:\n"," with open('/content/predicted_bounding_boxes.csv', 'a+', newline='') as csvfile:\n"," csvwriter = csv.writer(csvfile)\n"," csvwriter.writerow(save_boxes)\n"," \n"," if not os.path.exists('/content/predicted_bounding_boxes_names.csv'):\n"," with open('/content/predicted_bounding_boxes_names.csv', 'w', newline='') as csvfile_names:\n"," csvwriter = csv.writer(csvfile_names, delimiter=',')\n"," specs_list = ['filename']+['xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class']*len(boxes)\n"," csvwriter.writerow(specs_list)\n"," csvwriter.writerow(save_boxes_names)\n"," else:\n"," with open('/content/predicted_bounding_boxes_names.csv', 'a+', newline='') as csvfile_names:\n"," csvwriter = csv.writer(csvfile_names)\n"," csvwriter.writerow(save_boxes_names)\n"," # #This file is to create a nicer display for the output images\n"," # if not os.path.exists('/content/predicted_bounding_boxes_display.csv'):\n"," # with open('/content/predicted_bounding_boxes_display.csv', 'w', newline='') as csvfile_new:\n"," # csvwriter2 = csv.writer(csvfile_new, delimiter=',')\n"," # specs_list = ['filename','width','height','class','xmin','ymin','xmax','ymax']\n"," # csvwriter2.writerow(specs_list)\n"," # else:\n"," # with open('/content/predicted_bounding_boxes_display.csv','a+',newline='') as csvfile_new:\n"," # csvwriter2 = csv.writer(csvfile_new)\n"," # for box in boxes:\n"," # row = [os.path.basename(image_path),image_w,image_h,box.get_label(),int(box.xmin*image_w),int(box.ymin*image_h),int(box.xmax*image_w),int(box.ymax*image_h)]\n"," # csvwriter2.writerow(row)\n","\n","def add_header(inFilePath,outFilePath):\n"," header = ['filename']+['xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class']*max(n_objects)\n"," with open(inFilePath, newline='') as inFile, open(outFilePath, 'w', newline='') as outfile:\n"," r = csv.reader(inFile)\n"," w = csv.writer(outfile)\n"," next(r, None) # skip the first row from the reader, the old header\n"," # write new header\n"," w.writerow(header)\n"," # copy the rest\n"," for row in r:\n"," w.writerow(row)\n"," \n","def decode_netout(netout, anchors, nb_class, obj_threshold=0.3, nms_threshold=0.5):\n"," grid_h, grid_w, nb_box = netout.shape[:3]\n","\n"," boxes = []\n"," \n"," # decode the output by the network\n"," netout[..., 4] = _sigmoid(netout[..., 4])\n"," netout[..., 5:] = netout[..., 4][..., np.newaxis] * _softmax(netout[..., 5:])\n"," netout[..., 5:] *= netout[..., 5:] > obj_threshold\n"," \n"," for row in range(grid_h):\n"," for col in range(grid_w):\n"," for b in range(nb_box):\n"," # from 4th element onwards are confidence and class classes\n"," classes = netout[row,col,b,5:]\n"," \n"," if np.sum(classes) > 0:\n"," # first 4 elements are x, y, w, and h\n"," x, y, w, h = netout[row,col,b,:4]\n","\n"," x = (col + _sigmoid(x)) / grid_w # center position, unit: image width\n"," y = (row + _sigmoid(y)) / grid_h # center position, unit: image height\n"," w = anchors[2 * b + 0] * np.exp(w) / grid_w # unit: image width\n"," h = anchors[2 * b + 1] * np.exp(h) / grid_h # unit: image height\n"," confidence = netout[row,col,b,4]\n"," \n"," box = BoundBox(x-w/2, y-h/2, x+w/2, y+h/2, confidence, classes)\n"," \n"," boxes.append(box)\n","\n"," # suppress non-maximal boxes\n"," for c in range(nb_class):\n"," sorted_indices = list(reversed(np.argsort([box.classes[c] for box in boxes])))\n","\n"," for i in range(len(sorted_indices)):\n"," index_i = sorted_indices[i]\n"," \n"," if boxes[index_i].classes[c] == 0: \n"," continue\n"," else:\n"," for j in range(i+1, len(sorted_indices)):\n"," index_j = sorted_indices[j]\n"," \n"," if bbox_iou(boxes[index_i], boxes[index_j]) >= nms_threshold:\n"," boxes[index_j].classes[c] = 0\n"," \n"," # remove the boxes which are less likely than a obj_threshold\n"," boxes = [box for box in boxes if box.get_score() > obj_threshold]\n"," \n"," return boxes\n","\n","def replace(file_path, pattern, subst):\n"," #Create temp file\n"," fh, abs_path = mkstemp()\n"," with fdopen(fh,'w') as new_file:\n"," with open(file_path) as old_file:\n"," for line in old_file:\n"," new_file.write(line.replace(pattern, subst))\n"," #Copy the file permissions from the old file to the new file\n"," copymode(file_path, abs_path)\n"," #Remove original file\n"," remove(file_path)\n"," #Move new file\n"," move(abs_path, file_path)\n","\n","with open(\"/content/gdrive/My Drive/keras-yolo2/frontend.py\", \"r\") as check:\n"," lineReader = check.readlines()\n"," reduce_lr = False\n"," for line in lineReader:\n"," if \"reduce_lr\" in line:\n"," reduce_lr = True\n"," break\n","\n","if reduce_lr == False:\n"," #replace(\"/content/gdrive/My Drive/keras-yolo2/frontend.py\",\"period=1)\",\"period=1)\\n csv_logger=CSVLogger('/content/training_evaluation.csv')\")\n"," replace(\"/content/gdrive/My Drive/keras-yolo2/frontend.py\",\"period=1)\",\"period=1)\\n reduce_lr=ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5, verbose=1)\")\n","replace(\"/content/gdrive/My Drive/keras-yolo2/frontend.py\",\"import EarlyStopping\",\"import ReduceLROnPlateau, EarlyStopping\")\n","\n","with open(\"/content/gdrive/My Drive/keras-yolo2/frontend.py\", \"r\") as check:\n"," lineReader = check.readlines()\n"," map_eval = False\n"," for line in lineReader:\n"," if \"map_evaluation\" in line:\n"," map_eval = True\n"," break\n","\n","if map_eval == False:\n"," replace(\"/content/gdrive/My Drive/keras-yolo2/frontend.py\", \"import cv2\",\"import cv2\\nfrom map_evaluation import MapEvaluation\")\n"," new_callback = ' map_evaluator = MapEvaluation(self, valid_generator,save_best=True,save_name=\"/content/gdrive/My Drive/keras-yolo2/best_map_weights.h5\",iou_threshold=0.3,score_threshold=0.3)'\n"," replace(\"/content/gdrive/My Drive/keras-yolo2/frontend.py\",\"write_images=False)\",\"write_images=False)\\n\"+new_callback)\n"," replace(\"/content/gdrive/My Drive/keras-yolo2/map_evaluation.py\",\"import keras\",\"import keras\\nimport csv\")\n"," replace(\"/content/gdrive/My Drive/keras-yolo2/map_evaluation.py\",\"from .utils\",\"from utils\")\n"," replace(\"/content/gdrive/My Drive/keras-yolo2/map_evaluation.py\",\".format(_map))\",\".format(_map))\\n with open('/content/gdrive/My Drive/mAP.csv','a+', newline='') as mAP_csv:\\n csv_writer=csv.writer(mAP_csv)\\n csv_writer.writerow(['mAP:','{:.4f}'.format(_map)])\")\n"," replace(\"/content/gdrive/My Drive/keras-yolo2/map_evaluation.py\",\"iou_threshold=0.5\",\"iou_threshold=0.3\")\n"," replace(\"/content/gdrive/My Drive/keras-yolo2/map_evaluation.py\",\"score_threshold=0.5\",\"score_threshold=0.3\")\n","\n","replace(\"/content/gdrive/My Drive/keras-yolo2/frontend.py\", \"[early_stop, checkpoint, tensorboard]\",\"[checkpoint, reduce_lr, map_evaluator]\")\n","replace(\"/content/gdrive/My Drive/keras-yolo2/frontend.py\", \"predict(self, image)\",\"predict(self,image,iou_threshold=0.3,score_threshold=0.3)\")\n","replace(\"/content/gdrive/My Drive/keras-yolo2/frontend.py\", \"self.model.summary()\",\"#self.model.summary()\")\n","from frontend import YOLO\n","\n","def train(config_path, model_path, percentage_validation):\n"," #config_path = args.conf\n","\n"," with open(config_path) as config_buffer: \n"," config = json.loads(config_buffer.read())\n","\n"," ###############################\n"," # Parse the annotations \n"," ###############################\n","\n"," # parse annotations of the training set\n"," train_imgs, train_labels = parse_annotation(config['train']['train_annot_folder'], \n"," config['train']['train_image_folder'], \n"," config['model']['labels'])\n","\n"," # parse annotations of the validation set, if any, otherwise split the training set\n"," if os.path.exists(config['valid']['valid_annot_folder']):\n"," valid_imgs, valid_labels = parse_annotation(config['valid']['valid_annot_folder'], \n"," config['valid']['valid_image_folder'], \n"," config['model']['labels'])\n"," else:\n"," train_valid_split = int((1-percentage_validation/100.)*len(train_imgs))\n"," np.random.shuffle(train_imgs)\n","\n"," valid_imgs = train_imgs[train_valid_split:]\n"," train_imgs = train_imgs[:train_valid_split]\n","\n"," if len(config['model']['labels']) > 0:\n"," overlap_labels = set(config['model']['labels']).intersection(set(train_labels.keys()))\n","\n"," print('Seen labels:\\t', train_labels)\n"," print('Given labels:\\t', config['model']['labels'])\n"," print('Overlap labels:\\t', overlap_labels) \n","\n"," if len(overlap_labels) < len(config['model']['labels']):\n"," print('Some labels have no annotations! Please revise the list of labels in the config.json file!')\n"," return\n"," else:\n"," print('No labels are provided. Train on all seen labels.')\n"," config['model']['labels'] = train_labels.keys()\n"," \n"," ###############################\n"," # Construct the model \n"," ###############################\n","\n"," yolo = YOLO(backend = config['model']['backend'],\n"," input_size = config['model']['input_size'], \n"," labels = config['model']['labels'], \n"," max_box_per_image = config['model']['max_box_per_image'],\n"," anchors = config['model']['anchors'])\n","\n"," ###############################\n"," # Load the pretrained weights (if any) \n"," ############################### \n","\n"," if os.path.exists(config['train']['pretrained_weights']):\n"," print(\"Loading pre-trained weights in\", config['train']['pretrained_weights'])\n"," yolo.load_weights(config['train']['pretrained_weights'])\n"," if os.path.exists('/content/gdrive/My Drive/mAP.csv'):\n"," os.remove('/content/gdrive/My Drive/mAP.csv')\n"," ###############################\n"," # Start the training process \n"," ###############################\n","\n"," yolo.train(train_imgs = train_imgs,\n"," valid_imgs = valid_imgs,\n"," train_times = config['train']['train_times'],\n"," valid_times = config['valid']['valid_times'],\n"," nb_epochs = config['train']['nb_epochs'], \n"," learning_rate = config['train']['learning_rate'], \n"," batch_size = config['train']['batch_size'],\n"," warmup_epochs = config['train']['warmup_epochs'],\n"," object_scale = config['train']['object_scale'],\n"," no_object_scale = config['train']['no_object_scale'],\n"," coord_scale = config['train']['coord_scale'],\n"," class_scale = config['train']['class_scale'],\n"," saved_weights_name = config['train']['saved_weights_name'],\n"," debug = config['train']['debug'])\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n"," lossDataCSVpath = os.path.join(model_path,'Quality Control/training_evaluation.csv')\n"," with open(lossDataCSVpath, 'w') as f1:\n"," writer = csv.writer(f1)\n"," mAP_df = pd.read_csv('/content/gdrive/My Drive/mAP.csv',header=None)\n"," writer.writerow(['loss','val_loss','mAP','learning rate'])\n"," for i in range(len(yolo.model.history.history['loss'])):\n"," writer.writerow([yolo.model.history.history['loss'][i], yolo.model.history.history['val_loss'][i], float(mAP_df[1][i]), yolo.model.history.history['lr'][i]])\n","\n","def predict(config, weights_path, image_path):#, model_path):\n","\n"," with open(config) as config_buffer: \n"," config = json.load(config_buffer)\n","\n"," ###############################\n"," # Make the model \n"," ###############################\n","\n"," yolo = YOLO(backend = config['model']['backend'],\n"," input_size = config['model']['input_size'], \n"," labels = config['model']['labels'], \n"," max_box_per_image = config['model']['max_box_per_image'],\n"," anchors = config['model']['anchors'])\n","\n"," ###############################\n"," # Load trained weights\n"," ############################### \n","\n"," yolo.load_weights(weights_path)\n","\n"," ###############################\n"," # Predict bounding boxes \n"," ###############################\n","\n"," if image_path[-4:] == '.mp4':\n"," video_out = image_path[:-4] + '_detected' + image_path[-4:]\n"," video_reader = cv2.VideoCapture(image_path)\n","\n"," nb_frames = int(video_reader.get(cv2.CAP_PROP_FRAME_COUNT))\n"," frame_h = int(video_reader.get(cv2.CAP_PROP_FRAME_HEIGHT))\n"," frame_w = int(video_reader.get(cv2.CAP_PROP_FRAME_WIDTH))\n","\n"," video_writer = cv2.VideoWriter(video_out,\n"," cv2.VideoWriter_fourcc(*'MPEG'), \n"," 50.0, \n"," (frame_w, frame_h))\n","\n"," for i in tqdm(range(nb_frames)):\n"," _, image = video_reader.read()\n"," \n"," boxes = yolo.predict(image)\n"," image = draw_boxes(image, boxes, config['model']['labels'])\n","\n"," video_writer.write(np.uint8(image))\n","\n"," video_reader.release()\n"," video_writer.release() \n"," else:\n"," image = cv2.imread(image_path)\n"," boxes = yolo.predict(image)\n"," image = draw_boxes(image, boxes, config['model']['labels'])\n"," save_boxes(image_path,boxes,config['model']['labels'])#,model_path)#added by LvC\n"," print(len(boxes), 'boxes are found')\n"," #print(image)\n"," cv2.imwrite(image_path[:-4] + '_detected' + image_path[-4:], image)\n"," \n"," return len(boxes)\n","\n","# function to convert BoundingBoxesOnImage object into DataFrame\n","def bbs_obj_to_df(bbs_object):\n","# convert BoundingBoxesOnImage object into array\n"," bbs_array = bbs_object.to_xyxy_array()\n","# convert array into a DataFrame ['xmin', 'ymin', 'xmax', 'ymax'] columns\n"," df_bbs = pd.DataFrame(bbs_array, columns=['xmin', 'ymin', 'xmax', 'ymax'])\n"," return df_bbs\n","\n","# Function that will extract column data for our CSV file\n","def xml_to_csv(path):\n"," xml_list = []\n"," for xml_file in glob.glob(path + '/*.xml'):\n"," tree = ET.parse(xml_file)\n"," root = tree.getroot()\n"," for member in root.findall('object'):\n"," value = (root.find('filename').text,\n"," int(root.find('size')[0].text),\n"," int(root.find('size')[1].text),\n"," member[0].text,\n"," int(float(member[4][0].text)),\n"," int(float(member[4][1].text)),\n"," int(float(member[4][2].text)),\n"," int(float(member[4][3].text))\n"," )\n"," xml_list.append(value)\n"," column_name = ['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax']\n"," xml_df = pd.DataFrame(xml_list, columns=column_name)\n"," return xml_df\n","\n","\n","\n","def image_aug(df, images_path, aug_images_path, image_prefix, augmentor):\n"," # create data frame which we're going to populate with augmented image info\n"," aug_bbs_xy = pd.DataFrame(columns=\n"," ['filename','width','height','class', 'xmin', 'ymin', 'xmax', 'ymax']\n"," )\n"," grouped = df.groupby('filename')\n"," \n"," for filename in df['filename'].unique():\n"," # get separate data frame grouped by file name\n"," group_df = grouped.get_group(filename)\n"," group_df = group_df.reset_index()\n"," group_df = group_df.drop(['index'], axis=1) \n"," # read the image\n"," image = imageio.imread(images_path+filename)\n"," # get bounding boxes coordinates and write into array \n"," bb_array = group_df.drop(['filename', 'width', 'height', 'class'], axis=1).values\n"," # pass the array of bounding boxes coordinates to the imgaug library\n"," bbs = BoundingBoxesOnImage.from_xyxy_array(bb_array, shape=image.shape)\n"," # apply augmentation on image and on the bounding boxes\n"," image_aug, bbs_aug = augmentor(image=image, bounding_boxes=bbs)\n"," # disregard bounding boxes which have fallen out of image pane \n"," bbs_aug = bbs_aug.remove_out_of_image()\n"," # clip bounding boxes which are partially outside of image pane\n"," bbs_aug = bbs_aug.clip_out_of_image()\n"," \n"," # don't perform any actions with the image if there are no bounding boxes left in it \n"," if re.findall('Image...', str(bbs_aug)) == ['Image([]']:\n"," pass\n"," \n"," # otherwise continue\n"," else:\n"," # write augmented image to a file\n"," imageio.imwrite(aug_images_path+image_prefix+filename, image_aug) \n"," # create a data frame with augmented values of image width and height\n"," info_df = group_df.drop(['xmin', 'ymin', 'xmax', 'ymax'], axis=1) \n"," for index, _ in info_df.iterrows():\n"," info_df.at[index, 'width'] = image_aug.shape[1]\n"," info_df.at[index, 'height'] = image_aug.shape[0]\n"," # rename filenames by adding the predifined prefix\n"," info_df['filename'] = info_df['filename'].apply(lambda x: image_prefix+x)\n"," # create a data frame with augmented bounding boxes coordinates using the function we created earlier\n"," bbs_df = bbs_obj_to_df(bbs_aug)\n"," # concat all new augmented info into new data frame\n"," aug_df = pd.concat([info_df, bbs_df], axis=1)\n"," # append rows to aug_bbs_xy data frame\n"," aug_bbs_xy = pd.concat([aug_bbs_xy, aug_df]) \n"," \n"," # return dataframe with updated images and bounding boxes annotations \n"," aug_bbs_xy = aug_bbs_xy.reset_index()\n"," aug_bbs_xy = aug_bbs_xy.drop(['index'], axis=1)\n"," return aug_bbs_xy\n","\n","\n","print('-------------------------------------------')\n","print(\"Depencies installed and imported.\")\n","\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n"," NORMAL = '\\033[0m'\n","\n","# Check if this is the latest version of the notebook\n","\n","Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n","print('Notebook version: '+Notebook_version[0])\n","strlist = Notebook_version[0].split('.')\n","Notebook_version_main = strlist[0]+'.'+strlist[1]\n","\n","if Notebook_version_main == Latest_notebook_version.columns:\n"," print(\"This notebook is up-to-date.\")\n","else:\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","\n","#Create a pdf document with training summary\n","\n","# save FPDF() class into a \n","# variable pdf \n","\n","def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'YOLOv2'\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Training report for '+Network+' model ('+model_name+'):\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," # add another cell\n"," if trained:\n"," training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n"," pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n"," pdf.ln(1)\n","\n"," Header_2 = 'Information for your materials and methods:'\n"," pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n"," #print(all_packages)\n","\n"," #Main Packages\n"," main_packages = ''\n"," version_numbers = []\n"," for name in ['tensorflow','numpy','Keras']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n"," cuda_version = sp.run('nvcc --version',stdout=sp.PIPE, shell=True)\n"," cuda_version = cuda_version.stdout.decode('utf-8')\n"," cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n"," gpu_name = sp.run('nvidia-smi',stdout=sp.PIPE, shell=True)\n"," gpu_name = gpu_name.stdout.decode('utf-8')\n"," gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n"," #print(cuda_version[cuda_version.find(', V')+3:-1])\n"," #print(gpu_name)\n","\n"," shape = io.imread(Training_Source+'/'+os.listdir(Training_Source)[1]).shape\n"," dataset_size = len(os.listdir(Training_Source))\n","\n"," text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' labelled images (image dimensions: '+str(shape)+') with a batch size of '+str(batch_size)+' and a custom loss function combining MSE and crossentropy losses, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," if pretrained_model:\n"," text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' labelled images (image dimensions: '+str(shape)+') with a batch size of '+str(batch_size)+' and a custom loss function combining MSE and crossentropy losses, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was retrained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.multi_cell(190, 5, txt = text, align='L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n"," pdf.set_font('')\n"," if augmentation:\n"," aug_text = 'The dataset was augmented by a factor of '+str(multiply_dataset_by)+' by'\n"," if multiply_dataset_by >= 2:\n"," aug_text = aug_text+'\\n- flipping'\n"," if multiply_dataset_by > 2:\n"," aug_text = aug_text+'\\n- rotation'\n"," else:\n"," aug_text = 'No augmentation was used for training.'\n"," pdf.multi_cell(190, 5, txt=aug_text, align='L')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," if Use_Default_Advanced_Parameters:\n"," pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n"," pdf.cell(200, 5, txt='The following parameters were used for training:')\n"," pdf.ln(1)\n"," html = \"\"\" \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
number_of_epochs{0}
train_times{1}
batch_size{2}
learning_rate{3}
false_negative_penalty{4}
false_positive_penalty{5}
position_size_penalty{6}
false_class_penalty{7}
percentage_validation{8}
\n"," \"\"\".format(number_of_epochs, train_times, batch_size, learning_rate, false_negative_penalty, false_positive_penalty, position_size_penalty, false_class_penalty, percentage_validation)\n"," pdf.write_html(html)\n","\n"," #pdf.multi_cell(190, 5, txt = text_2, align='L')\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," pdf.ln(1)\n"," pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(30, 5, txt= 'Training_source:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_Source, align = 'L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(29, 5, txt= 'Training_target:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_Source_annotations, align = 'L')\n"," #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(22, 5, txt= 'Model Path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n"," pdf.ln(1)\n"," if visualise_example == True:\n"," pdf.cell(60, 5, txt = 'Example ground-truth annotation', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread('/content/TrainingDataExample_YOLOv2.png').shape\n"," pdf.image('/content/TrainingDataExample_YOLOv2.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- YOLOv2: Redmon, Joseph, and Ali Farhadi. \"YOLO9000: better, faster, stronger.\" Proceedings of the IEEE conference on computer vision and pattern recognition. 2017.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," ref_3 = '- YOLOv2 keras: /~https://github.com/experiencor/keras-yolo2, (2018)'\n"," pdf.multi_cell(190, 5, txt = ref_3, align='L')\n"," if augmentation:\n"," ref_4 = '- imgaug: Jung, Alexander et al., /~https://github.com/aleju/imgaug, (2020)'\n"," pdf.multi_cell(190, 5, txt = ref_4, align='L')\n"," pdf.ln(3)\n"," reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(model_path+'/'+model_name+'/'+model_name+'_training_report.pdf')\n","\n"," print('------------------------------')\n"," print('PDF report exported in '+model_path+'/'+model_name+'/')\n","\n","\n","def qc_pdf_export():\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'YOLOv2'\n","\n"," day = datetime.now()\n"," datetime_str = str(day)[0:16]\n","\n"," Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate and Time: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(2)\n"," pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n"," pdf.ln(1)\n"," if os.path.exists(QC_model_folder+'/Quality Control/lossCurveAndmAPPlots.png'):\n"," exp_size = io.imread(QC_model_folder+'/Quality Control/lossCurveAndmAPPlots.png').shape\n"," pdf.image(QC_model_folder+'/Quality Control/lossCurveAndmAPPlots.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(80, 5, txt = 'P-R curves for test dataset', ln=1, align='L')\n"," pdf.ln(2)\n"," for i in range(len(AP)):\n"," if os.path.exists(QC_model_folder+'/Quality Control/P-R_curve_'+config['model']['labels'][i]+'.png'):\n"," exp_size = io.imread(QC_model_folder+'/Quality Control/P-R_curve_'+config['model']['labels'][i]+'.png').shape\n"," pdf.ln(1)\n"," pdf.image(QC_model_folder+'/Quality Control/P-R_curve_'+config['model']['labels'][i]+'.png', x=16, y=None, w=round(exp_size[1]/4), h=round(exp_size[0]/4))\n"," else:\n"," pdf.cell(100, 5, txt='For the class '+config['model']['labels'][i]+' the model did not predict any objects.', ln=1, align='L')\n"," pdf.ln(3)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n","\n"," pdf.ln(1)\n"," html = \"\"\"\n"," \n"," \n"," \"\"\"\n"," with open(QC_model_folder+'/Quality Control/QC_results.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," class_name = header[0]\n"," fp = header[1]\n"," tp = header[2]\n"," fn = header[3]\n"," recall = header[4]\n"," precision = header[5]\n"," acc = header[6]\n"," f1 = header[7]\n"," AP_score = header[8]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(class_name,fp,tp,fn,recall,precision,acc,f1,AP_score)\n"," html = html+header\n"," i=0\n"," for row in metrics:\n"," i+=1\n"," class_name = row[0]\n"," fp = row[1]\n"," tp = row[2]\n"," fn = row[3]\n"," recall = row[4]\n"," precision = row[5]\n"," acc = row[6]\n"," f1 = row[7]\n"," AP_score = row[8]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(class_name,fp,tp,fn,str(round(float(recall),3)),str(round(float(precision),3)),str(round(float(acc),3)),str(round(float(f1),3)),str(round(float(AP_score),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}{3}{4}{5}{6}{7}{8}
{0}{1}{2}{3}{4}{5}{6}{7}{8}
\"\"\"\n","\n"," pdf.write_html(html)\n"," pdf.cell(180, 5, txt='Mean average precision (mAP) over the all classes is: '+str(round(mAP_score,3)), ln=1, align='L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(3)\n"," pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n"," pdf.ln(3)\n"," exp_size = io.imread(QC_model_folder+'/Quality Control/QC_example_data.png').shape\n"," pdf.image(QC_model_folder+'/Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/10))\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.ln(3)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- YOLOv2: Redmon, Joseph, and Ali Farhadi. \"YOLO9000: better, faster, stronger.\" Proceedings of the IEEE conference on computer vision and pattern recognition. 2017.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," ref_3 = '- YOLOv2 keras: /~https://github.com/experiencor/keras-yolo2, (2018)'\n"," pdf.multi_cell(190, 5, txt = ref_3, align='L')\n","\n"," pdf.ln(3)\n"," reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(QC_model_folder+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n","\n","\n"," print('------------------------------')\n"," print('PDF report exported in '+QC_model_folder+'/Quality Control/')\n","\n","# Build requirements file for local run\n","after = [str(m) for m in sys.modules]\n","build_requirements_file(before, after)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"HLYcZR9gMv42"},"source":["# **3. Select your paths and parameters**\n","\n","---\n","\n","The code below allows the user to enter the paths to where the training data is and to define the training parameters.\n","\n","After playing the cell will display some quantitative metrics of your dataset, including a count of objects per image and the number of instances per class.\n"]},{"cell_type":"markdown","metadata":{"id":"FQ_QxtSWQ7CL"},"source":["## **3.1. Setting main training parameters**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"id":"AuESFimvMv43"},"source":[" **Paths for training, predictions and results**\n","\n","\n","**`Training_source:`, `Training_source_annotations`:** These are the paths to your folders containing the Training_source and the annotation data respectively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**Training Parameters**\n","\n","**`number_of_epochs`:**Give estimates for training performance given a number of epochs and provide a default value. **Default value: 27**\n","\n","**Note that YOLOv2 uses 3 Warm-up epochs which improves the model's performance. This means the network will train for number_of_epochs + 3 epochs.**\n","\n","**`backend`:** There are different backends which are available to be trained for YOLO. These are usually slightly different model architectures, with pretrained weights. Take a look at the available backends and research which one will be best suited for your dataset.\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`train_times:`**Input how many times to cycle through the dataset per epoch. This is more useful for smaller datasets (but risks overfitting). **Default value: 4**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 16**\n","\n","**`learning_rate:`** Input the initial value to be used as learning rate. **Default value: 0.0004**\n","\n","**`false_negative_penalty:`** Penalize wrong detection of 'no-object'. **Default: 5.0**\n","\n","**`false_positive_penalty:`** Penalize wrong detection of 'object'. **Default: 1.0**\n","\n","**`position_size_penalty:`** Penalize inaccurate positioning or size of bounding boxes. **Default:1.0**\n","\n","**`false_class_penalty:`** Penalize misclassification of object in bounding box. **Default: 1.0**\n","\n","**`percentage_validation:`** Input the percentage of your training dataset you want to use to validate the network during training. **Default value: 10** "]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","cellView":"form"},"source":["#@markdown ###Path to training images:\n","\n","Training_Source = \"\" #@param {type:\"string\"}\n","\n","# Ground truth images\n","Training_Source_annotations = \"\" #@param {type:\"string\"}\n","\n","# model name and path\n","#@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","# backend\n","#@markdown ###Choose a backend\n","backend = \"Full Yolo\" #@param [\"Select Model\",\"Full Yolo\",\"Inception3\",\"SqueezeNet\",\"MobileNet\",\"Tiny Yolo\"]\n","\n","\n","full_model_path = os.path.join(model_path,model_name)\n","if os.path.exists(full_model_path):\n"," print(bcolors.WARNING+'Model folder already exists and will be overwritten.'+bcolors.NORMAL)\n","\n","\n","# other parameters for training.\n","# @markdown ###Training Parameters\n","# @markdown Number of epochs:\n","\n","number_of_epochs = 27#@param {type:\"number\"}\n","\n","# !sed -i 's@\\\"nb_epochs\\\":.*,@\\\"nb_epochs\\\": $number_of_epochs,@g' config.json\n","\n","# #@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please input:\n","train_times = 4 #@param {type:\"integer\"}\n","batch_size = 16#@param {type:\"number\"}\n","learning_rate = 1e-4 #@param{type:\"number\"}\n","false_negative_penalty = 5.0 #@param{type:\"number\"}\n","false_positive_penalty = 1.0 #@param{type:\"number\"}\n","position_size_penalty = 1.0 #@param{type:\"number\"}\n","false_class_penalty = 1.0 #@param{type:\"number\"}\n","percentage_validation = 10#@param{type:\"number\"}\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," train_times = 4\n"," batch_size = 8\n"," learning_rate = 1e-4\n"," false_negative_penalty = 5.0\n"," false_positive_penalty = 1.0\n"," position_size_penalty = 1.0\n"," false_class_penalty = 1.0\n"," percentage_validation = 10\n","\n","\n","df_anno = []\n","dir_anno = Training_Source_annotations\n","for fnm in os.listdir(dir_anno): \n"," if not fnm.startswith('.'): ## do not include hidden folders/files\n"," tree = ET.parse(os.path.join(dir_anno,fnm))\n"," row = extract_single_xml_file(tree)\n"," row[\"fileID\"] = os.path.splitext(fnm)[0]\n"," df_anno.append(row)\n","df_anno = pd.DataFrame(df_anno)\n","\n","maxNobj = np.max(df_anno[\"Nobj\"])\n","totalNobj = np.sum(df_anno[\"Nobj\"])\n","\n","\n","class_obj = []\n","for ibbx in range(maxNobj):\n"," class_obj.extend(df_anno[\"bbx_{}_name\".format(ibbx)].values)\n","class_obj = np.array(class_obj)\n","\n","count = Counter(class_obj[class_obj != 'nan'])\n","print(count)\n","class_nm = list(count.keys())\n","class_labels = json.dumps(class_nm)\n","class_count = list(count.values())\n","asort_class_count = np.argsort(class_count)\n","\n","class_nm = np.array(class_nm)[asort_class_count]\n","class_count = np.array(class_count)[asort_class_count]\n","\n","xs = range(len(class_count))\n","\n","\n","#Show how many objects there are in the images\n","plt.figure(figsize=(15,8))\n","plt.subplot(1,2,1)\n","plt.hist(df_anno[\"Nobj\"].values,bins=50)\n","plt.title(\"Total number of objects in the dataset: {}\".format(totalNobj))\n","plt.xlabel('Number of objects per image')\n","plt.ylabel('Occurences')\n","\n","plt.subplot(1,2,2)\n","plt.barh(xs,class_count)\n","plt.yticks(xs,class_nm)\n","plt.title(\"The number of objects per class: {} classes in total\".format(len(count)))\n","plt.show()\n","\n","visualise_example = False\n","Use_pretrained_model = False\n","Use_Data_augmentation = False\n","\n","full_model_path = os.path.join(model_path,model_name)\n","if os.path.exists(full_model_path):\n"," print(bcolors.WARNING+'Model folder already exists and has been overwritten.'+bcolors.NORMAL)\n"," shutil.rmtree(full_model_path)\n","\n","# Create a new directory\n","os.mkdir(full_model_path)\n","\n","pdf_export()"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"cellView":"form","id":"0JPIts19QBBz"},"source":["#@markdown ###Play this cell to visualise an example image from your dataset to make sure annotations and images are properly matched.\n","import imageio\n","visualise_example = True\n","size = 1 \n","ind_random = np.random.randint(0,df_anno.shape[0],size=size)\n","img_dir=Training_Source\n","\n","file_suffix = os.path.splitext(os.listdir(Training_Source)[0])[1]\n","for irow in ind_random:\n"," row = df_anno.iloc[irow,:]\n"," path = os.path.join(img_dir, row[\"fileID\"] + file_suffix)\n"," # read in image\n"," img = imageio.imread(path)\n","\n"," plt.figure(figsize=(12,12))\n"," plt.imshow(img, cmap='gray') # plot image\n"," plt.title(\"Nobj={}, height={}, width={}\".format(row[\"Nobj\"],row[\"height\"],row[\"width\"]))\n"," # for each object in the image, plot the bounding box\n"," for iplot in range(row[\"Nobj\"]):\n"," plt_rectangle(plt,\n"," label = row[\"bbx_{}_name\".format(iplot)],\n"," x1=row[\"bbx_{}_xmin\".format(iplot)],\n"," y1=row[\"bbx_{}_ymin\".format(iplot)],\n"," x2=row[\"bbx_{}_xmax\".format(iplot)],\n"," y2=row[\"bbx_{}_ymax\".format(iplot)])\n"," plt.axis('off')\n"," plt.savefig('/content/TrainingDataExample_YOLOv2.png',bbox_inches='tight',pad_inches=0)\n"," plt.show() ## show the plot\n","\n","pdf_export()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xyQZKby8yFME"},"source":["##**3.2. Data augmentation**\n","\n","---\n","\n"," Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if the dataset the `Use_Data_Augmentation` box can be unticked.\n","\n","Here, the images and bounding boxes are augmented by flipping and rotation. When doubling the dataset the images are only flipped. With each higher factor of augmentation the images added to the dataset represent one further rotation to the right by 90 degrees. 8x augmentation will give a dataset that is fully rotated and flipped once."]},{"cell_type":"code","metadata":{"id":"DMqWq5-AxnFU","cellView":"form"},"source":["#@markdown ##**Augmentation Options**\n","\n","\n","Use_Data_augmentation = True #@param {type:\"boolean\"}\n","\n","multiply_dataset_by = 2 #@param {type:\"slider\", min:2, max:8, step:1}\n","\n","rotation_range = 90\n","\n","file_suffix = os.path.splitext(os.listdir(Training_Source)[0])[1]\n","if (Use_Data_augmentation):\n"," print('Data Augmentation enabled')\n"," # load images as NumPy arrays and append them to images list\n"," if os.path.exists(Training_Source+'/.ipynb_checkpoints'):\n"," shutil.rmtree(Training_Source+'/.ipynb_checkpoints')\n"," \n"," images = []\n"," for index, file in enumerate(glob.glob(Training_Source+'/*'+file_suffix)):\n"," images.append(imageio.imread(file))\n"," \n"," # how many images we have\n"," print('Augmenting {} images'.format(len(images)))\n","\n"," # apply xml_to_csv() function to convert all XML files in images/ folder into labels.csv\n"," labels_df = xml_to_csv(Training_Source_annotations)\n"," labels_df.to_csv(('/content/original_labels.csv'), index=None)\n"," \n"," # Apply flip augmentation\n"," aug = iaa.OneOf([ \n"," iaa.Fliplr(1),\n"," iaa.Flipud(1)\n"," ])\n"," aug_2 = iaa.Affine(rotate=rotation_range, fit_output=True)\n"," aug_3 = iaa.Affine(rotate=rotation_range*2, fit_output=True)\n"," aug_4 = iaa.Affine(rotate=rotation_range*3, fit_output=True)\n","\n"," #Here we create a folder that will hold the original image dataset and the augmented image dataset\n"," augmented_training_source = os.path.dirname(Training_Source)+'/'+os.path.basename(Training_Source)+'_augmentation'\n"," if os.path.exists(augmented_training_source):\n"," shutil.rmtree(augmented_training_source)\n"," os.mkdir(augmented_training_source)\n","\n"," #Here we create a folder that will hold the original image annotation dataset and the augmented image annotation dataset (the bounding boxes).\n"," augmented_training_source_annotation = os.path.dirname(Training_Source_annotations)+'/'+os.path.basename(Training_Source_annotations)+'_augmentation'\n"," if os.path.exists(augmented_training_source_annotation):\n"," shutil.rmtree(augmented_training_source_annotation)\n"," os.mkdir(augmented_training_source_annotation)\n","\n"," #Create the augmentation\n"," augmented_images_df = image_aug(labels_df, Training_Source+'/', augmented_training_source+'/', 'flip_', aug)\n"," \n"," # Concat resized_images_df and augmented_images_df together and save in a new all_labels.csv file\n"," all_labels_df = pd.concat([labels_df, augmented_images_df])\n"," all_labels_df.to_csv('/content/combined_labels.csv', index=False)\n","\n"," #Here we convert the new bounding boxes for the augmented images to PASCAL VOC .xml format\n"," def convert_to_xml(df,source,target_folder):\n"," grouped = df.groupby('filename')\n"," for file in os.listdir(source):\n"," #if file in grouped.filename:\n"," group_df = grouped.get_group(file)\n"," group_df = group_df.reset_index()\n"," group_df = group_df.drop(['index'], axis=1)\n"," #group_df = group_df.dropna(axis=0)\n"," writer = Writer(source+'/'+file,group_df.iloc[1]['width'],group_df.iloc[1]['height'])\n"," for i, row in group_df.iterrows():\n"," writer.addObject(row['class'],round(row['xmin']),round(row['ymin']),round(row['xmax']),round(row['ymax']))\n"," writer.save(target_folder+'/'+os.path.splitext(file)[0]+'.xml')\n"," convert_to_xml(all_labels_df,augmented_training_source,augmented_training_source_annotation)\n"," \n"," #Second round of augmentation\n"," if multiply_dataset_by > 2:\n"," aug_labels_df_2 = xml_to_csv(augmented_training_source_annotation)\n"," augmented_images_2_df = image_aug(aug_labels_df_2, augmented_training_source+'/', augmented_training_source+'/', 'rot1_90_', aug_2)\n"," all_aug_labels_df = pd.concat([augmented_images_df, augmented_images_2_df])\n"," #all_labels_df.to_csv('/content/all_labels_aug.csv', index=False)\n"," \n"," for file in os.listdir(augmented_training_source_annotation):\n"," os.remove(os.path.join(augmented_training_source_annotation,file))\n"," convert_to_xml(all_aug_labels_df,augmented_training_source,augmented_training_source_annotation)\n","\n"," if multiply_dataset_by > 3:\n"," print('Augmenting again')\n"," aug_labels_df_3 = xml_to_csv(augmented_training_source_annotation)\n"," augmented_images_3_df = image_aug(aug_labels_df_3, augmented_training_source+'/', augmented_training_source+'/', 'rot2_90_', aug_2)\n"," all_aug_labels_df_3 = pd.concat([all_aug_labels_df, augmented_images_3_df])\n","\n"," for file in os.listdir(augmented_training_source_annotation):\n"," os.remove(os.path.join(augmented_training_source_annotation,file))\n"," convert_to_xml(all_aug_labels_df_3,augmented_training_source,augmented_training_source_annotation)\n"," \n"," #This is a preliminary remover of potential duplicates in the augmentation\n"," #Ideally, duplicates are not even produced, but this acts as a fail safe.\n"," if multiply_dataset_by==4:\n"," for file in os.listdir(augmented_training_source):\n"," if file.startswith('rot2_90_flip_'):\n"," os.remove(os.path.join(augmented_training_source,file))\n"," os.remove(os.path.join(augmented_training_source_annotation, os.path.splitext(file)[0]+'.xml'))\n","\n"," if multiply_dataset_by > 4:\n"," print('And Again')\n"," aug_labels_df_4 = xml_to_csv(augmented_training_source_annotation)\n"," augmented_images_4_df = image_aug(aug_labels_df_4, augmented_training_source+'/',augmented_training_source+'/','rot3_90_', aug_2)\n"," all_aug_labels_df_4 = pd.concat([all_aug_labels_df_3, augmented_images_4_df])\n","\n"," for file in os.listdir(augmented_training_source_annotation):\n"," os.remove(os.path.join(augmented_training_source_annotation,file))\n"," convert_to_xml(all_aug_labels_df_4,augmented_training_source,augmented_training_source_annotation)\n","\n"," for file in os.listdir(augmented_training_source):\n"," if file.startswith('rot3_90_rot2_90_flip_'):\n"," os.remove(os.path.join(augmented_training_source,file))\n"," os.remove(os.path.join(augmented_training_source_annotation, os.path.splitext(file)[0]+'.xml'))\n"," if file.startswith('rot3_90_rot1_90_flip_'):\n"," os.remove(os.path.join(augmented_training_source,file))\n"," os.remove(os.path.join(augmented_training_source_annotation, os.path.splitext(file)[0]+'.xml'))\n"," if file.startswith('rot3_90_flip_'):\n"," os.remove(os.path.join(augmented_training_source,file))\n"," os.remove(os.path.join(augmented_training_source_annotation, os.path.splitext(file)[0]+'.xml'))\n"," if file.startswith('rot2_90_flip_'):\n"," os.remove(os.path.join(augmented_training_source,file))\n"," os.remove(os.path.join(augmented_training_source_annotation, os.path.splitext(file)[0]+'.xml'))\n","\n","\n"," if multiply_dataset_by > 5:\n"," print('And again')\n"," augmented_images_5_df = image_aug(labels_df, Training_Source+'/', augmented_training_source+'/', 'rot_90_', aug_2)\n"," all_aug_labels_df_5 = pd.concat([all_aug_labels_df_4,augmented_images_5_df])\n","\n"," for file in os.listdir(augmented_training_source_annotation):\n"," os.remove(os.path.join(augmented_training_source_annotation,file))\n"," \n"," convert_to_xml(all_aug_labels_df_5,augmented_training_source,augmented_training_source_annotation)\n","\n"," if multiply_dataset_by > 6:\n"," print('And again')\n"," augmented_images_df_6 = image_aug(labels_df, Training_Source+'/', augmented_training_source+'/', 'rot_180_', aug_3)\n"," all_aug_labels_df_6 = pd.concat([all_aug_labels_df_5,augmented_images_df_6])\n"," \n"," for file in os.listdir(augmented_training_source_annotation):\n"," os.remove(os.path.join(augmented_training_source_annotation,file))\n"," convert_to_xml(all_aug_labels_df_6,augmented_training_source,augmented_training_source_annotation)\n","\n"," if multiply_dataset_by > 7:\n"," print('And again')\n"," augmented_images_df_7 = image_aug(labels_df, Training_Source+'/', augmented_training_source+'/', 'rot_270_', aug_4)\n"," all_aug_labels_df_7 = pd.concat([all_aug_labels_df_6,augmented_images_df_7])\n"," \n"," for file in os.listdir(augmented_training_source_annotation):\n"," os.remove(os.path.join(augmented_training_source_annotation,file))\n"," convert_to_xml(all_aug_labels_df_7,augmented_training_source,augmented_training_source_annotation)\n","\n"," for file in os.listdir(Training_Source):\n"," shutil.copyfile(Training_Source+'/'+file,augmented_training_source+'/'+file)\n"," shutil.copyfile(Training_Source_annotations+'/'+os.path.splitext(file)[0]+'.xml',augmented_training_source_annotation+'/'+os.path.splitext(file)[0]+'.xml')\n"," # display new dataframe\n"," #augmented_images_df\n"," \n"," # os.chdir('/content/gdrive/My Drive/keras-yolo2')\n"," # #Change the name of the training folder\n"," # !sed -i 's@\\\"train_image_folder\\\":.*,@\\\"train_image_folder\\\": \\\"$augmented_training_source/\\\",@g' config.json\n","\n"," # #Change annotation folder\n"," # !sed -i 's@\\\"train_annot_folder\\\":.*,@\\\"train_annot_folder\\\": \\\"$augmented_training_source_annotation/\\\",@g' config.json\n","\n"," df_anno = []\n"," dir_anno = augmented_training_source_annotation\n"," for fnm in os.listdir(dir_anno): \n"," if not fnm.startswith('.'): ## do not include hidden folders/files\n"," tree = ET.parse(os.path.join(dir_anno,fnm))\n"," row = extract_single_xml_file(tree)\n"," row[\"fileID\"] = os.path.splitext(fnm)[0]\n"," df_anno.append(row)\n"," df_anno = pd.DataFrame(df_anno)\n","\n"," maxNobj = np.max(df_anno[\"Nobj\"])\n","\n"," #Write the annotations to a csv file\n"," #df_anno.to_csv(model_path+'/annot.csv', index=False)#header=False, sep=',')\n","\n"," #Show how many objects there are in the images\n"," plt.figure()\n"," plt.subplot(2,1,1)\n"," plt.hist(df_anno[\"Nobj\"].values,bins=50)\n"," plt.title(\"max N of objects per image={}\".format(maxNobj))\n"," plt.show()\n","\n"," #Show the classes and how many there are of each in the dataset\n"," class_obj = []\n"," for ibbx in range(maxNobj):\n"," class_obj.extend(df_anno[\"bbx_{}_name\".format(ibbx)].values)\n"," class_obj = np.array(class_obj)\n","\n"," count = Counter(class_obj[class_obj != 'nan'])\n"," print(count)\n"," class_nm = list(count.keys())\n"," class_labels = json.dumps(class_nm)\n"," class_count = list(count.values())\n"," asort_class_count = np.argsort(class_count)\n","\n"," class_nm = np.array(class_nm)[asort_class_count]\n"," class_count = np.array(class_count)[asort_class_count]\n","\n"," xs = range(len(class_count))\n","\n"," plt.subplot(2,1,2)\n"," plt.barh(xs,class_count)\n"," plt.yticks(xs,class_nm)\n"," plt.title(\"The number of objects per class: {} objects in total\".format(len(count)))\n"," plt.show()\n","\n","else:\n"," print('No augmentation will be used')\n","\n","pdf_export(augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"cellView":"form","id":"y7HVvJZuNU1t"},"source":["#@markdown ###Play this cell to visualise some example images from your **augmented** dataset to make sure annotations and images are properly matched.\n","if (Use_Data_augmentation):\n"," df_anno_aug = []\n"," dir_anno_aug = augmented_training_source_annotation\n"," for fnm in os.listdir(dir_anno_aug): \n"," if not fnm.startswith('.'): ## do not include hidden folders/files\n"," tree = ET.parse(os.path.join(dir_anno_aug,fnm))\n"," row = extract_single_xml_file(tree)\n"," row[\"fileID\"] = os.path.splitext(fnm)[0]\n"," df_anno_aug.append(row)\n"," df_anno_aug = pd.DataFrame(df_anno_aug)\n","\n"," size = 3 \n"," ind_random = np.random.randint(0,df_anno_aug.shape[0],size=size)\n"," img_dir=augmented_training_source\n","\n"," file_suffix = os.path.splitext(os.listdir(augmented_training_source)[0])[1]\n"," for irow in ind_random:\n"," row = df_anno_aug.iloc[irow,:]\n"," path = os.path.join(img_dir, row[\"fileID\"] + file_suffix)\n"," # read in image\n"," img = imageio.imread(path)\n","\n"," plt.figure(figsize=(12,12))\n"," plt.imshow(img, cmap='gray') # plot image\n"," plt.title(\"Nobj={}, height={}, width={}\".format(row[\"Nobj\"],row[\"height\"],row[\"width\"]))\n"," # for each object in the image, plot the bounding box\n"," for iplot in range(row[\"Nobj\"]):\n"," plt_rectangle(plt,\n"," label = row[\"bbx_{}_name\".format(iplot)],\n"," x1=row[\"bbx_{}_xmin\".format(iplot)],\n"," y1=row[\"bbx_{}_ymin\".format(iplot)],\n"," x2=row[\"bbx_{}_xmax\".format(iplot)],\n"," y2=row[\"bbx_{}_ymax\".format(iplot)])\n"," plt.show() ## show the plot\n"," print('These are the augmented training images.')\n","\n","else:\n"," print('Data augmentation disabled.')\n","\n","# else:\n","# for irow in ind_random:\n","# row = df_anno.iloc[irow,:]\n","# path = os.path.join(img_dir, row[\"fileID\"] + file_suffix)\n","# # read in image\n","# img = imageio.imread(path)\n","\n","# plt.figure(figsize=(12,12))\n","# plt.imshow(img, cmap='gray') # plot image\n","# plt.title(\"Nobj={}, height={}, width={}\".format(row[\"Nobj\"],row[\"height\"],row[\"width\"]))\n","# # for each object in the image, plot the bounding box\n","# for iplot in range(row[\"Nobj\"]):\n","# plt_rectangle(plt,\n","# label = row[\"bbx_{}_name\".format(iplot)],\n","# x1=row[\"bbx_{}_xmin\".format(iplot)],\n","# y1=row[\"bbx_{}_ymin\".format(iplot)],\n","# x2=row[\"bbx_{}_xmax\".format(iplot)],\n","# y2=row[\"bbx_{}_ymax\".format(iplot)])\n","# plt.show() ## show the plot\n","# print('These are the non-augmented training images.')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"3L9zSGtORKYI"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a YOLOv2 model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"9vC2n-HeLdiJ","cellView":"form"},"source":["# @markdown ##Loading weights from a pretrained network\n","\n","# Training_Source = \"\" #@param{type:\"string\"}\n","# Training_Source_annotation = \"\" #@param{type:\"string\"}\n","# Check if the right files exist\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","Weights_choice = \"best\" #@param [\"last\", \"best\"]\n","\n","pretrained_model_path = \"\" #@param{type:\"string\"}\n","h5_file_path = pretrained_model_path+'/'+Weights_choice+'_weights.h5'\n","\n","if not os.path.exists(h5_file_path) and Use_pretrained_model:\n"," print('WARNING pretrained model does not exist')\n"," Use_pretrained_model = False\n","\n","# os.chdir('/content/gdrive/My Drive/keras-yolo2')\n","# !sed -i 's@\\\"pretrained_weights\\\":.*,@\\\"pretrained_weights\\\": \\\"$h5_file_path\\\",@g' config.json\n","\n","if Use_pretrained_model:\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4):\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n"," learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n"," learning_rate = bestLearningRate\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," #bestLearningRate = learning_rate\n"," #lastLearningRate = learning_rate\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","else:\n"," print('No pre-trained models will be used.')\n","\n"," \n"," # !sed -i 's@\\\"warmup_epochs\\\":.*,@\\\"warmup_epochs\\\": 0,@g' config.json\n"," # !sed -i 's@\\\"learning_rate\\\":.*,@\\\"learning_rate\\\": $learning_rate,@g' config.json\n","\n","# with open(os.path.join(pretrained_model_path, 'Quality Control', 'lr.csv'),'r') as csvfile:\n","# csvRead = pd.read_csv(csvfile, sep=',')\n","# #print(csvRead)\n"," \n","# if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4)\n","# print(\"pretrained network learning rate found\")\n","# #find the last learning rate\n","# lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n","# #Find the learning rate corresponding to the lowest validation loss\n","# min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n","# #print(min_val_loss)\n","# bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n","# if Weights_choice == \"last\":\n","# print('Last learning rate: '+str(lastLearningRate))\n","\n","# if Weights_choice == \"best\":\n","# print('Learning rate of best validation loss: '+str(bestLearningRate))\n","\n","# if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n","# bestLearningRate = initial_learning_rate\n","# lastLearningRate = initial_learning_rate\n","# print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","\n","pdf_export(augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"MCGklf1vZf2M"},"source":["#**4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"0Dfn8ZsEMv5d"},"source":["## **4.1. Start training**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches.\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"code","metadata":{"scrolled":true,"id":"iwNmp1PUzRDQ","cellView":"form"},"source":["#@markdown ##Start training\n","\n","# full_model_path = os.path.join(model_path,model_name)\n","# if os.path.exists(full_model_path):\n","# print(bcolors.WARNING+'Model folder already exists and has been overwritten.'+bcolors.NORMAL)\n","# shutil.rmtree(full_model_path)\n","\n","# # Create a new directory\n","# os.mkdir(full_model_path)\n","\n","# ------------\n","\n","os.chdir('/content/gdrive/My Drive/keras-yolo2')\n","if backend == \"Full Yolo\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/full_yolo_backend.h5'):\n"," !wget /~https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/full_yolo_backend.h5\n","elif backend == \"Inception3\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/inception_backend.h5'):\n"," !wget /~https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/inception_backend.h5\n","elif backend == \"MobileNet\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/mobilenet_backend.h5'):\n"," !wget /~https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/mobilenet_backend.h5\n","elif backend == \"SqueezeNet\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/squeezenet_backend.h5'):\n"," !wget /~https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/squeezenet_backend.h5\n","elif backend == \"Tiny Yolo\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/tiny_yolo_backend.h5'):\n"," !wget /~https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/tiny_yolo_backend.h5\n","\n","#os.chdir('/content/drive/My Drive/Zero-Cost Deep-Learning to Enhance Microscopy/Various dataset/Detection_Dataset_2/BCCD.v2.voc')\n","#if not os.path.exists(model_path+'/full_raccoon.h5'):\n"," # !wget --load-cookies /tmp/cookies.txt \"https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1NWbrpMGLc84ow-4gXn2mloFocFGU595s' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\\1\\n/p')&id=1NWbrpMGLc84ow-4gXn2mloFocFGU595s\" -O full_yolo_raccoon.h5 && rm -rf /tmp/cookies.txt\n","\n","\n","full_model_file_path = full_model_path+'/best_weights.h5'\n","os.chdir('/content/gdrive/My Drive/keras-yolo2/')\n","\n","#Change backend name\n","!sed -i 's@\\\"backend\\\":.*,@\\\"backend\\\": \\\"$backend\\\",@g' config.json\n","\n","#Change the name of the training folder\n","!sed -i 's@\\\"train_image_folder\\\":.*,@\\\"train_image_folder\\\": \\\"$Training_Source/\\\",@g' config.json\n","\n","#Change annotation folder\n","!sed -i 's@\\\"train_annot_folder\\\":.*,@\\\"train_annot_folder\\\": \\\"$Training_Source_annotations/\\\",@g' config.json\n","\n","#Change the name of the saved model\n","!sed -i 's@\\\"saved_weights_name\\\":.*,@\\\"saved_weights_name\\\": \\\"$full_model_file_path\\\",@g' config.json\n","\n","#Change warmup epochs for untrained model\n","!sed -i 's@\\\"warmup_epochs\\\":.*,@\\\"warmup_epochs\\\": 3,@g' config.json\n","\n","#When defining a new model we should reset the pretrained model parameter\n","!sed -i 's@\\\"pretrained_weights\\\":.*,@\\\"pretrained_weights\\\": \\\"No_pretrained_weights\\\",@g' config.json\n","\n","!sed -i 's@\\\"nb_epochs\\\":.*,@\\\"nb_epochs\\\": $number_of_epochs,@g' config.json\n","\n","!sed -i 's@\\\"train_times\\\":.*,@\\\"train_times\\\": $train_times,@g' config.json\n","!sed -i 's@\\\"batch_size\\\":.*,@\\\"batch_size\\\": $batch_size,@g' config.json\n","!sed -i 's@\\\"learning_rate\\\":.*,@\\\"learning_rate\\\": $learning_rate,@g' config.json\n","!sed -i 's@\\\"object_scale\":.*,@\\\"object_scale\\\": $false_negative_penalty,@g' config.json\n","!sed -i 's@\\\"no_object_scale\":.*,@\\\"no_object_scale\\\": $false_positive_penalty,@g' config.json\n","!sed -i 's@\\\"coord_scale\\\":.*,@\\\"coord_scale\\\": $position_size_penalty,@g' config.json\n","!sed -i 's@\\\"class_scale\\\":.*,@\\\"class_scale\\\": $false_class_penalty,@g' config.json\n","\n","#Write the annotations to a csv file\n","df_anno.to_csv(full_model_path+'/annotations.csv', index=False)#header=False, sep=',')\n","\n","!sed -i 's@\\\"labels\\\":.*@\\\"labels\\\": $class_labels@g' config.json\n","\n","\n","#Generate anchors for the bounding boxes\n","os.chdir('/content/gdrive/My Drive/keras-yolo2')\n","output = sp.getoutput('python ./gen_anchors.py -c ./config.json')\n","\n","anchors_1 = output.find(\"[\")\n","anchors_2 = output.find(\"]\")\n","\n","config_anchors = output[anchors_1:anchors_2+1]\n","!sed -i 's@\\\"anchors\\\":.*,@\\\"anchors\\\": $config_anchors,@g' config.json\n","\n","\n","!sed -i 's@\\\"pretrained_weights\\\":.*,@\\\"pretrained_weights\\\": \\\"$h5_file_path\\\",@g' config.json\n","\n","\n","# !sed -i 's@\\\"anchors\\\":.*,@\\\"anchors\\\": $config_anchors,@g' config.json\n","\n","\n","if Use_pretrained_model:\n"," !sed -i 's@\\\"warmup_epochs\\\":.*,@\\\"warmup_epochs\\\": 0,@g' config.json\n"," !sed -i 's@\\\"learning_rate\\\":.*,@\\\"learning_rate\\\": $learning_rate,@g' config.json\n","\n","if Use_Data_augmentation:\n"," # os.chdir('/content/gdrive/My Drive/keras-yolo2')\n"," #Change the name of the training folder\n"," !sed -i 's@\\\"train_image_folder\\\":.*,@\\\"train_image_folder\\\": \\\"$augmented_training_source/\\\",@g' config.json\n","\n"," #Change annotation folder\n"," !sed -i 's@\\\"train_annot_folder\\\":.*,@\\\"train_annot_folder\\\": \\\"$augmented_training_source_annotation/\\\",@g' config.json\n","\n","\n","# ------------\n","\n","\n","\n","if os.path.exists(full_model_path+\"/Quality Control\"):\n"," shutil.rmtree(full_model_path+\"/Quality Control\")\n","os.makedirs(full_model_path+\"/Quality Control\")\n","\n","\n","start = time.time()\n","\n","os.chdir('/content/gdrive/My Drive/keras-yolo2')\n","train('config.json', full_model_path, percentage_validation)\n","\n","shutil.copyfile('/content/gdrive/My Drive/keras-yolo2/config.json',full_model_path+'/config.json')\n","\n","if os.path.exists('/content/gdrive/My Drive/keras-yolo2/best_map_weights.h5'):\n"," shutil.move('/content/gdrive/My Drive/keras-yolo2/best_map_weights.h5',full_model_path+'/best_map_weights.h5')\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","#Create a pdf document with training summary\n","\n","pdf_export(trained = True, augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"_0Hynw3-xHp1"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"id":"eAJzMwPA6tlH","cellView":"form"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the name of the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","if (Use_the_current_trained_model): \n"," QC_model_folder = full_model_path\n","\n","#print(os.path.join(model_path, model_name))\n","\n","QC_model_name = os.path.basename(QC_model_folder)\n","\n","if os.path.exists(QC_model_folder):\n"," print(\"The \"+QC_model_name+\" model will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path before proceeding further.')\n","\n","if Use_the_current_trained_model == False:\n"," if os.path.exists('/content/gdrive/My Drive/keras-yolo2/config.json'):\n"," os.remove('/content/gdrive/My Drive/keras-yolo2/config.json')\n"," shutil.copyfile(QC_model_folder+'/config.json','/content/gdrive/My Drive/keras-yolo2/config.json')\n","\n","#@markdown ###Which backend is the model using?\n","backend = \"Full Yolo\" #@param [\"Select Model\",\"Full Yolo\",\"Inception3\",\"SqueezeNet\",\"MobileNet\",\"Tiny Yolo\"]\n","os.chdir('/content/gdrive/My Drive/keras-yolo2')\n","if backend == \"Full Yolo\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/full_yolo_backend.h5'):\n"," !wget /~https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/full_yolo_backend.h5\n","elif backend == \"Inception3\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/inception_backend.h5'):\n"," !wget /~https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/inception_backend.h5\n","elif backend == \"MobileNet\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/mobilenet_backend.h5'):\n"," !wget /~https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/mobilenet_backend.h5\n","elif backend == \"SqueezeNet\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/squeezenet_backend.h5'):\n"," !wget /~https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/squeezenet_backend.h5\n","elif backend == \"Tiny Yolo\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/tiny_yolo_backend.h5'):\n"," !wget /~https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/tiny_yolo_backend.h5\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"dhJROwlAMv5o"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"id":"vMzSP50kMv5p","cellView":"form"},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","import csv\n","from matplotlib import pyplot as plt\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","mAPDataFromCSV = []\n","with open(QC_model_folder+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n"," mAPDataFromCSV.append(float(row[2]))\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(20,15))\n","\n","plt.subplot(3,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(3,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","#plt.savefig(os.path.dirname(QC_model_folder)+'/Quality Control/lossCurvePlots.png')\n","#plt.show()\n","\n","plt.subplot(3,1,3)\n","plt.plot(epochNumber,mAPDataFromCSV, label='mAP score')\n","plt.title('mean average precision (mAP) vs. epoch number (linear scale)')\n","plt.ylabel('mAP score')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_folder+'/Quality Control/lossCurveAndmAPPlots.png',bbox_inches='tight', pad_inches=0)\n","plt.show()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"X5_92nL2xdP6"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display an overlay of the input images ground-truth (solid lines) and predicted boxes (dashed lines). Additionally, the below cell will show the mAP value of the model on the QC data together with plots of the Precision-Recall curves for all the classes in the dataset. If you want to read in more detail about these scores, we recommend [this brief explanation](https://medium.com/@jonathan_hui/map-mean-average-precision-for-object-detection-45c121a31173).\n","\n"," The images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" should contain images (e.g. as .jpg)and annotations (.xml files)!\n","\n","Since the training saves three different models, for the best validation loss (`best_weights`), best average precision (`best_mAP_weights`) and the model after the last epoch (`last_weights`), you should choose which ones you want to use for quality control or prediction. We recommend using `best_map_weights` because they should yield the best performance on the dataset. However, it can be worth testing how well `best_weights` perform too.\n","\n","**mAP score:** This refers to the mean average precision of the model on the given dataset. This value gives an indication how precise the predictions of the classes on this dataset are when compared to the ground-truth. Values closer to 1 indicate a good fit.\n","\n","**Precision:** This is the proportion of the correct classifications (true positives) in all the predictions made by the model.\n","\n","**Recall:** This is the proportion of the detected true positives in all the detectable data."]},{"cell_type":"code","metadata":{"id":"w90MdriMxhjD","cellView":"form"},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Annotations_QC_folder = \"\" #@param{type:\"string\"}\n","\n","#@markdown ##Choose which model you want to evaluate:\n","model_choice = \"best_weights\" #@param[\"best_weights\",\"last_weights\",\"best_map_weights\"]\n","\n","file_suffix = os.path.splitext(os.listdir(Source_QC_folder)[0])[1]\n","\n","# Create a quality control/Prediction Folder\n","if os.path.exists(QC_model_folder+\"/Quality Control/Prediction\"):\n"," shutil.rmtree(QC_model_folder+\"/Quality Control/Prediction\")\n","\n","os.makedirs(QC_model_folder+\"/Quality Control/Prediction\")\n","\n","#Delete old csv with box predictions if one exists\n","\n","if os.path.exists('/content/predicted_bounding_boxes.csv'):\n"," os.remove('/content/predicted_bounding_boxes.csv')\n","if os.path.exists('/content/predicted_bounding_boxes_names.csv'):\n"," os.remove('/content/predicted_bounding_boxes_names.csv')\n","if os.path.exists(Source_QC_folder+'/.ipynb_checkpoints'):\n"," shutil.rmtree(Source_QC_folder+'/.ipynb_checkpoints')\n","\n","os.chdir('/content/gdrive/My Drive/keras-yolo2')\n","\n","n_objects = []\n","for img in os.listdir(Source_QC_folder):\n"," full_image_path = Source_QC_folder+'/'+img\n"," print('----')\n"," print(img)\n"," n_obj = predict('config.json',QC_model_folder+'/'+model_choice+'.h5',full_image_path)\n"," n_objects.append(n_obj)\n"," K.clear_session()\n","\n","for img in os.listdir(Source_QC_folder):\n"," if img.endswith('detected'+file_suffix):\n"," shutil.move(Source_QC_folder+'/'+img,QC_model_folder+\"/Quality Control/Prediction/\"+img)\n","\n","#Here, we open the config file to get the classes fro the GT labels\n","config_path = '/content/gdrive/My Drive/keras-yolo2/config.json'\n","with open(config_path) as config_buffer:\n"," config = json.load(config_buffer)\n","\n","#Make a csv file to read into imagej macro, to create custom bounding boxes\n","header = ['filename']+['xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class']*max(n_objects)\n","with open('/content/predicted_bounding_boxes.csv', newline='') as inFile, open('/content/predicted_bounding_boxes_new.csv', 'w', newline='') as outfile:\n"," r = csv.reader(inFile)\n"," w = csv.writer(outfile)\n"," next(r, None) # skip the first row from the reader, the old header\n"," # write new header\n"," w.writerow(header)\n"," # copy the rest\n"," for row in r:\n"," w.writerow(row)\n","\n","df_bbox=pd.read_csv('/content/predicted_bounding_boxes_new.csv')\n","df_bbox=df_bbox.transpose()\n","new_header = df_bbox.iloc[0] #grab the first row for the header\n","df_bbox = df_bbox[1:] #take the data less the header row\n","df_bbox.columns = new_header #set the header row as the df header\n","df_bbox.sort_values(by='filename',axis=1,inplace=True)\n","df_bbox.to_csv(QC_model_folder+'/Quality Control/predicted_bounding_boxes_for_custom_ROI_QC.csv')\n","\n","F1_scores, AP, recall, precision = _calc_avg_precisions(config,Source_QC_folder,Annotations_QC_folder+'/',QC_model_folder+'/'+model_choice+'.h5',0.3,0.3)\n","\n","\n","\n","with open(QC_model_folder+\"/Quality Control/QC_results.csv\", \"r\") as file:\n"," x = from_csv(file)\n"," \n","print(x)\n","\n","mAP_score = sum(AP.values())/len(AP)\n","\n","print('mAP score for QC dataset: '+str(mAP_score))\n","\n","for i in range(len(AP)):\n"," if AP[i]!=0:\n"," fig = plt.figure(figsize=(8,4))\n"," if len(recall[i]) == 1:\n"," new_recall = np.linspace(0,list(recall[i])[0],10)\n"," new_precision = list(precision[i])*10\n"," fig = plt.figure(figsize=(3,2))\n"," plt.plot(new_recall,new_precision)\n"," plt.axis([min(new_recall),1,0,1.02])\n"," plt.xlabel('Recall',fontsize=14)\n"," plt.ylabel('Precision',fontsize=14)\n"," plt.title(config['model']['labels'][i]+', AP: '+str(round(AP[i],3)),fontsize=14)\n"," plt.fill_between(new_recall,new_precision,alpha=0.3)\n"," plt.savefig(QC_model_folder+'/Quality Control/P-R_curve_'+config['model']['labels'][i]+'.png', bbox_inches='tight', pad_inches=0)\n"," plt.show()\n"," else:\n"," new_recall = list(recall[i])\n"," new_recall.append(new_recall[len(new_recall)-1])\n"," new_precision = list(precision[i])\n"," new_precision.append(0)\n"," plt.plot(new_recall,new_precision)\n"," plt.axis([min(new_recall),1,0,1.02])\n"," plt.xlabel('Recall',fontsize=14)\n"," plt.ylabel('Precision',fontsize=14)\n"," plt.title(config['model']['labels'][i]+', AP: '+str(round(AP[i],3)),fontsize=14)\n"," plt.fill_between(new_recall,new_precision,alpha=0.3)\n"," plt.savefig(QC_model_folder+'/Quality Control/P-R_curve_'+config['model']['labels'][i]+'.png', bbox_inches='tight', pad_inches=0)\n"," plt.show()\n"," else:\n"," print('No object of class '+config['model']['labels'][i]+' was detected. This will lower the mAP score. Consider adding an image containing this class to your QC dataset to see if the model can detect this class at all.')\n","\n","\n","# --------------------------------------------------------------\n","add_header('/content/predicted_bounding_boxes_names.csv','/content/predicted_bounding_boxes_names_new.csv')\n","\n","# This will display a randomly chosen dataset input and predicted output\n","\n","print('Below is an example input, prediction and ground truth annotation from your test dataset.')\n","random_choice = random.choice(os.listdir(Source_QC_folder))\n","file_suffix = os.path.splitext(random_choice)[1]\n","\n","plt.figure(figsize=(30,15))\n","\n","### Display Raw input ###\n","\n","x = plt.imread(Source_QC_folder+\"/\"+random_choice)\n","plt.subplot(1,3,1)\n","plt.axis('off')\n","plt.imshow(x, interpolation='nearest', cmap='gray')\n","plt.title('Input', fontsize = 12)\n","\n","### Display Predicted annotation ###\n","\n","df_bbox2 = pd.read_csv('/content/predicted_bounding_boxes_names_new.csv')\n","for img in range(0,df_bbox2.shape[0]):\n"," df_bbox2.iloc[img]\n"," row = pd.DataFrame(df_bbox2.iloc[img])\n"," if row[img][0] == random_choice:\n"," row = row.dropna()\n"," image = imageio.imread(Source_QC_folder+'/'+row[img][0])\n"," #plt.figure(figsize=(12,12))\n"," plt.subplot(1,3,2)\n"," plt.axis('off')\n"," plt.imshow(image, cmap='gray') # plot image\n"," plt.title('Prediction', fontsize=12)\n"," for i in range(1,int(len(row)-1),6):\n"," plt_rectangle(plt,\n"," label = row[img][i+5],\n"," x1=row[img][i],#.format(iplot)],\n"," y1=row[img][i+1],\n"," x2=row[img][i+2],\n"," y2=row[img][i+3])#,\n"," #fontsize=8)\n","\n","\n","### Display GT Annotation ###\n","\n","df_anno_QC_gt = []\n","for fnm in os.listdir(Annotations_QC_folder): \n"," if not fnm.startswith('.'): ## do not include hidden folders/files\n"," tree = ET.parse(os.path.join(Annotations_QC_folder,fnm))\n"," row = extract_single_xml_file(tree)\n"," row[\"fileID\"] = os.path.splitext(fnm)[0]\n"," df_anno_QC_gt.append(row)\n","df_anno_QC_gt = pd.DataFrame(df_anno_QC_gt)\n","#maxNobj = np.max(df_anno_QC_gt[\"Nobj\"])\n","\n","for i in range(0,df_anno_QC_gt.shape[0]):\n"," if df_anno_QC_gt.iloc[i][\"fileID\"]+file_suffix == random_choice:\n"," row = df_anno_QC_gt.iloc[i]\n","\n","img = imageio.imread(Source_QC_folder+'/'+random_choice)\n","plt.subplot(1,3,3)\n","plt.axis('off')\n","plt.imshow(img, cmap='gray') # plot image\n","plt.title('Ground Truth annotations', fontsize=12)\n","\n","# for each object in the image, plot the bounding box\n","for iplot in range(row[\"Nobj\"]):\n"," plt_rectangle(plt,\n"," label = row[\"bbx_{}_name\".format(iplot)],\n"," x1=row[\"bbx_{}_xmin\".format(iplot)],\n"," y1=row[\"bbx_{}_ymin\".format(iplot)],\n"," x2=row[\"bbx_{}_xmax\".format(iplot)],\n"," y2=row[\"bbx_{}_ymax\".format(iplot)])#,\n"," #fontsize=8)\n","\n","### Show the plot ###\n","plt.savefig(QC_model_folder+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","plt.show()\n","\n","\n","#Make a pdf summary of the QC results\n","\n","qc_pdf_export()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-tJeeJjLnRkP"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"d8wuQGjoq6eN"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images.\n","\n","**`Prediction_model_path`:** This should be the folder that contains your model."]},{"cell_type":"code","metadata":{"id":"y2TD5p7MZrEb","cellView":"form"},"source":["#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then play the cell to predict outputs from your unseen images.\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","file_suffix = os.path.splitext(os.listdir(Data_folder)[0])[1]\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, provide the name of the model and path to model folder:\n","\n","Prediction_model_path = \"\" #@param {type:\"string\"}\n","\n","#@markdown ###Which model do you want to use?\n","model_choice = \"best_map_weights\" #@param[\"best_weights\",\"last_weights\",\"best_map_weights\"]\n","\n","#@markdown ###Which backend is the model using?\n","backend = \"Full Yolo\" #@param [\"Select Model\",\"Full Yolo\",\"Inception3\",\"SqueezeNet\",\"MobileNet\",\"Tiny Yolo\"]\n","\n","\n","os.chdir('/content/gdrive/My Drive/keras-yolo2')\n","if backend == \"Full Yolo\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/full_yolo_backend.h5'):\n"," !wget /~https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/full_yolo_backend.h5\n","elif backend == \"Inception3\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/inception_backend.h5'):\n"," !wget /~https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/inception_backend.h5\n","elif backend == \"MobileNet\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/mobilenet_backend.h5'):\n"," !wget /~https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/mobilenet_backend.h5\n","elif backend == \"SqueezeNet\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/squeezenet_backend.h5'):\n"," !wget /~https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/squeezenet_backend.h5\n","elif backend == \"Tiny Yolo\":\n"," if not os.path.exists('/content/gdrive/My Drive/keras-yolo2/tiny_yolo_backend.h5'):\n"," !wget /~https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/tiny_yolo_backend.h5\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_path = full_model_path\n","\n","if Use_the_current_trained_model == False:\n"," if os.path.exists('/content/gdrive/My Drive/keras-yolo2/config.json'):\n"," os.remove('/content/gdrive/My Drive/keras-yolo2/config.json')\n"," shutil.copyfile(Prediction_model_path+'/config.json','/content/gdrive/My Drive/keras-yolo2/config.json')\n","\n","if os.path.exists(Prediction_model_path+'/'+model_choice+'.h5'):\n"," print(\"The \"+os.path.basename(Prediction_model_path)+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","# Provide the code for performing predictions and saving them\n","print(\"Images will be saved into folder:\", Result_folder)\n","\n","\n","# ----- Predictions ------\n","\n","start = time.time()\n","\n","#Remove any files that might be from the prediction of QC examples.\n","if os.path.exists('/content/predicted_bounding_boxes.csv'):\n"," os.remove('/content/predicted_bounding_boxes.csv')\n","if os.path.exists('/content/predicted_bounding_boxes_new.csv'):\n"," os.remove('/content/predicted_bounding_boxes_new.csv')\n","if os.path.exists('/content/predicted_bounding_boxes_names.csv'):\n"," os.remove('/content/predicted_bounding_boxes_names.csv')\n","if os.path.exists('/content/predicted_bounding_boxes_names_new.csv'):\n"," os.remove('/content/predicted_bounding_boxes_names_new.csv')\n","\n","os.chdir('/content/gdrive/My Drive/keras-yolo2')\n","\n","if os.path.exists(Data_folder+'/.ipynb_checkpoints'):\n"," shutil.rmtree(Data_folder+'/.ipynb_checkpoints')\n","\n","n_objects = []\n","for img in os.listdir(Data_folder):\n"," full_image_path = Data_folder+'/'+img\n"," n_obj = predict('config.json',Prediction_model_path+'/'+model_choice+'.h5',full_image_path)#,Result_folder)\n"," n_objects.append(n_obj)\n"," K.clear_session()\n","for img in os.listdir(Data_folder):\n"," if img.endswith('detected'+file_suffix):\n"," shutil.move(Data_folder+'/'+img,Result_folder+'/'+img)\n","\n","if os.path.exists('/content/predicted_bounding_boxes.csv'):\n"," #shutil.move('/content/predicted_bounding_boxes.csv',Result_folder+'/predicted_bounding_boxes.csv')\n"," print('Bounding box labels and coordinates saved to '+ Result_folder)\n","else:\n"," print('For some reason the bounding box labels and coordinates were not saved. Check that your predictions look as expected.')\n","\n","#Make a csv file to read into imagej macro, to create custom bounding boxes\n","header = ['filename']+['xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class']*max(n_objects)\n","with open('/content/predicted_bounding_boxes.csv', newline='') as inFile, open('/content/predicted_bounding_boxes_new.csv', 'w', newline='') as outfile:\n"," r = csv.reader(inFile)\n"," w = csv.writer(outfile)\n"," next(r, None) # skip the first row from the reader, the old header\n"," # write new header\n"," w.writerow(header)\n"," # copy the rest\n"," for row in r:\n"," w.writerow(row)\n","\n","df_bbox=pd.read_csv('/content/predicted_bounding_boxes_new.csv')\n","df_bbox=df_bbox.transpose()\n","new_header = df_bbox.iloc[0] #grab the first row for the header\n","df_bbox = df_bbox[1:] #take the data less the header row\n","df_bbox.columns = new_header #set the header row as the df header\n","df_bbox.sort_values(by='filename',axis=1,inplace=True)\n","df_bbox.to_csv(Result_folder+'/predicted_bounding_boxes_for_custom_ROI.csv')\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"tP1isF0PO4C1"},"source":["## **6.2. Inspect the predicted output**\n","---\n","\n"]},{"cell_type":"code","metadata":{"id":"ypLeYWnzO6tv","cellView":"form"},"source":["# @markdown ##Run this cell to display a randomly chosen input and its corresponding predicted output.\n","import random\n","from matplotlib.pyplot import imread\n","# This will display a randomly chosen dataset input and predicted output\n","random_choice = random.choice(os.listdir(Data_folder))\n","print(random_choice)\n","x = imread(Data_folder+\"/\"+random_choice)\n","\n","os.chdir(Result_folder)\n","y = imread(Result_folder+\"/\"+os.path.splitext(random_choice)[0]+'_detected'+file_suffix)\n","\n","plt.figure(figsize=(20,8))\n","\n","plt.subplot(1,3,1)\n","plt.axis('off')\n","plt.imshow(x, interpolation='nearest', cmap='gray')\n","plt.title('Input')\n","\n","plt.subplot(1,3,2)\n","plt.axis('off')\n","plt.imshow(y, interpolation='nearest')\n","plt.title('Predicted output');\n","\n","add_header('/content/predicted_bounding_boxes_names.csv','/content/predicted_bounding_boxes_names_new.csv')\n","\n","#We need to edit this predicted_bounding_boxes_new.csv file slightly to display the bounding boxes\n","df_bbox2 = pd.read_csv('/content/predicted_bounding_boxes_names_new.csv')\n","for img in range(0,df_bbox2.shape[0]):\n"," df_bbox2.iloc[img]\n"," row = pd.DataFrame(df_bbox2.iloc[img])\n"," if row[img][0] == random_choice:\n"," row = row.dropna()\n"," image = imageio.imread(Data_folder+'/'+row[img][0])\n"," #plt.figure(figsize=(12,12))\n"," plt.subplot(1,3,3)\n"," plt.axis('off')\n"," plt.title('Alternative Display of Prediction')\n"," plt.imshow(image, cmap='gray') # plot image\n","\n"," for i in range(1,int(len(row)-1),6):\n"," plt_rectangle(plt,\n"," label = row[img][i+5],\n"," x1=row[img][i],#.format(iplot)],\n"," y1=row[img][i+1],\n"," x2=row[img][i+2],\n"," y2=row[img][i+3])#,\n"," #fontsize=8)\n"," #plt.margins(0,0)\n"," #plt.subplots_adjust(left=0., right=1., top=1., bottom=0.)\n"," #plt.gca().xaxis.set_major_locator(plt.NullLocator())\n"," #plt.gca().yaxis.set_major_locator(plt.NullLocator())\n"," plt.savefig('/content/detected_cells.png',bbox_inches='tight',transparent=True,pad_inches=0)\n","plt.show() ## show the plot\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"UvSlTaH14s3t"},"source":["\n","#**Thank you for using YOLOv2!**"]}]} \ No newline at end of file +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"YOLOv2_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"10jmTrCraaJMEbuTBTnt34WeoG2MVVEEv","timestamp":1625150616869},{"file_id":"1VlYfohmBOvSVtkYci7R2gktMj-F32oK4","timestamp":1622645551473},{"file_id":"1bQuSKv6gvjvWhnzoIVjqUNvlC3F_-Jvw","timestamp":1619709372524},{"file_id":"1A26cn0nxWQCv-LuP3UBfyCWlKBGIo0RU","timestamp":1610968154980},{"file_id":"1MmLTCC0nyX3Akb9V4C_OVxM3X_M8u-eX","timestamp":1610543191319},{"file_id":"1paNjUObR5Rcr4BMGADJTz0PQBBLZDPrY","timestamp":1602522500580},{"file_id":"1WZRIoSBNcRUEq4-Rq5M4mDkIaOlEHnxz","timestamp":1588762142860},{"file_id":"10weAY0es-pEfHlACCaBCKK7PmgdoJqdh","timestamp":1587728072051},{"file_id":"10Ze0rFZoooyyTL_OIVWGdFJEhWE6_cSB","timestamp":1586789421439},{"file_id":"1SsGyUbWcMaLGHFepMuKElRNYLdEBUwf6","timestamp":1583244509550}],"collapsed_sections":[],"machine_shape":"hm"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.7"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"IkSguVy8Xv83"},"source":["# **YOLOv2**\n","---\n","\n"," YOLOv2 is a deep-learning method designed to perform object detection and classification of objects in images, published by [Redmon and Farhadi](https://ieeexplore.ieee.org/document/8100173). This is based on the original [YOLO](https://arxiv.org/abs/1506.02640) implementation published by the same authors. YOLOv2 is trained on images with class annotations in the form of bounding boxes drawn around the objects of interest. The images are downsampled by a convolutional neural network (CNN) and objects are classified in two final fully connected layers in the network. YOLOv2 learns classification and object detection simultaneously by taking the whole input image into account, predicting many possible bounding box solutions, and then using regression to find the best bounding boxes and classifications for each object.\n","\n","**This particular notebook enables object detection and classification on 2D images given ground truth bounding boxes. If you are interested in image segmentation, you should use our U-net or Stardist notebooks instead.**\n","\n","---\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is based on the following papers: \n","\n","**YOLO9000: Better, Faster, Stronger** from Joseph Redmon and Ali Farhadi in Proceedings of the IEEE conference on computer vision and pattern recognition, 2017, (https://ieeexplore.ieee.org/document/8100173)\n","\n","**You Only Look Once: Unified, Real-Time Object Detection** from Joseph Redmon, Santosh Divvala, Ross Girshick, Ali Farhadi in IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2016, (https://ieeexplore.ieee.org/document/7780460)\n","\n","**Note: The source code for this notebook is adapted for keras and can be found in: (/~https://github.com/experiencor/keras-yolo2)**\n","\n","\n","**Please also cite these original papers when using or developing this notebook.**"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","\n","\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"gKDLkLWUd-YX"},"source":["#**0. Before getting started**\n","---\n"," Preparing the dataset carefully is essential to make this YOLOv2 notebook work. This model requires as input a set of images (currently .jpg) and as target a list of annotation files in Pascal VOC format. The annotation files should have the exact same name as the input files, except with an .xml instead of the .jpg extension. The annotation files contain the class labels and all bounding boxes for the objects for each image in your dataset. Most datasets will give the option of saving the annotations in this format or using software for hand-annotations will automatically save the annotations in this format. \n","\n"," If you want to assemble your own dataset we recommend using the open source https://www.makesense.ai/ resource. You can follow our instructions on how to label your dataset with this tool on our [wiki](/~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki/Object-Detection-(YOLOv2)).\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n"," **Additionally, the corresponding input and output files need to have the same name**.\n","\n"," Please note that you currently can **only use .png or .jpg files!**\n","\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Input images (Training_source)\n"," - img_1.png, img_2.png, ...\n"," - High SNR images (Training_source_annotations)\n"," - img_1.xml, img_2.xml, ...\n"," - **Quality control dataset**\n"," - Input images\n"," - img_1.png, img_2.png\n"," - High SNR images\n"," - img_1.xml, img_2.xml\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"AdN8B91xZO0x"},"source":["# **1. Install YOLOv2 and Dependencies**\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"bEN_Qt10Opz-"},"source":["## **1.1. Install key dependencies**\n","---\n"," "]},{"cell_type":"code","metadata":{"id":"5fGzcX6sOTxH","cellView":"form"},"source":["#@markdown ##Install YOLOv2 and dependencies\n","\n","!pip install pascal-voc-writer\n","!pip install fpdf\n","!pip install PTable\n","!pip install h5py==2.10\n","\n","#Force session restart\n","exit(0)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"SNLwKiVXO000"},"source":["## **1.2. Restart your runtime**\n","---\n","\n","\n","\n","** Ignore the following message error message. Your Runtime has automatically restarted. This is normal.**\n","\n","\"\"
\n"]},{"cell_type":"markdown","metadata":{"id":"8OsMrZ8hO7D8"},"source":["## **1.3. Load key dependencies**\n","---\n"," "]},{"cell_type":"code","metadata":{"id":"fq21zJVFNASx","cellView":"form"},"source":["Notebook_version = '1.13'\n","Network = 'YOLOv2'\n","\n","from builtins import any as b_any\n","\n","def get_requirements_path():\n"," # Store requirements file in 'contents' directory \n"," current_dir = os.getcwd()\n"," dir_count = current_dir.count('/') - 1\n"," path = '../' * (dir_count) + 'requirements.txt'\n"," return path\n","\n","def filter_files(file_list, filter_list):\n"," filtered_list = []\n"," for fname in file_list:\n"," if b_any(fname.split('==')[0] in s for s in filter_list):\n"," filtered_list.append(fname)\n"," return filtered_list\n","\n","def build_requirements_file(before, after):\n"," path = get_requirements_path()\n","\n"," # Exporting requirements.txt for local run\n"," !pip freeze > $path\n","\n"," # Get minimum requirements file\n"," df = pd.read_csv(path, delimiter = \"\\n\")\n"," mod_list = [m.split('.')[0] for m in after if not m in before]\n"," req_list_temp = df.values.tolist()\n"," req_list = [x[0] for x in req_list_temp]\n","\n"," # Replace with package name and handle cases where import name is different to module name\n"," mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]\n"," mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list] \n"," filtered_list = filter_files(req_list, mod_replace_list)\n","\n"," file=open(path,'w')\n"," for item in filtered_list:\n"," file.writelines(item + '\\n')\n","\n"," file.close()\n","\n","import sys\n","before = [str(m) for m in sys.modules]\n","\n","#@markdown ##Load Key Dependencies\n","%tensorflow_version 1.x\n","\n","from pascal_voc_writer import Writer\n","from __future__ import division\n","from __future__ import print_function\n","from __future__ import absolute_import\n","import csv\n","import random\n","import pprint\n","import time\n","import numpy as np\n","from optparse import OptionParser\n","import pickle\n","import math\n","import cv2\n","import copy\n","import math\n","from matplotlib import pyplot as plt\n","from matplotlib.pyplot import imread\n","import matplotlib.patches as patches\n","import tensorflow as tf\n","import pandas as pd\n","import os\n","import shutil\n","from skimage import io\n","from sklearn.metrics import average_precision_score\n","\n","from keras.models import Model\n","from keras.layers import Flatten, Dense, Input, Conv2D, MaxPooling2D, Dropout, Reshape, Activation, Conv2D, MaxPooling2D, BatchNormalization, Lambda\n","from keras.layers.advanced_activations import LeakyReLU\n","from keras.layers.merge import concatenate\n","from keras.applications.mobilenet import MobileNet\n","from keras.applications import InceptionV3\n","from keras.applications.vgg16 import VGG16\n","from keras.applications.resnet50 import ResNet50\n","\n","from keras import backend as K\n","from keras.optimizers import Adam, SGD, RMSprop\n","from keras.layers import GlobalAveragePooling2D, GlobalMaxPooling2D, TimeDistributed\n","from keras.engine.topology import get_source_inputs\n","from keras.utils import layer_utils\n","from keras.utils.data_utils import get_file\n","from keras.objectives import categorical_crossentropy\n","from keras.models import Model\n","from keras.utils import generic_utils\n","from keras.engine import Layer, InputSpec\n","from keras import initializers, regularizers\n","from keras.utils import Sequence\n","import xml.etree.ElementTree as ET\n","from collections import OrderedDict, Counter\n","import json\n","import imageio\n","import imgaug as ia\n","from imgaug import augmenters as iaa\n","import copy\n","import cv2\n","from tqdm import tqdm\n","from tempfile import mkstemp\n","from shutil import move, copymode\n","from os import fdopen, remove\n","from fpdf import FPDF, HTMLMixin\n","from datetime import datetime\n","from pip._internal.operations.freeze import freeze\n","import subprocess as sp\n","from prettytable import from_csv\n","\n","ia.seed(1)\n","# imgaug uses matplotlib backend for displaying images\n","from imgaug.augmentables.bbs import BoundingBox, BoundingBoxesOnImage\n","import re\n","import glob\n","\n","#Here, we import a different github repo which includes the map_evaluation.py\n","!git clone /~https://github.com/rodrigo2019/keras_yolo2.git\n","\n","#Here, we import the main github repo for this notebook and move it to the gdrive\n","!git clone /~https://github.com/experiencor/keras-yolo2.git\n","\n","#Now, we move the map_evaluation.py file to the main repo for this notebook.\n","#The source repo of the map_evaluation.py can then be ignored and is not further relevant for this notebook.\n","shutil.move('/content/keras_yolo2/keras_yolov2/map_evaluation.py','/content/keras-yolo2/map_evaluation.py')\n","\n","#We remove this branch from the notebook, to avoid confusion.\n","shutil.rmtree('/content/keras_yolo2')\n","\n","os.chdir('/content/keras-yolo2')\n","\n","\n","from backend import BaseFeatureExtractor, FullYoloFeature\n","from preprocessing import parse_annotation, BatchGenerator\n","\n","\n","\n","def plt_rectangle(plt,label,x1,y1,x2,y2,fontsize=10):\n"," '''\n"," == Input ==\n"," \n"," plt : matplotlib.pyplot object\n"," label : string containing the object class name\n"," x1 : top left corner x coordinate\n"," y1 : top left corner y coordinate\n"," x2 : bottom right corner x coordinate\n"," y2 : bottom right corner y coordinate\n"," '''\n"," linewidth = 1\n"," color = \"yellow\"\n"," plt.text(x1,y1,label,fontsize=fontsize,backgroundcolor=\"magenta\")\n"," plt.plot([x1,x1],[y1,y2], linewidth=linewidth,color=color)\n"," plt.plot([x2,x2],[y1,y2], linewidth=linewidth,color=color)\n"," plt.plot([x1,x2],[y1,y1], linewidth=linewidth,color=color)\n"," plt.plot([x1,x2],[y2,y2], linewidth=linewidth,color=color)\n","\n","def extract_single_xml_file(tree,object_count=True):\n"," Nobj = 0\n"," row = OrderedDict()\n"," for elems in tree.iter():\n","\n"," if elems.tag == \"size\":\n"," for elem in elems:\n"," row[elem.tag] = int(elem.text)\n"," if elems.tag == \"object\":\n"," for elem in elems:\n"," if elem.tag == \"name\":\n"," row[\"bbx_{}_{}\".format(Nobj,elem.tag)] = str(elem.text) \n"," if elem.tag == \"bndbox\":\n"," for k in elem:\n"," row[\"bbx_{}_{}\".format(Nobj,k.tag)] = float(k.text)\n"," Nobj += 1\n"," if object_count == True:\n"," row[\"Nobj\"] = Nobj\n"," return(row)\n","\n","def count_objects(tree):\n"," Nobj=0\n"," for elems in tree.iter():\n"," if elems.tag == \"object\":\n"," for elem in elems:\n"," if elem.tag == \"bndbox\":\n"," Nobj += 1\n"," return(Nobj)\n","\n","def compute_overlap(a, b):\n"," \"\"\"\n"," Code originally from /~https://github.com/rbgirshick/py-faster-rcnn.\n"," Parameters\n"," ----------\n"," a: (N, 4) ndarray of float\n"," b: (K, 4) ndarray of float\n"," Returns\n"," -------\n"," overlaps: (N, K) ndarray of overlap between boxes and query_boxes\n"," \"\"\"\n"," area = (b[:, 2] - b[:, 0]) * (b[:, 3] - b[:, 1])\n","\n"," iw = np.minimum(np.expand_dims(a[:, 2], axis=1), b[:, 2]) - np.maximum(np.expand_dims(a[:, 0], 1), b[:, 0])\n"," ih = np.minimum(np.expand_dims(a[:, 3], axis=1), b[:, 3]) - np.maximum(np.expand_dims(a[:, 1], 1), b[:, 1])\n","\n"," iw = np.maximum(iw, 0)\n"," ih = np.maximum(ih, 0)\n","\n"," ua = np.expand_dims((a[:, 2] - a[:, 0]) * (a[:, 3] - a[:, 1]), axis=1) + area - iw * ih\n","\n"," ua = np.maximum(ua, np.finfo(float).eps)\n","\n"," intersection = iw * ih\n","\n"," return intersection / ua\n","\n","def compute_ap(recall, precision):\n"," \"\"\" Compute the average precision, given the recall and precision curves.\n"," Code originally from /~https://github.com/rbgirshick/py-faster-rcnn.\n","\n"," # Arguments\n"," recall: The recall curve (list).\n"," precision: The precision curve (list).\n"," # Returns\n"," The average precision as computed in py-faster-rcnn.\n"," \"\"\"\n"," # correct AP calculation\n"," # first append sentinel values at the end\n"," mrec = np.concatenate(([0.], recall, [1.]))\n"," mpre = np.concatenate(([0.], precision, [0.]))\n","\n"," # compute the precision envelope\n"," for i in range(mpre.size - 1, 0, -1):\n"," mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])\n","\n"," # to calculate area under PR curve, look for points\n"," # where X axis (recall) changes value\n"," i = np.where(mrec[1:] != mrec[:-1])[0]\n","\n"," # and sum (\\Delta recall) * prec\n"," ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])\n"," return ap \n","\n","def load_annotation(image_folder,annotations_folder, i, config):\n"," annots = []\n"," imgs, anns = parse_annotation(annotations_folder,image_folder,config['model']['labels'])\n"," for obj in imgs[i]['object']:\n"," annot = [obj['xmin'], obj['ymin'], obj['xmax'], obj['ymax'], config['model']['labels'].index(obj['name'])]\n"," annots += [annot]\n","\n"," if len(annots) == 0: annots = [[]]\n","\n"," return np.array(annots)\n","\n","def _calc_avg_precisions(config,image_folder,annotations_folder,weights_path,iou_threshold,score_threshold):\n","\n"," # gather all detections and annotations\n"," all_detections = [[None for _ in range(len(config['model']['labels']))] for _ in range(len(os.listdir(image_folder)))]\n"," all_annotations = [[None for _ in range(len(config['model']['labels']))] for _ in range(len(os.listdir(annotations_folder)))]\n","\n"," for i in range(len(os.listdir(image_folder))):\n"," raw_image = cv2.imread(os.path.join(image_folder,sorted(os.listdir(image_folder))[i]))\n"," raw_height, raw_width, _ = raw_image.shape\n"," #print(raw_height)\n"," # make the boxes and the labels\n"," yolo = YOLO(backend = config['model']['backend'],\n"," input_size = config['model']['input_size'], \n"," labels = config['model']['labels'], \n"," max_box_per_image = config['model']['max_box_per_image'],\n"," anchors = config['model']['anchors'])\n"," yolo.load_weights(weights_path)\n"," pred_boxes = yolo.predict(raw_image,iou_threshold=iou_threshold,score_threshold=score_threshold)\n","\n"," score = np.array([box.score for box in pred_boxes])\n"," #print(score)\n"," pred_labels = np.array([box.label for box in pred_boxes])\n"," #print(len(pred_boxes))\n"," if len(pred_boxes) > 0:\n"," pred_boxes = np.array([[box.xmin * raw_width, box.ymin * raw_height, box.xmax * raw_width,\n"," box.ymax * raw_height, box.score] for box in pred_boxes])\n"," else:\n"," pred_boxes = np.array([[]])\n","\n"," # sort the boxes and the labels according to scores\n"," score_sort = np.argsort(-score)\n"," pred_labels = pred_labels[score_sort]\n"," pred_boxes = pred_boxes[score_sort]\n","\n"," # copy detections to all_detections\n"," for label in range(len(config['model']['labels'])):\n"," all_detections[i][label] = pred_boxes[pred_labels == label, :]\n","\n"," annotations = load_annotation(image_folder,annotations_folder,i,config)\n","\n"," # copy ground truth to all_annotations\n"," for label in range(len(config['model']['labels'])):\n"," all_annotations[i][label] = annotations[annotations[:, 4] == label, :4].copy()\n","\n"," # compute mAP by comparing all detections and all annotations\n"," average_precisions = {}\n"," F1_scores = {}\n"," total_recall = []\n"," total_precision = []\n"," \n"," with open(QC_model_folder+\"/Quality Control/QC_results.csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"class\", \"false positive\", \"true positive\", \"false negative\", \"recall\", \"precision\", \"accuracy\", \"f1 score\", \"average_precision\"]) \n"," \n"," for label in range(len(config['model']['labels'])):\n"," false_positives = np.zeros((0,))\n"," true_positives = np.zeros((0,))\n"," scores = np.zeros((0,))\n"," num_annotations = 0.0\n","\n"," for i in range(len(os.listdir(image_folder))):\n"," detections = all_detections[i][label]\n"," annotations = all_annotations[i][label]\n"," num_annotations += annotations.shape[0]\n"," detected_annotations = []\n","\n"," for d in detections:\n"," scores = np.append(scores, d[4])\n","\n"," if annotations.shape[0] == 0:\n"," false_positives = np.append(false_positives, 1)\n"," true_positives = np.append(true_positives, 0)\n"," continue\n","\n"," overlaps = compute_overlap(np.expand_dims(d, axis=0), annotations)\n"," assigned_annotation = np.argmax(overlaps, axis=1)\n"," max_overlap = overlaps[0, assigned_annotation]\n","\n"," if max_overlap >= iou_threshold and assigned_annotation not in detected_annotations:\n"," false_positives = np.append(false_positives, 0)\n"," true_positives = np.append(true_positives, 1)\n"," detected_annotations.append(assigned_annotation)\n"," else:\n"," false_positives = np.append(false_positives, 1)\n"," true_positives = np.append(true_positives, 0)\n","\n"," # no annotations -> AP for this class is 0 (is this correct?)\n"," if num_annotations == 0:\n"," average_precisions[label] = 0\n"," continue\n","\n"," # sort by score\n"," indices = np.argsort(-scores)\n"," false_positives = false_positives[indices]\n"," true_positives = true_positives[indices]\n","\n"," # compute false positives and true positives\n"," false_positives = np.cumsum(false_positives)\n"," true_positives = np.cumsum(true_positives)\n","\n"," # compute recall and precision\n"," recall = true_positives / num_annotations\n"," precision = true_positives / np.maximum(true_positives + false_positives, np.finfo(np.float64).eps)\n"," total_recall.append(recall)\n"," total_precision.append(precision)\n"," #print(precision)\n"," # compute average precision\n"," average_precision = compute_ap(recall, precision)\n"," average_precisions[label] = average_precision\n","\n"," if len(precision) != 0:\n"," F1_score = 2*(precision[-1]*recall[-1]/(precision[-1]+recall[-1]))\n"," F1_scores[label] = F1_score\n"," writer.writerow([config['model']['labels'][label], str(int(false_positives[-1])), str(int(true_positives[-1])), str(int(num_annotations-true_positives[-1])), str(recall[-1]), str(precision[-1]), str(true_positives[-1]/num_annotations), str(F1_scores[label]), str(average_precisions[label])])\n"," else:\n"," F1_score = 0\n"," F1_scores[label] = F1_score\n"," writer.writerow([config['model']['labels'][label], str(0), str(0), str(0), str(0), str(0), str(0), str(F1_score), str(average_precisions[label])])\n"," return F1_scores, average_precisions, total_recall, total_precision\n","\n","\n","def show_frame(pred_bb, pred_classes, pred_conf, gt_bb, gt_classes, class_dict, background=np.zeros((512, 512, 3)), show_confidence=True):\n"," \"\"\"\n"," Here, we are adapting classes and functions from /~https://github.com/MathGaron/mean_average_precision\n"," \"\"\"\n"," \"\"\"\n"," Plot the boundingboxes\n"," :param pred_bb: (np.array) Predicted Bounding Boxes [x1, y1, x2, y2] : Shape [n_pred, 4]\n"," :param pred_classes: (np.array) Predicted Classes : Shape [n_pred]\n"," :param pred_conf: (np.array) Predicted Confidences [0.-1.] : Shape [n_pred]\n"," :param gt_bb: (np.array) Ground Truth Bounding Boxes [x1, y1, x2, y2] : Shape [n_gt, 4]\n"," :param gt_classes: (np.array) Ground Truth Classes : Shape [n_gt]\n"," :param class_dict: (dictionary) Key value pairs of classes, e.g. {0:'dog',1:'cat',2:'horse'}\n"," :return:\n"," \"\"\"\n"," n_pred = pred_bb.shape[0]\n"," n_gt = gt_bb.shape[0]\n"," n_class = int(np.max(np.append(pred_classes, gt_classes)) + 1)\n"," #print(n_class)\n"," if len(background.shape) < 3:\n"," h, w = background.shape\n"," else:\n"," h, w, c = background.shape\n","\n"," ax = plt.subplot(\"111\")\n"," ax.imshow(background)\n"," cmap = plt.cm.get_cmap('hsv')\n","\n"," confidence_alpha = pred_conf.copy()\n"," if not show_confidence:\n"," confidence_alpha.fill(1)\n","\n"," for i in range(n_pred):\n"," x1 = pred_bb[i, 0]# * w\n"," y1 = pred_bb[i, 1]# * h\n"," x2 = pred_bb[i, 2]# * w\n"," y2 = pred_bb[i, 3]# * h\n"," rect_w = x2 - x1\n"," rect_h = y2 - y1\n"," #print(x1, y1)\n"," ax.add_patch(patches.Rectangle((x1, y1), rect_w, rect_h,\n"," fill=False,\n"," edgecolor=cmap(float(pred_classes[i]) / n_class),\n"," linestyle='dashdot',\n"," alpha=confidence_alpha[i]))\n","\n"," for i in range(n_gt):\n"," x1 = gt_bb[i, 0]# * w\n"," y1 = gt_bb[i, 1]# * h\n"," x2 = gt_bb[i, 2]# * w\n"," y2 = gt_bb[i, 3]# * h\n"," rect_w = x2 - x1\n"," rect_h = y2 - y1\n"," ax.add_patch(patches.Rectangle((x1, y1), rect_w, rect_h,\n"," fill=False,\n"," edgecolor=cmap(float(gt_classes[i]) / n_class)))\n","\n"," legend_handles = []\n","\n"," for i in range(n_class):\n"," legend_handles.append(patches.Patch(color=cmap(float(i) / n_class), label=class_dict[i]))\n"," \n"," ax.legend(handles=legend_handles)\n"," plt.show()\n","\n","class BoundBox:\n"," \"\"\"\n"," Here, we are adapting classes and functions from /~https://github.com/MathGaron/mean_average_precision\n"," \"\"\"\n"," def __init__(self, xmin, ymin, xmax, ymax, c = None, classes = None):\n"," self.xmin = xmin\n"," self.ymin = ymin\n"," self.xmax = xmax\n"," self.ymax = ymax\n"," \n"," self.c = c\n"," self.classes = classes\n","\n"," self.label = -1\n"," self.score = -1\n","\n"," def get_label(self):\n"," if self.label == -1:\n"," self.label = np.argmax(self.classes)\n"," \n"," return self.label\n"," \n"," def get_score(self):\n"," if self.score == -1:\n"," self.score = self.classes[self.get_label()]\n"," \n"," return self.score\n","\n","class WeightReader:\n"," def __init__(self, weight_file):\n"," self.offset = 4\n"," self.all_weights = np.fromfile(weight_file, dtype='float32')\n"," \n"," def read_bytes(self, size):\n"," self.offset = self.offset + size\n"," return self.all_weights[self.offset-size:self.offset]\n"," \n"," def reset(self):\n"," self.offset = 4\n","\n","def bbox_iou(box1, box2):\n"," intersect_w = _interval_overlap([box1.xmin, box1.xmax], [box2.xmin, box2.xmax])\n"," intersect_h = _interval_overlap([box1.ymin, box1.ymax], [box2.ymin, box2.ymax]) \n"," \n"," intersect = intersect_w * intersect_h\n","\n"," w1, h1 = box1.xmax-box1.xmin, box1.ymax-box1.ymin\n"," w2, h2 = box2.xmax-box2.xmin, box2.ymax-box2.ymin\n"," \n"," union = w1*h1 + w2*h2 - intersect\n"," \n"," return float(intersect) / union\n","\n","def draw_boxes(image, boxes, labels):\n"," image_h, image_w, _ = image.shape\n"," #Changes in box color added by LvC\n"," for box in boxes:\n"," xmin = int(box.xmin*image_w)\n"," ymin = int(box.ymin*image_h)\n"," xmax = int(box.xmax*image_w)\n"," ymax = int(box.ymax*image_h)\n"," if box.get_label() == 0:\n"," cv2.rectangle(image, (xmin,ymin), (xmax,ymax), (255,0,0), 3)\n"," elif box.get_label() == 1:\n"," cv2.rectangle(image, (xmin,ymin), (xmax,ymax), (0,255,0), 3)\n"," else:\n"," cv2.rectangle(image, (xmin,ymin), (xmax,ymax), (0,0,255), 3)\n"," cv2.putText(image, \n"," labels[box.get_label()] + ' ' + str(round(box.get_score(),3)), \n"," (xmin, ymin - 13), \n"," cv2.FONT_HERSHEY_SIMPLEX, \n"," 1e-3 * image_h, \n"," (0,0,0), 2)\n"," #print(box.get_label()) \n"," return image \n","\n","#Function added by LvC\n","def save_boxes(image_path, boxes, labels):#, save_path):\n"," image = cv2.imread(image_path)\n"," image_h, image_w, _ = image.shape\n"," save_boxes =[]\n"," save_boxes_names = []\n"," save_boxes.append(os.path.basename(image_path))\n"," save_boxes_names.append(os.path.basename(image_path))\n"," for box in boxes:\n"," # xmin = box.xmin\n"," save_boxes.append(int(box.xmin*image_w))\n"," save_boxes_names.append(int(box.xmin*image_w))\n"," # ymin = box.ymin\n"," save_boxes.append(int(box.ymin*image_h))\n"," save_boxes_names.append(int(box.ymin*image_h))\n"," # xmax = box.xmax\n"," save_boxes.append(int(box.xmax*image_w))\n"," save_boxes_names.append(int(box.xmax*image_w))\n"," # ymax = box.ymax\n"," save_boxes.append(int(box.ymax*image_h))\n"," save_boxes_names.append(int(box.ymax*image_h))\n"," score = box.get_score()\n"," save_boxes.append(score)\n"," save_boxes_names.append(score)\n"," label = box.get_label()\n"," save_boxes.append(label)\n"," save_boxes_names.append(labels[label])\n"," \n"," #This file will be for later analysis of the bounding boxes in imagej\n"," if not os.path.exists('/content/predicted_bounding_boxes.csv'):\n"," with open('/content/predicted_bounding_boxes.csv', 'w', newline='') as csvfile:\n"," csvwriter = csv.writer(csvfile, delimiter=',')\n"," specs_list = ['filename']+['xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class']*len(boxes)\n"," csvwriter.writerow(specs_list)\n"," csvwriter.writerow(save_boxes)\n"," else:\n"," with open('/content/predicted_bounding_boxes.csv', 'a+', newline='') as csvfile:\n"," csvwriter = csv.writer(csvfile)\n"," csvwriter.writerow(save_boxes)\n"," \n"," if not os.path.exists('/content/predicted_bounding_boxes_names.csv'):\n"," with open('/content/predicted_bounding_boxes_names.csv', 'w', newline='') as csvfile_names:\n"," csvwriter = csv.writer(csvfile_names, delimiter=',')\n"," specs_list = ['filename']+['xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class']*len(boxes)\n"," csvwriter.writerow(specs_list)\n"," csvwriter.writerow(save_boxes_names)\n"," else:\n"," with open('/content/predicted_bounding_boxes_names.csv', 'a+', newline='') as csvfile_names:\n"," csvwriter = csv.writer(csvfile_names)\n"," csvwriter.writerow(save_boxes_names)\n","\n","def add_header(inFilePath,outFilePath):\n"," header = ['filename']+['xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class']*max(n_objects)\n"," with open(inFilePath, newline='') as inFile, open(outFilePath, 'w', newline='') as outfile:\n"," r = csv.reader(inFile)\n"," w = csv.writer(outfile)\n"," next(r, None) # skip the first row from the reader, the old header\n"," # write new header\n"," w.writerow(header)\n"," # copy the rest\n"," for row in r:\n"," w.writerow(row)\n"," \n","def decode_netout(netout, anchors, nb_class, obj_threshold=0.3, nms_threshold=0.5):\n"," grid_h, grid_w, nb_box = netout.shape[:3]\n","\n"," boxes = []\n"," \n"," # decode the output by the network\n"," netout[..., 4] = _sigmoid(netout[..., 4])\n"," netout[..., 5:] = netout[..., 4][..., np.newaxis] * _softmax(netout[..., 5:])\n"," netout[..., 5:] *= netout[..., 5:] > obj_threshold\n"," \n"," for row in range(grid_h):\n"," for col in range(grid_w):\n"," for b in range(nb_box):\n"," # from 4th element onwards are confidence and class classes\n"," classes = netout[row,col,b,5:]\n"," \n"," if np.sum(classes) > 0:\n"," # first 4 elements are x, y, w, and h\n"," x, y, w, h = netout[row,col,b,:4]\n","\n"," x = (col + _sigmoid(x)) / grid_w # center position, unit: image width\n"," y = (row + _sigmoid(y)) / grid_h # center position, unit: image height\n"," w = anchors[2 * b + 0] * np.exp(w) / grid_w # unit: image width\n"," h = anchors[2 * b + 1] * np.exp(h) / grid_h # unit: image height\n"," confidence = netout[row,col,b,4]\n"," \n"," box = BoundBox(x-w/2, y-h/2, x+w/2, y+h/2, confidence, classes)\n"," \n"," boxes.append(box)\n","\n"," # suppress non-maximal boxes\n"," for c in range(nb_class):\n"," sorted_indices = list(reversed(np.argsort([box.classes[c] for box in boxes])))\n","\n"," for i in range(len(sorted_indices)):\n"," index_i = sorted_indices[i]\n"," \n"," if boxes[index_i].classes[c] == 0: \n"," continue\n"," else:\n"," for j in range(i+1, len(sorted_indices)):\n"," index_j = sorted_indices[j]\n"," \n"," if bbox_iou(boxes[index_i], boxes[index_j]) >= nms_threshold:\n"," boxes[index_j].classes[c] = 0\n"," \n"," # remove the boxes which are less likely than a obj_threshold\n"," boxes = [box for box in boxes if box.get_score() > obj_threshold]\n"," \n"," return boxes\n","\n","def replace(file_path, pattern, subst):\n"," #Create temp file\n"," fh, abs_path = mkstemp()\n"," with fdopen(fh,'w') as new_file:\n"," with open(file_path) as old_file:\n"," for line in old_file:\n"," new_file.write(line.replace(pattern, subst))\n"," #Copy the file permissions from the old file to the new file\n"," copymode(file_path, abs_path)\n"," #Remove original file\n"," remove(file_path)\n"," #Move new file\n"," move(abs_path, file_path)\n","\n","with open(\"/content/keras-yolo2/frontend.py\", \"r\") as check:\n"," lineReader = check.readlines()\n"," reduce_lr = False\n"," for line in lineReader:\n"," if \"reduce_lr\" in line:\n"," reduce_lr = True\n"," break\n","\n","if reduce_lr == False:\n"," #replace(\"/content/gdrive/My Drive/keras-yolo2/frontend.py\",\"period=1)\",\"period=1)\\n csv_logger=CSVLogger('/content/training_evaluation.csv')\")\n"," replace(\"/content/keras-yolo2/frontend.py\",\"period=1)\",\"period=1)\\n reduce_lr=ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5, verbose=1)\")\n","replace(\"/content/keras-yolo2/frontend.py\",\"import EarlyStopping\",\"import ReduceLROnPlateau, EarlyStopping\")\n","\n","with open(\"/content/keras-yolo2/frontend.py\", \"r\") as check:\n"," lineReader = check.readlines()\n"," map_eval = False\n"," for line in lineReader:\n"," if \"map_evaluation\" in line:\n"," map_eval = True\n"," break\n","\n","if map_eval == False:\n"," replace(\"/content/keras-yolo2/frontend.py\", \"import cv2\",\"import cv2\\nfrom map_evaluation import MapEvaluation\")\n"," new_callback = ' map_evaluator = MapEvaluation(self, valid_generator,save_best=True,save_name=\"/content/keras-yolo2/best_map_weights.h5\",iou_threshold=0.3,score_threshold=0.3)'\n"," replace(\"/content/keras-yolo2/frontend.py\",\"write_images=False)\",\"write_images=False)\\n\"+new_callback)\n"," replace(\"/content/keras-yolo2/map_evaluation.py\",\"import keras\",\"import keras\\nimport csv\")\n"," replace(\"/content/keras-yolo2/map_evaluation.py\",\"from .utils\",\"from utils\")\n"," replace(\"/content/keras-yolo2/map_evaluation.py\",\".format(_map))\",\".format(_map))\\n with open('/content/gdrive/My Drive/mAP.csv','a+', newline='') as mAP_csv:\\n csv_writer=csv.writer(mAP_csv)\\n csv_writer.writerow(['mAP:','{:.4f}'.format(_map)])\")\n"," replace(\"/content/keras-yolo2/map_evaluation.py\",\"iou_threshold=0.5\",\"iou_threshold=0.3\")\n"," replace(\"/content/keras-yolo2/map_evaluation.py\",\"score_threshold=0.5\",\"score_threshold=0.3\")\n","\n","replace(\"/content/keras-yolo2/frontend.py\", \"[early_stop, checkpoint, tensorboard]\",\"[checkpoint, reduce_lr, map_evaluator]\")\n","replace(\"/content/keras-yolo2/frontend.py\", \"predict(self, image)\",\"predict(self,image,iou_threshold=0.3,score_threshold=0.3)\")\n","replace(\"/content/keras-yolo2/frontend.py\", \"self.model.summary()\",\"#self.model.summary()\")\n","from frontend import YOLO\n","\n","def train(config_path, model_path, percentage_validation):\n"," #config_path = args.conf\n","\n"," with open(config_path) as config_buffer: \n"," config = json.loads(config_buffer.read())\n","\n"," ###############################\n"," # Parse the annotations \n"," ###############################\n","\n"," # parse annotations of the training set\n"," train_imgs, train_labels = parse_annotation(config['train']['train_annot_folder'], \n"," config['train']['train_image_folder'], \n"," config['model']['labels'])\n","\n"," # parse annotations of the validation set, if any, otherwise split the training set\n"," if os.path.exists(config['valid']['valid_annot_folder']):\n"," valid_imgs, valid_labels = parse_annotation(config['valid']['valid_annot_folder'], \n"," config['valid']['valid_image_folder'], \n"," config['model']['labels'])\n"," else:\n"," train_valid_split = int((1-percentage_validation/100.)*len(train_imgs))\n"," np.random.shuffle(train_imgs)\n","\n"," valid_imgs = train_imgs[train_valid_split:]\n"," train_imgs = train_imgs[:train_valid_split]\n","\n"," if len(config['model']['labels']) > 0:\n"," overlap_labels = set(config['model']['labels']).intersection(set(train_labels.keys()))\n","\n"," print('Seen labels:\\t', train_labels)\n"," print('Given labels:\\t', config['model']['labels'])\n"," print('Overlap labels:\\t', overlap_labels) \n","\n"," if len(overlap_labels) < len(config['model']['labels']):\n"," print('Some labels have no annotations! Please revise the list of labels in the config.json file!')\n"," return\n"," else:\n"," print('No labels are provided. Train on all seen labels.')\n"," config['model']['labels'] = train_labels.keys()\n"," \n"," ###############################\n"," # Construct the model \n"," ###############################\n","\n"," yolo = YOLO(backend = config['model']['backend'],\n"," input_size = config['model']['input_size'], \n"," labels = config['model']['labels'], \n"," max_box_per_image = config['model']['max_box_per_image'],\n"," anchors = config['model']['anchors'])\n","\n"," ###############################\n"," # Load the pretrained weights (if any) \n"," ############################### \n","\n"," if os.path.exists(config['train']['pretrained_weights']):\n"," print(\"Loading pre-trained weights in\", config['train']['pretrained_weights'])\n"," yolo.load_weights(config['train']['pretrained_weights'])\n"," if os.path.exists('/content/gdrive/My Drive/mAP.csv'):\n"," os.remove('/content/gdrive/My Drive/mAP.csv')\n"," ###############################\n"," # Start the training process \n"," ###############################\n","\n"," yolo.train(train_imgs = train_imgs,\n"," valid_imgs = valid_imgs,\n"," train_times = config['train']['train_times'],\n"," valid_times = config['valid']['valid_times'],\n"," nb_epochs = config['train']['nb_epochs'], \n"," learning_rate = config['train']['learning_rate'], \n"," batch_size = config['train']['batch_size'],\n"," warmup_epochs = config['train']['warmup_epochs'],\n"," object_scale = config['train']['object_scale'],\n"," no_object_scale = config['train']['no_object_scale'],\n"," coord_scale = config['train']['coord_scale'],\n"," class_scale = config['train']['class_scale'],\n"," saved_weights_name = config['train']['saved_weights_name'],\n"," debug = config['train']['debug'])\n","\n","# The training evaluation.csv is saved (overwrites the Files if needed). \n"," lossDataCSVpath = os.path.join(model_path,'Quality Control/training_evaluation.csv')\n"," with open(lossDataCSVpath, 'w') as f1:\n"," writer = csv.writer(f1)\n"," mAP_df = pd.read_csv('/content/gdrive/My Drive/mAP.csv',header=None)\n"," writer.writerow(['loss','val_loss','mAP','learning rate'])\n"," for i in range(len(yolo.model.history.history['loss'])):\n"," writer.writerow([yolo.model.history.history['loss'][i], yolo.model.history.history['val_loss'][i], float(mAP_df[1][i]), yolo.model.history.history['lr'][i]])\n","\n","def predict(config, weights_path, image_path):#, model_path):\n","\n"," with open(config) as config_buffer: \n"," config = json.load(config_buffer)\n","\n"," ###############################\n"," # Make the model \n"," ###############################\n","\n"," yolo = YOLO(backend = config['model']['backend'],\n"," input_size = config['model']['input_size'], \n"," labels = config['model']['labels'], \n"," max_box_per_image = config['model']['max_box_per_image'],\n"," anchors = config['model']['anchors'])\n","\n"," ###############################\n"," # Load trained weights\n"," ############################### \n","\n"," yolo.load_weights(weights_path)\n","\n"," ###############################\n"," # Predict bounding boxes \n"," ###############################\n","\n"," if image_path[-4:] == '.mp4':\n"," video_out = image_path[:-4] + '_detected' + image_path[-4:]\n"," video_reader = cv2.VideoCapture(image_path)\n","\n"," nb_frames = int(video_reader.get(cv2.CAP_PROP_FRAME_COUNT))\n"," frame_h = int(video_reader.get(cv2.CAP_PROP_FRAME_HEIGHT))\n"," frame_w = int(video_reader.get(cv2.CAP_PROP_FRAME_WIDTH))\n","\n"," video_writer = cv2.VideoWriter(video_out,\n"," cv2.VideoWriter_fourcc(*'MPEG'), \n"," 50.0, \n"," (frame_w, frame_h))\n","\n"," for i in tqdm(range(nb_frames)):\n"," _, image = video_reader.read()\n"," \n"," boxes = yolo.predict(image)\n"," image = draw_boxes(image, boxes, config['model']['labels'])\n","\n"," video_writer.write(np.uint8(image))\n","\n"," video_reader.release()\n"," video_writer.release() \n"," else:\n"," image = cv2.imread(image_path)\n"," boxes = yolo.predict(image)\n"," image = draw_boxes(image, boxes, config['model']['labels'])\n"," save_boxes(image_path,boxes,config['model']['labels'])#,model_path)#added by LvC\n"," print(len(boxes), 'boxes are found')\n"," #print(image)\n"," cv2.imwrite(image_path[:-4] + '_detected' + image_path[-4:], image)\n"," \n"," return len(boxes)\n","\n","# function to convert BoundingBoxesOnImage object into DataFrame\n","def bbs_obj_to_df(bbs_object):\n","# convert BoundingBoxesOnImage object into array\n"," bbs_array = bbs_object.to_xyxy_array()\n","# convert array into a DataFrame ['xmin', 'ymin', 'xmax', 'ymax'] columns\n"," df_bbs = pd.DataFrame(bbs_array, columns=['xmin', 'ymin', 'xmax', 'ymax'])\n"," return df_bbs\n","\n","# Function that will extract column data for our CSV file\n","def xml_to_csv(path):\n"," xml_list = []\n"," for xml_file in glob.glob(path + '/*.xml'):\n"," tree = ET.parse(xml_file)\n"," root = tree.getroot()\n"," for member in root.findall('object'):\n"," value = (root.find('filename').text,\n"," int(root.find('size')[0].text),\n"," int(root.find('size')[1].text),\n"," member[0].text,\n"," int(float(member[4][0].text)),\n"," int(float(member[4][1].text)),\n"," int(float(member[4][2].text)),\n"," int(float(member[4][3].text))\n"," )\n"," xml_list.append(value)\n"," column_name = ['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax']\n"," xml_df = pd.DataFrame(xml_list, columns=column_name)\n"," return xml_df\n","\n","\n","\n","def image_aug(df, images_path, aug_images_path, image_prefix, augmentor):\n"," # create data frame which we're going to populate with augmented image info\n"," aug_bbs_xy = pd.DataFrame(columns=\n"," ['filename','width','height','class', 'xmin', 'ymin', 'xmax', 'ymax']\n"," )\n"," grouped = df.groupby('filename')\n"," \n"," for filename in df['filename'].unique():\n"," # get separate data frame grouped by file name\n"," group_df = grouped.get_group(filename)\n"," group_df = group_df.reset_index()\n"," group_df = group_df.drop(['index'], axis=1) \n"," # read the image\n"," image = imageio.imread(images_path+filename)\n"," # get bounding boxes coordinates and write into array \n"," bb_array = group_df.drop(['filename', 'width', 'height', 'class'], axis=1).values\n"," # pass the array of bounding boxes coordinates to the imgaug library\n"," bbs = BoundingBoxesOnImage.from_xyxy_array(bb_array, shape=image.shape)\n"," # apply augmentation on image and on the bounding boxes\n"," image_aug, bbs_aug = augmentor(image=image, bounding_boxes=bbs)\n"," # disregard bounding boxes which have fallen out of image pane \n"," bbs_aug = bbs_aug.remove_out_of_image()\n"," # clip bounding boxes which are partially outside of image pane\n"," bbs_aug = bbs_aug.clip_out_of_image()\n"," \n"," # don't perform any actions with the image if there are no bounding boxes left in it \n"," if re.findall('Image...', str(bbs_aug)) == ['Image([]']:\n"," pass\n"," \n"," # otherwise continue\n"," else:\n"," # write augmented image to a file\n"," imageio.imwrite(aug_images_path+image_prefix+filename, image_aug) \n"," # create a data frame with augmented values of image width and height\n"," info_df = group_df.drop(['xmin', 'ymin', 'xmax', 'ymax'], axis=1) \n"," for index, _ in info_df.iterrows():\n"," info_df.at[index, 'width'] = image_aug.shape[1]\n"," info_df.at[index, 'height'] = image_aug.shape[0]\n"," # rename filenames by adding the predifined prefix\n"," info_df['filename'] = info_df['filename'].apply(lambda x: image_prefix+x)\n"," # create a data frame with augmented bounding boxes coordinates using the function we created earlier\n"," bbs_df = bbs_obj_to_df(bbs_aug)\n"," # concat all new augmented info into new data frame\n"," aug_df = pd.concat([info_df, bbs_df], axis=1)\n"," # append rows to aug_bbs_xy data frame\n"," aug_bbs_xy = pd.concat([aug_bbs_xy, aug_df]) \n"," \n"," # return dataframe with updated images and bounding boxes annotations \n"," aug_bbs_xy = aug_bbs_xy.reset_index()\n"," aug_bbs_xy = aug_bbs_xy.drop(['index'], axis=1)\n"," return aug_bbs_xy\n","\n","\n","print('-------------------------------------------')\n","print(\"Depencies installed and imported.\")\n","\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n"," NORMAL = '\\033[0m'\n","\n","# Check if this is the latest version of the notebook\n","\n","# Check if this is the latest version of the notebook\n","All_notebook_versions = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_Notebook_versions.csv\", dtype=str)\n","print('Notebook version: '+Notebook_version)\n","Latest_Notebook_version = All_notebook_versions[All_notebook_versions[\"Notebook\"] == Network]['Version'].iloc[0]\n","print('Latest notebook version: '+Latest_Notebook_version)\n","if Notebook_version == Latest_Notebook_version:\n"," print(\"This notebook is up-to-date.\")\n","else:\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","#Create a pdf document with training summary\n","\n","# save FPDF() class into a \n","# variable pdf \n","\n","def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'YOLOv2'\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Training report for '+Network+' model ('+model_name+'):\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," # add another cell\n"," if trained:\n"," training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n"," pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n"," pdf.ln(1)\n","\n"," Header_2 = 'Information for your materials and methods:'\n"," pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n"," #print(all_packages)\n","\n"," #Main Packages\n"," main_packages = ''\n"," version_numbers = []\n"," for name in ['tensorflow','numpy','Keras']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n"," cuda_version = sp.run('nvcc --version',stdout=sp.PIPE, shell=True)\n"," cuda_version = cuda_version.stdout.decode('utf-8')\n"," cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n"," gpu_name = sp.run('nvidia-smi',stdout=sp.PIPE, shell=True)\n"," gpu_name = gpu_name.stdout.decode('utf-8')\n"," gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n"," #print(cuda_version[cuda_version.find(', V')+3:-1])\n"," #print(gpu_name)\n","\n"," shape = io.imread(Training_Source+'/'+os.listdir(Training_Source)[1]).shape\n"," dataset_size = len(os.listdir(Training_Source))\n","\n"," text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' labelled images (image dimensions: '+str(shape)+') with a batch size of '+str(batch_size)+' and a custom loss function combining MSE and crossentropy losses, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," if pretrained_model:\n"," text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' labelled images (image dimensions: '+str(shape)+') with a batch size of '+str(batch_size)+' and a custom loss function combining MSE and crossentropy losses, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was retrained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.multi_cell(190, 5, txt = text, align='L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n"," pdf.set_font('')\n"," if augmentation:\n"," aug_text = 'The dataset was augmented by a factor of '+str(multiply_dataset_by)+' by'\n"," if multiply_dataset_by >= 2:\n"," aug_text = aug_text+'\\n- flipping'\n"," if multiply_dataset_by > 2:\n"," aug_text = aug_text+'\\n- rotation'\n"," else:\n"," aug_text = 'No augmentation was used for training.'\n"," pdf.multi_cell(190, 5, txt=aug_text, align='L')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," if Use_Default_Advanced_Parameters:\n"," pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n"," pdf.cell(200, 5, txt='The following parameters were used for training:')\n"," pdf.ln(1)\n"," html = \"\"\" \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
number_of_epochs{0}
train_times{1}
batch_size{2}
learning_rate{3}
false_negative_penalty{4}
false_positive_penalty{5}
position_size_penalty{6}
false_class_penalty{7}
percentage_validation{8}
\n"," \"\"\".format(number_of_epochs, train_times, batch_size, learning_rate, false_negative_penalty, false_positive_penalty, position_size_penalty, false_class_penalty, percentage_validation)\n"," pdf.write_html(html)\n","\n"," #pdf.multi_cell(190, 5, txt = text_2, align='L')\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," pdf.ln(1)\n"," pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(30, 5, txt= 'Training_source:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_Source, align = 'L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(29, 5, txt= 'Training_target:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_Source_annotations, align = 'L')\n"," #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(22, 5, txt= 'Model Path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n"," pdf.ln(1)\n"," if visualise_example == True:\n"," pdf.cell(60, 5, txt = 'Example ground-truth annotation', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread('/content/TrainingDataExample_YOLOv2.png').shape\n"," pdf.image('/content/TrainingDataExample_YOLOv2.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- YOLOv2: Redmon, Joseph, and Ali Farhadi. \"YOLO9000: better, faster, stronger.\" Proceedings of the IEEE conference on computer vision and pattern recognition. 2017.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," ref_3 = '- YOLOv2 keras: /~https://github.com/experiencor/keras-yolo2, (2018)'\n"," pdf.multi_cell(190, 5, txt = ref_3, align='L')\n"," if augmentation:\n"," ref_4 = '- imgaug: Jung, Alexander et al., /~https://github.com/aleju/imgaug, (2020)'\n"," pdf.multi_cell(190, 5, txt = ref_4, align='L')\n"," pdf.ln(3)\n"," reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(model_path+'/'+model_name+'/'+model_name+'_training_report.pdf')\n","\n"," print('------------------------------')\n"," print('PDF report exported in '+model_path+'/'+model_name+'/')\n","\n","\n","def qc_pdf_export():\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'YOLOv2'\n","\n"," day = datetime.now()\n"," datetime_str = str(day)[0:16]\n","\n"," Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate and Time: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(2)\n"," pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n"," pdf.ln(1)\n"," if os.path.exists(QC_model_folder+'/Quality Control/lossCurveAndmAPPlots.png'):\n"," exp_size = io.imread(QC_model_folder+'/Quality Control/lossCurveAndmAPPlots.png').shape\n"," pdf.image(QC_model_folder+'/Quality Control/lossCurveAndmAPPlots.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(80, 5, txt = 'P-R curves for test dataset', ln=1, align='L')\n"," pdf.ln(2)\n"," for i in range(len(AP)):\n"," if os.path.exists(QC_model_folder+'/Quality Control/P-R_curve_'+config['model']['labels'][i]+'.png'):\n"," exp_size = io.imread(QC_model_folder+'/Quality Control/P-R_curve_'+config['model']['labels'][i]+'.png').shape\n"," pdf.ln(1)\n"," pdf.image(QC_model_folder+'/Quality Control/P-R_curve_'+config['model']['labels'][i]+'.png', x=16, y=None, w=round(exp_size[1]/4), h=round(exp_size[0]/4))\n"," else:\n"," pdf.cell(100, 5, txt='For the class '+config['model']['labels'][i]+' the model did not predict any objects.', ln=1, align='L')\n"," pdf.ln(3)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n","\n"," pdf.ln(1)\n"," html = \"\"\"\n"," \n"," \n"," \"\"\"\n"," with open(QC_model_folder+'/Quality Control/QC_results.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," class_name = header[0]\n"," fp = header[1]\n"," tp = header[2]\n"," fn = header[3]\n"," recall = header[4]\n"," precision = header[5]\n"," acc = header[6]\n"," f1 = header[7]\n"," AP_score = header[8]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(class_name,fp,tp,fn,recall,precision,acc,f1,AP_score)\n"," html = html+header\n"," i=0\n"," for row in metrics:\n"," i+=1\n"," class_name = row[0]\n"," fp = row[1]\n"," tp = row[2]\n"," fn = row[3]\n"," recall = row[4]\n"," precision = row[5]\n"," acc = row[6]\n"," f1 = row[7]\n"," AP_score = row[8]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(class_name,fp,tp,fn,str(round(float(recall),3)),str(round(float(precision),3)),str(round(float(acc),3)),str(round(float(f1),3)),str(round(float(AP_score),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}{3}{4}{5}{6}{7}{8}
{0}{1}{2}{3}{4}{5}{6}{7}{8}
\"\"\"\n","\n"," pdf.write_html(html)\n"," pdf.cell(180, 5, txt='Mean average precision (mAP) over the all classes is: '+str(round(mAP_score,3)), ln=1, align='L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(3)\n"," pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n"," pdf.ln(3)\n"," exp_size = io.imread(QC_model_folder+'/Quality Control/QC_example_data.png').shape\n"," pdf.image(QC_model_folder+'/Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/10), h = round(exp_size[0]/10))\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.ln(3)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- YOLOv2: Redmon, Joseph, and Ali Farhadi. \"YOLO9000: better, faster, stronger.\" Proceedings of the IEEE conference on computer vision and pattern recognition. 2017.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," ref_3 = '- YOLOv2 keras: /~https://github.com/experiencor/keras-yolo2, (2018)'\n"," pdf.multi_cell(190, 5, txt = ref_3, align='L')\n","\n"," pdf.ln(3)\n"," reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(QC_model_folder+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n","\n","\n"," print('------------------------------')\n"," print('PDF report exported in '+QC_model_folder+'/Quality Control/')\n","\n","# Build requirements file for local run\n","after = [str(m) for m in sys.modules]\n","build_requirements_file(before, after)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **2. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["\n","## **2.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelerator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"zCvebubeSaGY","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"sNIVx8_CLolt"},"source":["## **2.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","cellView":"form"},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"vX0BPrgq9aW2"},"source":["** If you cannot see your files, reactivate your session by connecting to your hosted runtime.** \n","\n","\n","\"Example
Connect to a hosted runtime.
"]},{"cell_type":"markdown","metadata":{"id":"HLYcZR9gMv42"},"source":["# **3. Select your paths and parameters**\n","\n","---\n","\n","The code below allows the user to enter the paths to where the training data is and to define the training parameters.\n","\n","After playing the cell will display some quantitative metrics of your dataset, including a count of objects per image and the number of instances per class.\n"]},{"cell_type":"markdown","metadata":{"id":"FQ_QxtSWQ7CL"},"source":["## **3.1. Setting main training parameters**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"id":"AuESFimvMv43"},"source":[" **Paths for training, predictions and results**\n","\n","\n","**`Training_source:`, `Training_source_annotations`:** These are the paths to your folders containing the Training_source and the annotation data respectively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**Training Parameters**\n","\n","**`number_of_epochs`:**Give estimates for training performance given a number of epochs and provide a default value. **Default value: 27**\n","\n","**Note that YOLOv2 uses 3 Warm-up epochs which improves the model's performance. This means the network will train for number_of_epochs + 3 epochs.**\n","\n","**`backend`:** There are different backends which are available to be trained for YOLO. These are usually slightly different model architectures, with pretrained weights. Take a look at the available backends and research which one will be best suited for your dataset.\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`train_times:`**Input how many times to cycle through the dataset per epoch. This is more useful for smaller datasets (but risks overfitting). **Default value: 4**\n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 16**\n","\n","**`learning_rate:`** Input the initial value to be used as learning rate. **Default value: 0.0004**\n","\n","**`false_negative_penalty:`** Penalize wrong detection of 'no-object'. **Default: 5.0**\n","\n","**`false_positive_penalty:`** Penalize wrong detection of 'object'. **Default: 1.0**\n","\n","**`position_size_penalty:`** Penalize inaccurate positioning or size of bounding boxes. **Default:1.0**\n","\n","**`false_class_penalty:`** Penalize misclassification of object in bounding box. **Default: 1.0**\n","\n","**`percentage_validation:`** Input the percentage of your training dataset you want to use to validate the network during training. **Default value: 10** "]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","cellView":"form"},"source":["#@markdown ###Path to training images:\n","\n","Training_Source = \"\" #@param {type:\"string\"}\n","\n","# Ground truth images\n","Training_Source_annotations = \"\" #@param {type:\"string\"}\n","\n","# model name and path\n","#@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","# backend\n","#@markdown ###Choose a backend\n","backend = \"Full Yolo\" #@param [\"Select Model\",\"Full Yolo\",\"Inception3\",\"SqueezeNet\",\"MobileNet\",\"Tiny Yolo\"]\n","\n","\n","full_model_path = os.path.join(model_path,model_name)\n","if os.path.exists(full_model_path):\n"," print(bcolors.WARNING+'Model folder already exists and will be overwritten.'+bcolors.NORMAL)\n","\n","\n","# other parameters for training.\n","# @markdown ###Training Parameters\n","# @markdown Number of epochs:\n","\n","number_of_epochs = 27#@param {type:\"number\"}\n","\n","# !sed -i 's@\\\"nb_epochs\\\":.*,@\\\"nb_epochs\\\": $number_of_epochs,@g' config.json\n","\n","# #@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please input:\n","train_times = 4 #@param {type:\"integer\"}\n","batch_size = 16#@param {type:\"number\"}\n","learning_rate = 1e-4 #@param{type:\"number\"}\n","false_negative_penalty = 5.0 #@param{type:\"number\"}\n","false_positive_penalty = 1.0 #@param{type:\"number\"}\n","position_size_penalty = 1.0 #@param{type:\"number\"}\n","false_class_penalty = 1.0 #@param{type:\"number\"}\n","percentage_validation = 10#@param{type:\"number\"}\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," train_times = 4\n"," batch_size = 8\n"," learning_rate = 1e-4\n"," false_negative_penalty = 5.0\n"," false_positive_penalty = 1.0\n"," position_size_penalty = 1.0\n"," false_class_penalty = 1.0\n"," percentage_validation = 10\n","\n","\n","df_anno = []\n","dir_anno = Training_Source_annotations\n","for fnm in os.listdir(dir_anno): \n"," if not fnm.startswith('.'): ## do not include hidden folders/files\n"," tree = ET.parse(os.path.join(dir_anno,fnm))\n"," row = extract_single_xml_file(tree)\n"," row[\"fileID\"] = os.path.splitext(fnm)[0]\n"," df_anno.append(row)\n","df_anno = pd.DataFrame(df_anno)\n","\n","maxNobj = np.max(df_anno[\"Nobj\"])\n","totalNobj = np.sum(df_anno[\"Nobj\"])\n","\n","\n","class_obj = []\n","for ibbx in range(maxNobj):\n"," class_obj.extend(df_anno[\"bbx_{}_name\".format(ibbx)].values)\n","class_obj = np.array(class_obj)\n","\n","count = Counter(class_obj[class_obj != 'nan'])\n","print(count)\n","class_nm = list(count.keys())\n","class_labels = json.dumps(class_nm)\n","class_count = list(count.values())\n","asort_class_count = np.argsort(class_count)\n","\n","class_nm = np.array(class_nm)[asort_class_count]\n","class_count = np.array(class_count)[asort_class_count]\n","\n","xs = range(len(class_count))\n","\n","\n","#Show how many objects there are in the images\n","plt.figure(figsize=(15,8))\n","plt.subplot(1,2,1)\n","plt.hist(df_anno[\"Nobj\"].values,bins=50)\n","plt.title(\"Total number of objects in the dataset: {}\".format(totalNobj))\n","plt.xlabel('Number of objects per image')\n","plt.ylabel('Occurences')\n","\n","plt.subplot(1,2,2)\n","plt.barh(xs,class_count)\n","plt.yticks(xs,class_nm)\n","plt.title(\"The number of objects per class: {} classes in total\".format(len(count)))\n","plt.show()\n","\n","visualise_example = False\n","Use_pretrained_model = False\n","Use_Data_augmentation = False\n","\n","full_model_path = os.path.join(model_path,model_name)\n","if os.path.exists(full_model_path):\n"," print(bcolors.WARNING+'Model folder already exists and has been overwritten.'+bcolors.NORMAL)\n"," shutil.rmtree(full_model_path)\n","\n","# Create a new directory\n","os.mkdir(full_model_path)\n","\n","pdf_export()"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"0JPIts19QBBz","cellView":"form"},"source":["#@markdown ###Play this cell to visualise an example image from your dataset to make sure annotations and images are properly matched.\n","import imageio\n","visualise_example = True\n","size = 1 \n","ind_random = np.random.randint(0,df_anno.shape[0],size=size)\n","img_dir=Training_Source\n","\n","file_suffix = os.path.splitext(os.listdir(Training_Source)[0])[1]\n","for irow in ind_random:\n"," row = df_anno.iloc[irow,:]\n"," path = os.path.join(img_dir, row[\"fileID\"] + file_suffix)\n"," # read in image\n"," img = imageio.imread(path)\n","\n"," plt.figure(figsize=(12,12))\n"," plt.imshow(img, cmap='gray') # plot image\n"," plt.title(\"Nobj={}, height={}, width={}\".format(row[\"Nobj\"],row[\"height\"],row[\"width\"]))\n"," # for each object in the image, plot the bounding box\n"," for iplot in range(row[\"Nobj\"]):\n"," plt_rectangle(plt,\n"," label = row[\"bbx_{}_name\".format(iplot)],\n"," x1=row[\"bbx_{}_xmin\".format(iplot)],\n"," y1=row[\"bbx_{}_ymin\".format(iplot)],\n"," x2=row[\"bbx_{}_xmax\".format(iplot)],\n"," y2=row[\"bbx_{}_ymax\".format(iplot)])\n"," plt.axis('off')\n"," plt.savefig('/content/TrainingDataExample_YOLOv2.png',bbox_inches='tight',pad_inches=0)\n"," plt.show() ## show the plot\n","\n","pdf_export()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xyQZKby8yFME"},"source":["##**3.2. Data augmentation**\n","\n","---\n","\n"," Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if the dataset the `Use_Data_Augmentation` box can be unticked.\n","\n","Here, the images and bounding boxes are augmented by flipping and rotation. When doubling the dataset the images are only flipped. With each higher factor of augmentation the images added to the dataset represent one further rotation to the right by 90 degrees. 8x augmentation will give a dataset that is fully rotated and flipped once."]},{"cell_type":"code","metadata":{"id":"DMqWq5-AxnFU","cellView":"form"},"source":["#@markdown ##**Augmentation Options**\n","\n","\n","Use_Data_augmentation = False #@param {type:\"boolean\"}\n","\n","multiply_dataset_by = 2 #@param {type:\"slider\", min:2, max:8, step:1}\n","\n","rotation_range = 90\n","\n","file_suffix = os.path.splitext(os.listdir(Training_Source)[0])[1]\n","if (Use_Data_augmentation):\n"," print('Data Augmentation enabled')\n"," # load images as NumPy arrays and append them to images list\n"," if os.path.exists(Training_Source+'/.ipynb_checkpoints'):\n"," shutil.rmtree(Training_Source+'/.ipynb_checkpoints')\n"," \n"," images = []\n"," for index, file in enumerate(glob.glob(Training_Source+'/*'+file_suffix)):\n"," images.append(imageio.imread(file))\n"," \n"," # how many images we have\n"," print('Augmenting {} images'.format(len(images)))\n","\n"," # apply xml_to_csv() function to convert all XML files in images/ folder into labels.csv\n"," labels_df = xml_to_csv(Training_Source_annotations)\n"," labels_df.to_csv(('/content/original_labels.csv'), index=None)\n"," \n"," # Apply flip augmentation\n"," aug = iaa.OneOf([ \n"," iaa.Fliplr(1),\n"," iaa.Flipud(1)\n"," ])\n"," aug_2 = iaa.Affine(rotate=rotation_range, fit_output=True)\n"," aug_3 = iaa.Affine(rotate=rotation_range*2, fit_output=True)\n"," aug_4 = iaa.Affine(rotate=rotation_range*3, fit_output=True)\n","\n"," #Here we create a folder that will hold the original image dataset and the augmented image dataset\n"," augmented_training_source = os.path.dirname(Training_Source)+'/'+os.path.basename(Training_Source)+'_augmentation'\n"," if os.path.exists(augmented_training_source):\n"," shutil.rmtree(augmented_training_source)\n"," os.mkdir(augmented_training_source)\n","\n"," #Here we create a folder that will hold the original image annotation dataset and the augmented image annotation dataset (the bounding boxes).\n"," augmented_training_source_annotation = os.path.dirname(Training_Source_annotations)+'/'+os.path.basename(Training_Source_annotations)+'_augmentation'\n"," if os.path.exists(augmented_training_source_annotation):\n"," shutil.rmtree(augmented_training_source_annotation)\n"," os.mkdir(augmented_training_source_annotation)\n","\n"," #Create the augmentation\n"," augmented_images_df = image_aug(labels_df, Training_Source+'/', augmented_training_source+'/', 'flip_', aug)\n"," \n"," # Concat resized_images_df and augmented_images_df together and save in a new all_labels.csv file\n"," all_labels_df = pd.concat([labels_df, augmented_images_df])\n"," all_labels_df.to_csv('/content/combined_labels.csv', index=False)\n","\n"," #Here we convert the new bounding boxes for the augmented images to PASCAL VOC .xml format\n"," def convert_to_xml(df,source,target_folder):\n"," grouped = df.groupby('filename')\n"," for file in os.listdir(source):\n"," #if file in grouped.filename:\n"," group_df = grouped.get_group(file)\n"," group_df = group_df.reset_index()\n"," group_df = group_df.drop(['index'], axis=1)\n"," #group_df = group_df.dropna(axis=0)\n"," writer = Writer(source+'/'+file,group_df.iloc[0]['width'],group_df.iloc[0]['height'])\n"," for i, row in group_df.iterrows():\n"," writer.addObject(row['class'],round(row['xmin']),round(row['ymin']),round(row['xmax']),round(row['ymax']))\n"," writer.save(target_folder+'/'+os.path.splitext(file)[0]+'.xml')\n"," convert_to_xml(all_labels_df,augmented_training_source,augmented_training_source_annotation)\n"," \n"," #Second round of augmentation\n"," if multiply_dataset_by > 2:\n"," aug_labels_df_2 = xml_to_csv(augmented_training_source_annotation)\n"," augmented_images_2_df = image_aug(aug_labels_df_2, augmented_training_source+'/', augmented_training_source+'/', 'rot1_90_', aug_2)\n"," all_aug_labels_df = pd.concat([augmented_images_df, augmented_images_2_df])\n"," #all_labels_df.to_csv('/content/all_labels_aug.csv', index=False)\n"," \n"," for file in os.listdir(augmented_training_source_annotation):\n"," os.remove(os.path.join(augmented_training_source_annotation,file))\n"," convert_to_xml(all_aug_labels_df,augmented_training_source,augmented_training_source_annotation)\n","\n"," if multiply_dataset_by > 3:\n"," print('Augmenting again')\n"," aug_labels_df_3 = xml_to_csv(augmented_training_source_annotation)\n"," augmented_images_3_df = image_aug(aug_labels_df_3, augmented_training_source+'/', augmented_training_source+'/', 'rot2_90_', aug_2)\n"," all_aug_labels_df_3 = pd.concat([all_aug_labels_df, augmented_images_3_df])\n","\n"," for file in os.listdir(augmented_training_source_annotation):\n"," os.remove(os.path.join(augmented_training_source_annotation,file))\n"," convert_to_xml(all_aug_labels_df_3,augmented_training_source,augmented_training_source_annotation)\n"," \n"," #This is a preliminary remover of potential duplicates in the augmentation\n"," #Ideally, duplicates are not even produced, but this acts as a fail safe.\n"," if multiply_dataset_by==4:\n"," for file in os.listdir(augmented_training_source):\n"," if file.startswith('rot2_90_flip_'):\n"," os.remove(os.path.join(augmented_training_source,file))\n"," os.remove(os.path.join(augmented_training_source_annotation, os.path.splitext(file)[0]+'.xml'))\n","\n"," if multiply_dataset_by > 4:\n"," print('And Again')\n"," aug_labels_df_4 = xml_to_csv(augmented_training_source_annotation)\n"," augmented_images_4_df = image_aug(aug_labels_df_4, augmented_training_source+'/',augmented_training_source+'/','rot3_90_', aug_2)\n"," all_aug_labels_df_4 = pd.concat([all_aug_labels_df_3, augmented_images_4_df])\n","\n"," for file in os.listdir(augmented_training_source_annotation):\n"," os.remove(os.path.join(augmented_training_source_annotation,file))\n"," convert_to_xml(all_aug_labels_df_4,augmented_training_source,augmented_training_source_annotation)\n","\n"," for file in os.listdir(augmented_training_source):\n"," if file.startswith('rot3_90_rot2_90_flip_'):\n"," os.remove(os.path.join(augmented_training_source,file))\n"," os.remove(os.path.join(augmented_training_source_annotation, os.path.splitext(file)[0]+'.xml'))\n"," if file.startswith('rot3_90_rot1_90_flip_'):\n"," os.remove(os.path.join(augmented_training_source,file))\n"," os.remove(os.path.join(augmented_training_source_annotation, os.path.splitext(file)[0]+'.xml'))\n"," if file.startswith('rot3_90_flip_'):\n"," os.remove(os.path.join(augmented_training_source,file))\n"," os.remove(os.path.join(augmented_training_source_annotation, os.path.splitext(file)[0]+'.xml'))\n"," if file.startswith('rot2_90_flip_'):\n"," os.remove(os.path.join(augmented_training_source,file))\n"," os.remove(os.path.join(augmented_training_source_annotation, os.path.splitext(file)[0]+'.xml'))\n","\n","\n"," if multiply_dataset_by > 5:\n"," print('And again')\n"," augmented_images_5_df = image_aug(labels_df, Training_Source+'/', augmented_training_source+'/', 'rot_90_', aug_2)\n"," all_aug_labels_df_5 = pd.concat([all_aug_labels_df_4,augmented_images_5_df])\n","\n"," for file in os.listdir(augmented_training_source_annotation):\n"," os.remove(os.path.join(augmented_training_source_annotation,file))\n"," \n"," convert_to_xml(all_aug_labels_df_5,augmented_training_source,augmented_training_source_annotation)\n","\n"," if multiply_dataset_by > 6:\n"," print('And again')\n"," augmented_images_df_6 = image_aug(labels_df, Training_Source+'/', augmented_training_source+'/', 'rot_180_', aug_3)\n"," all_aug_labels_df_6 = pd.concat([all_aug_labels_df_5,augmented_images_df_6])\n"," \n"," for file in os.listdir(augmented_training_source_annotation):\n"," os.remove(os.path.join(augmented_training_source_annotation,file))\n"," convert_to_xml(all_aug_labels_df_6,augmented_training_source,augmented_training_source_annotation)\n","\n"," if multiply_dataset_by > 7:\n"," print('And again')\n"," augmented_images_df_7 = image_aug(labels_df, Training_Source+'/', augmented_training_source+'/', 'rot_270_', aug_4)\n"," all_aug_labels_df_7 = pd.concat([all_aug_labels_df_6,augmented_images_df_7])\n"," \n"," for file in os.listdir(augmented_training_source_annotation):\n"," os.remove(os.path.join(augmented_training_source_annotation,file))\n"," convert_to_xml(all_aug_labels_df_7,augmented_training_source,augmented_training_source_annotation)\n","\n"," for file in os.listdir(Training_Source):\n"," shutil.copyfile(Training_Source+'/'+file,augmented_training_source+'/'+file)\n"," shutil.copyfile(Training_Source_annotations+'/'+os.path.splitext(file)[0]+'.xml',augmented_training_source_annotation+'/'+os.path.splitext(file)[0]+'.xml')\n","\n"," df_anno = []\n"," dir_anno = augmented_training_source_annotation\n"," for fnm in os.listdir(dir_anno): \n"," if not fnm.startswith('.'): ## do not include hidden folders/files\n"," tree = ET.parse(os.path.join(dir_anno,fnm))\n"," row = extract_single_xml_file(tree)\n"," row[\"fileID\"] = os.path.splitext(fnm)[0]\n"," df_anno.append(row)\n"," df_anno = pd.DataFrame(df_anno)\n","\n"," maxNobj = np.max(df_anno[\"Nobj\"])\n","\n"," #Show how many objects there are in the images\n"," plt.figure()\n"," plt.subplot(2,1,1)\n"," plt.hist(df_anno[\"Nobj\"].values,bins=50)\n"," plt.title(\"max N of objects per image={}\".format(maxNobj))\n"," plt.show()\n","\n"," #Show the classes and how many there are of each in the dataset\n"," class_obj = []\n"," for ibbx in range(maxNobj):\n"," class_obj.extend(df_anno[\"bbx_{}_name\".format(ibbx)].values)\n"," class_obj = np.array(class_obj)\n","\n"," count = Counter(class_obj[class_obj != 'nan'])\n"," print(count)\n"," class_nm = list(count.keys())\n"," class_labels = json.dumps(class_nm)\n"," class_count = list(count.values())\n"," asort_class_count = np.argsort(class_count)\n","\n"," class_nm = np.array(class_nm)[asort_class_count]\n"," class_count = np.array(class_count)[asort_class_count]\n","\n"," xs = range(len(class_count))\n","\n"," plt.subplot(2,1,2)\n"," plt.barh(xs,class_count)\n"," plt.yticks(xs,class_nm)\n"," plt.title(\"The number of objects per class: {} objects in total\".format(len(count)))\n"," plt.show()\n","\n","else:\n"," print('No augmentation will be used')\n","\n","pdf_export(augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"y7HVvJZuNU1t","cellView":"form"},"source":["#@markdown ###Play this cell to visualise some example images from your **augmented** dataset to make sure annotations and images are properly matched.\n","if (Use_Data_augmentation):\n"," df_anno_aug = []\n"," dir_anno_aug = augmented_training_source_annotation\n"," for fnm in os.listdir(dir_anno_aug): \n"," if not fnm.startswith('.'): ## do not include hidden folders/files\n"," tree = ET.parse(os.path.join(dir_anno_aug,fnm))\n"," row = extract_single_xml_file(tree)\n"," row[\"fileID\"] = os.path.splitext(fnm)[0]\n"," df_anno_aug.append(row)\n"," df_anno_aug = pd.DataFrame(df_anno_aug)\n","\n"," size = 3 \n"," ind_random = np.random.randint(0,df_anno_aug.shape[0],size=size)\n"," img_dir=augmented_training_source\n","\n"," file_suffix = os.path.splitext(os.listdir(augmented_training_source)[0])[1]\n"," for irow in ind_random:\n"," row = df_anno_aug.iloc[irow,:]\n"," path = os.path.join(img_dir, row[\"fileID\"] + file_suffix)\n"," # read in image\n"," img = imageio.imread(path)\n","\n"," plt.figure(figsize=(12,12))\n"," plt.imshow(img, cmap='gray') # plot image\n"," plt.title(\"Nobj={}, height={}, width={}\".format(row[\"Nobj\"],row[\"height\"],row[\"width\"]))\n"," # for each object in the image, plot the bounding box\n"," for iplot in range(row[\"Nobj\"]):\n"," plt_rectangle(plt,\n"," label = row[\"bbx_{}_name\".format(iplot)],\n"," x1=row[\"bbx_{}_xmin\".format(iplot)],\n"," y1=row[\"bbx_{}_ymin\".format(iplot)],\n"," x2=row[\"bbx_{}_xmax\".format(iplot)],\n"," y2=row[\"bbx_{}_ymax\".format(iplot)])\n"," plt.show() ## show the plot\n"," print('These are the augmented training images.')\n","\n","else:\n"," print('Data augmentation disabled.')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"3L9zSGtORKYI"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a YOLOv2 model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n","\n"," In order to continue training from the point where the pre-trained model left off, it is adviseable to also **load the learning rate** that was used when the training ended. This is automatically saved for models trained with ZeroCostDL4Mic and will be loaded here. If no learning rate can be found in the model folder provided, the default learning rate will be used. "]},{"cell_type":"code","metadata":{"id":"9vC2n-HeLdiJ","cellView":"form"},"source":["# @markdown ##Loading weights from a pretrained network\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","Weights_choice = \"best\" #@param [\"last\", \"best\"]\n","\n","pretrained_model_path = \"\" #@param{type:\"string\"}\n","h5_file_path = pretrained_model_path+'/'+Weights_choice+'_weights.h5'\n","\n","if not os.path.exists(h5_file_path) and Use_pretrained_model:\n"," print('WARNING pretrained model does not exist')\n"," Use_pretrained_model = False\n","\n","if Use_pretrained_model:\n"," with open(os.path.join(pretrained_model_path, 'Quality Control', 'training_evaluation.csv'),'r') as csvfile:\n"," csvRead = pd.read_csv(csvfile, sep=',')\n"," if \"learning rate\" in csvRead.columns: #Here we check that the learning rate column exist (compatibility with model trained un ZeroCostDL4Mic bellow 1.4):\n"," print(\"pretrained network learning rate found\")\n"," #find the last learning rate\n"," lastLearningRate = csvRead[\"learning rate\"].iloc[-1]\n"," #Find the learning rate corresponding to the lowest validation loss\n"," min_val_loss = csvRead[csvRead['val_loss'] == min(csvRead['val_loss'])]\n"," #print(min_val_loss)\n"," bestLearningRate = min_val_loss['learning rate'].iloc[-1]\n","\n"," if Weights_choice == \"last\":\n"," print('Last learning rate: '+str(lastLearningRate))\n"," learning_rate = lastLearningRate\n","\n"," if Weights_choice == \"best\":\n"," print('Learning rate of best validation loss: '+str(bestLearningRate))\n"," learning_rate = bestLearningRate\n","\n"," if not \"learning rate\" in csvRead.columns: #if the column does not exist, then initial learning rate is used instead\n"," print(bcolors.WARNING+'WARNING: The learning rate cannot be identified from the pretrained network. Default learning rate of '+str(bestLearningRate)+' will be used instead' + W)\n","else:\n"," print('No pre-trained models will be used.')\n","\n","pdf_export(augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"MCGklf1vZf2M"},"source":["#**4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"0Dfn8ZsEMv5d"},"source":["## **4.1. Start training**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches.\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"code","metadata":{"scrolled":true,"id":"iwNmp1PUzRDQ","cellView":"form"},"source":["#@markdown ##Start training\n","\n","# ------------\n","\n","os.chdir('/content/keras-yolo2')\n","if backend == \"Full Yolo\":\n"," if not os.path.exists('/content/keras-yolo2/full_yolo_backend.h5'):\n"," !wget /~https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/full_yolo_backend.h5\n","elif backend == \"Inception3\":\n"," if not os.path.exists('/content/keras-yolo2/inception_backend.h5'):\n"," !wget /~https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/inception_backend.h5\n","elif backend == \"MobileNet\":\n"," if not os.path.exists('/content/keras-yolo2/mobilenet_backend.h5'):\n"," !wget /~https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/mobilenet_backend.h5\n","elif backend == \"SqueezeNet\":\n"," if not os.path.exists('/content/keras-yolo2/squeezenet_backend.h5'):\n"," !wget /~https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/squeezenet_backend.h5\n","elif backend == \"Tiny Yolo\":\n"," if not os.path.exists('/content/keras-yolo2/tiny_yolo_backend.h5'):\n"," !wget /~https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/tiny_yolo_backend.h5\n","\n","\n","full_model_file_path = full_model_path+'/best_weights.h5'\n","os.chdir('/content/keras-yolo2/')\n","\n","#Change backend name\n","!sed -i 's@\\\"backend\\\":.*,@\\\"backend\\\": \\\"$backend\\\",@g' config.json\n","\n","#Change the name of the training folder\n","!sed -i 's@\\\"train_image_folder\\\":.*,@\\\"train_image_folder\\\": \\\"$Training_Source/\\\",@g' config.json\n","\n","#Change annotation folder\n","!sed -i 's@\\\"train_annot_folder\\\":.*,@\\\"train_annot_folder\\\": \\\"$Training_Source_annotations/\\\",@g' config.json\n","\n","#Change the name of the saved model\n","!sed -i 's@\\\"saved_weights_name\\\":.*,@\\\"saved_weights_name\\\": \\\"$full_model_file_path\\\",@g' config.json\n","\n","#Change warmup epochs for untrained model\n","!sed -i 's@\\\"warmup_epochs\\\":.*,@\\\"warmup_epochs\\\": 3,@g' config.json\n","\n","#When defining a new model we should reset the pretrained model parameter\n","!sed -i 's@\\\"pretrained_weights\\\":.*,@\\\"pretrained_weights\\\": \\\"No_pretrained_weights\\\",@g' config.json\n","\n","!sed -i 's@\\\"nb_epochs\\\":.*,@\\\"nb_epochs\\\": $number_of_epochs,@g' config.json\n","\n","!sed -i 's@\\\"train_times\\\":.*,@\\\"train_times\\\": $train_times,@g' config.json\n","!sed -i 's@\\\"batch_size\\\":.*,@\\\"batch_size\\\": $batch_size,@g' config.json\n","!sed -i 's@\\\"learning_rate\\\":.*,@\\\"learning_rate\\\": $learning_rate,@g' config.json\n","!sed -i 's@\\\"object_scale\":.*,@\\\"object_scale\\\": $false_negative_penalty,@g' config.json\n","!sed -i 's@\\\"no_object_scale\":.*,@\\\"no_object_scale\\\": $false_positive_penalty,@g' config.json\n","!sed -i 's@\\\"coord_scale\\\":.*,@\\\"coord_scale\\\": $position_size_penalty,@g' config.json\n","!sed -i 's@\\\"class_scale\\\":.*,@\\\"class_scale\\\": $false_class_penalty,@g' config.json\n","\n","#Write the annotations to a csv file\n","df_anno.to_csv(full_model_path+'/annotations.csv', index=False)#header=False, sep=',')\n","\n","!sed -i 's@\\\"labels\\\":.*@\\\"labels\\\": $class_labels@g' config.json\n","\n","\n","#Generate anchors for the bounding boxes\n","os.chdir('/content/keras-yolo2')\n","output = sp.getoutput('python ./gen_anchors.py -c ./config.json')\n","\n","anchors_1 = output.find(\"[\")\n","anchors_2 = output.find(\"]\")\n","\n","config_anchors = output[anchors_1:anchors_2+1]\n","!sed -i 's@\\\"anchors\\\":.*,@\\\"anchors\\\": $config_anchors,@g' config.json\n","\n","\n","!sed -i 's@\\\"pretrained_weights\\\":.*,@\\\"pretrained_weights\\\": \\\"$h5_file_path\\\",@g' config.json\n","\n","\n","# !sed -i 's@\\\"anchors\\\":.*,@\\\"anchors\\\": $config_anchors,@g' config.json\n","\n","\n","if Use_pretrained_model:\n"," !sed -i 's@\\\"warmup_epochs\\\":.*,@\\\"warmup_epochs\\\": 0,@g' config.json\n"," !sed -i 's@\\\"learning_rate\\\":.*,@\\\"learning_rate\\\": $learning_rate,@g' config.json\n","\n","if Use_Data_augmentation:\n"," #Change the name of the training folder\n"," !sed -i 's@\\\"train_image_folder\\\":.*,@\\\"train_image_folder\\\": \\\"$augmented_training_source/\\\",@g' config.json\n","\n"," #Change annotation folder\n"," !sed -i 's@\\\"train_annot_folder\\\":.*,@\\\"train_annot_folder\\\": \\\"$augmented_training_source_annotation/\\\",@g' config.json\n","\n","\n","# ------------\n","\n","\n","\n","if os.path.exists(full_model_path+\"/Quality Control\"):\n"," shutil.rmtree(full_model_path+\"/Quality Control\")\n","os.makedirs(full_model_path+\"/Quality Control\")\n","\n","\n","start = time.time()\n","\n","os.chdir('/content/keras-yolo2')\n","train('config.json', full_model_path, percentage_validation)\n","\n","shutil.copyfile('/content/keras-yolo2/config.json',full_model_path+'/config.json')\n","\n","if os.path.exists('/content/keras-yolo2/best_map_weights.h5'):\n"," shutil.move('/content/keras-yolo2/best_map_weights.h5',full_model_path+'/best_map_weights.h5')\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","#Create a pdf document with training summary\n","\n","pdf_export(trained = True, augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"_0Hynw3-xHp1"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"code","metadata":{"id":"eAJzMwPA6tlH","cellView":"form"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the name of the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","if (Use_the_current_trained_model): \n"," QC_model_folder = full_model_path\n","\n","#print(os.path.join(model_path, model_name))\n","\n","QC_model_name = os.path.basename(QC_model_folder)\n","\n","if os.path.exists(QC_model_folder):\n"," print(\"The \"+QC_model_name+\" model will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path before proceeding further.')\n","\n","if Use_the_current_trained_model == False:\n"," if os.path.exists('/content/keras-yolo2/config.json'):\n"," os.remove('/content/keras-yolo2/config.json')\n"," shutil.copyfile(QC_model_folder+'/config.json','/content/keras-yolo2/config.json')\n","\n","#@markdown ###Which backend is the model using?\n","backend = \"Full Yolo\" #@param [\"Select Model\",\"Full Yolo\",\"Inception3\",\"SqueezeNet\",\"MobileNet\",\"Tiny Yolo\"]\n","os.chdir('/content/keras-yolo2')\n","if backend == \"Full Yolo\":\n"," if not os.path.exists('/content/keras-yolo2/full_yolo_backend.h5'):\n"," !wget /~https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/full_yolo_backend.h5\n","elif backend == \"Inception3\":\n"," if not os.path.exists('/content/keras-yolo2/inception_backend.h5'):\n"," !wget /~https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/inception_backend.h5\n","elif backend == \"MobileNet\":\n"," if not os.path.exists('/content/keras-yolo2/mobilenet_backend.h5'):\n"," !wget /~https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/mobilenet_backend.h5\n","elif backend == \"SqueezeNet\":\n"," if not os.path.exists('/content/keras-yolo2/squeezenet_backend.h5'):\n"," !wget /~https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/squeezenet_backend.h5\n","elif backend == \"Tiny Yolo\":\n"," if not os.path.exists('/content/keras-yolo2/tiny_yolo_backend.h5'):\n"," !wget /~https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/tiny_yolo_backend.h5\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"dhJROwlAMv5o"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased."]},{"cell_type":"code","metadata":{"id":"vMzSP50kMv5p","cellView":"form"},"source":["#@markdown ##Play the cell to show a plot of training errors vs. epoch number\n","import csv\n","from matplotlib import pyplot as plt\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","mAPDataFromCSV = []\n","with open(QC_model_folder+'/Quality Control/training_evaluation.csv','r') as csvfile:\n"," csvRead = csv.reader(csvfile, delimiter=',')\n"," next(csvRead)\n"," for row in csvRead:\n"," lossDataFromCSV.append(float(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n"," mAPDataFromCSV.append(float(row[2]))\n","epochNumber = range(len(lossDataFromCSV))\n","plt.figure(figsize=(20,15))\n","\n","plt.subplot(3,1,1)\n","plt.plot(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","\n","plt.subplot(3,1,2)\n","plt.semilogy(epochNumber,lossDataFromCSV, label='Training loss')\n","plt.semilogy(epochNumber,vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. epoch number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","#plt.savefig(os.path.dirname(QC_model_folder)+'/Quality Control/lossCurvePlots.png')\n","#plt.show()\n","\n","plt.subplot(3,1,3)\n","plt.plot(epochNumber,mAPDataFromCSV, label='mAP score')\n","plt.title('mean average precision (mAP) vs. epoch number (linear scale)')\n","plt.ylabel('mAP score')\n","plt.xlabel('Epoch number')\n","plt.legend()\n","plt.savefig(QC_model_folder+'/Quality Control/lossCurveAndmAPPlots.png',bbox_inches='tight', pad_inches=0)\n","plt.show()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"X5_92nL2xdP6"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display an overlay of the input images ground-truth (solid lines) and predicted boxes (dashed lines). Additionally, the below cell will show the mAP value of the model on the QC data together with plots of the Precision-Recall curves for all the classes in the dataset. If you want to read in more detail about these scores, we recommend [this brief explanation](https://medium.com/@jonathan_hui/map-mean-average-precision-for-object-detection-45c121a31173).\n","\n"," The images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" should contain images (e.g. as .jpg)and annotations (.xml files)!\n","\n","Since the training saves three different models, for the best validation loss (`best_weights`), best average precision (`best_mAP_weights`) and the model after the last epoch (`last_weights`), you should choose which ones you want to use for quality control or prediction. We recommend using `best_map_weights` because they should yield the best performance on the dataset. However, it can be worth testing how well `best_weights` perform too.\n","\n","**mAP score:** This refers to the mean average precision of the model on the given dataset. This value gives an indication how precise the predictions of the classes on this dataset are when compared to the ground-truth. Values closer to 1 indicate a good fit.\n","\n","**Precision:** This is the proportion of the correct classifications (true positives) in all the predictions made by the model.\n","\n","**Recall:** This is the proportion of the detected true positives in all the detectable data."]},{"cell_type":"code","metadata":{"id":"w90MdriMxhjD","cellView":"form"},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Annotations_QC_folder = \"\" #@param{type:\"string\"}\n","\n","#@markdown ##Choose which model you want to evaluate:\n","model_choice = \"best_map_weights\" #@param[\"best_weights\",\"last_weights\",\"best_map_weights\"]\n","\n","file_suffix = os.path.splitext(os.listdir(Source_QC_folder)[0])[1]\n","\n","# Create a quality control/Prediction Folder\n","if os.path.exists(QC_model_folder+\"/Quality Control/Prediction\"):\n"," shutil.rmtree(QC_model_folder+\"/Quality Control/Prediction\")\n","\n","os.makedirs(QC_model_folder+\"/Quality Control/Prediction\")\n","\n","#Delete old csv with box predictions if one exists\n","\n","if os.path.exists('/content/predicted_bounding_boxes.csv'):\n"," os.remove('/content/predicted_bounding_boxes.csv')\n","if os.path.exists('/content/predicted_bounding_boxes_names.csv'):\n"," os.remove('/content/predicted_bounding_boxes_names.csv')\n","if os.path.exists(Source_QC_folder+'/.ipynb_checkpoints'):\n"," shutil.rmtree(Source_QC_folder+'/.ipynb_checkpoints')\n","\n","os.chdir('/content/keras-yolo2')\n","\n","n_objects = []\n","for img in os.listdir(Source_QC_folder):\n"," full_image_path = Source_QC_folder+'/'+img\n"," print('----')\n"," print(img)\n"," n_obj = predict('config.json',QC_model_folder+'/'+model_choice+'.h5',full_image_path)\n"," n_objects.append(n_obj)\n"," K.clear_session()\n","\n","for img in os.listdir(Source_QC_folder):\n"," if img.endswith('detected'+file_suffix):\n"," shutil.move(Source_QC_folder+'/'+img,QC_model_folder+\"/Quality Control/Prediction/\"+img)\n","\n","#Here, we open the config file to get the classes fro the GT labels\n","config_path = '/content/keras-yolo2/config.json'\n","with open(config_path) as config_buffer:\n"," config = json.load(config_buffer)\n","\n","#Make a csv file to read into imagej macro, to create custom bounding boxes\n","header = ['filename']+['xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class']*max(n_objects)\n","with open('/content/predicted_bounding_boxes.csv', newline='') as inFile, open('/content/predicted_bounding_boxes_new.csv', 'w', newline='') as outfile:\n"," r = csv.reader(inFile)\n"," w = csv.writer(outfile)\n"," next(r, None) # skip the first row from the reader, the old header\n"," # write new header\n"," w.writerow(header)\n"," # copy the rest\n"," for row in r:\n"," w.writerow(row)\n","\n","df_bbox=pd.read_csv('/content/predicted_bounding_boxes_new.csv')\n","df_bbox=df_bbox.transpose()\n","new_header = df_bbox.iloc[0] #grab the first row for the header\n","df_bbox = df_bbox[1:] #take the data less the header row\n","df_bbox.columns = new_header #set the header row as the df header\n","df_bbox.sort_values(by='filename',axis=1,inplace=True)\n","df_bbox.to_csv(QC_model_folder+'/Quality Control/predicted_bounding_boxes_for_custom_ROI_QC.csv')\n","\n","F1_scores, AP, recall, precision = _calc_avg_precisions(config,Source_QC_folder,Annotations_QC_folder+'/',QC_model_folder+'/'+model_choice+'.h5',0.3,0.3)\n","\n","\n","\n","with open(QC_model_folder+\"/Quality Control/QC_results.csv\", \"r\") as file:\n"," x = from_csv(file)\n"," \n","print(x)\n","\n","mAP_score = sum(AP.values())/len(AP)\n","\n","print('mAP score for QC dataset: '+str(mAP_score))\n","\n","for i in range(len(AP)):\n"," if AP[i]!=0:\n"," fig = plt.figure(figsize=(8,4))\n"," if len(recall[i]) == 1:\n"," new_recall = np.linspace(0,list(recall[i])[0],10)\n"," new_precision = list(precision[i])*10\n"," fig = plt.figure(figsize=(3,2))\n"," plt.plot(new_recall,new_precision)\n"," plt.axis([min(new_recall),1,0,1.02])\n"," plt.xlabel('Recall',fontsize=14)\n"," plt.ylabel('Precision',fontsize=14)\n"," plt.title(config['model']['labels'][i]+', AP: '+str(round(AP[i],3)),fontsize=14)\n"," plt.fill_between(new_recall,new_precision,alpha=0.3)\n"," plt.savefig(QC_model_folder+'/Quality Control/P-R_curve_'+config['model']['labels'][i]+'.png', bbox_inches='tight', pad_inches=0)\n"," plt.show()\n"," else:\n"," new_recall = list(recall[i])\n"," new_recall.append(new_recall[len(new_recall)-1])\n"," new_precision = list(precision[i])\n"," new_precision.append(0)\n"," plt.plot(new_recall,new_precision)\n"," plt.axis([min(new_recall),1,0,1.02])\n"," plt.xlabel('Recall',fontsize=14)\n"," plt.ylabel('Precision',fontsize=14)\n"," plt.title(config['model']['labels'][i]+', AP: '+str(round(AP[i],3)),fontsize=14)\n"," plt.fill_between(new_recall,new_precision,alpha=0.3)\n"," plt.savefig(QC_model_folder+'/Quality Control/P-R_curve_'+config['model']['labels'][i]+'.png', bbox_inches='tight', pad_inches=0)\n"," plt.show()\n"," else:\n"," print('No object of class '+config['model']['labels'][i]+' was detected. This will lower the mAP score. Consider adding an image containing this class to your QC dataset to see if the model can detect this class at all.')\n","\n","\n","# --------------------------------------------------------------\n","add_header('/content/predicted_bounding_boxes_names.csv','/content/predicted_bounding_boxes_names_new.csv')\n","\n","# This will display a randomly chosen dataset input and predicted output\n","\n","print('Below is an example input, prediction and ground truth annotation from your test dataset.')\n","random_choice = random.choice(os.listdir(Source_QC_folder))\n","file_suffix = os.path.splitext(random_choice)[1]\n","\n","plt.figure(figsize=(30,15))\n","\n","### Display Raw input ###\n","\n","x = plt.imread(Source_QC_folder+\"/\"+random_choice)\n","plt.subplot(1,3,1)\n","plt.axis('off')\n","plt.imshow(x, interpolation='nearest', cmap='gray')\n","plt.title('Input', fontsize = 12)\n","\n","### Display Predicted annotation ###\n","\n","df_bbox2 = pd.read_csv('/content/predicted_bounding_boxes_names_new.csv')\n","for img in range(0,df_bbox2.shape[0]):\n"," df_bbox2.iloc[img]\n"," row = pd.DataFrame(df_bbox2.iloc[img])\n"," if row[img][0] == random_choice:\n"," row = row.dropna()\n"," image = imageio.imread(Source_QC_folder+'/'+row[img][0])\n"," #plt.figure(figsize=(12,12))\n"," plt.subplot(1,3,2)\n"," plt.axis('off')\n"," plt.imshow(image, cmap='gray') # plot image\n"," plt.title('Prediction', fontsize=12)\n"," for i in range(1,int(len(row)-1),6):\n"," plt_rectangle(plt,\n"," label = row[img][i+5],\n"," x1=row[img][i],#.format(iplot)],\n"," y1=row[img][i+1],\n"," x2=row[img][i+2],\n"," y2=row[img][i+3])#,\n"," #fontsize=8)\n","\n","\n","### Display GT Annotation ###\n","\n","df_anno_QC_gt = []\n","for fnm in os.listdir(Annotations_QC_folder): \n"," if not fnm.startswith('.'): ## do not include hidden folders/files\n"," tree = ET.parse(os.path.join(Annotations_QC_folder,fnm))\n"," row = extract_single_xml_file(tree)\n"," row[\"fileID\"] = os.path.splitext(fnm)[0]\n"," df_anno_QC_gt.append(row)\n","df_anno_QC_gt = pd.DataFrame(df_anno_QC_gt)\n","#maxNobj = np.max(df_anno_QC_gt[\"Nobj\"])\n","\n","for i in range(0,df_anno_QC_gt.shape[0]):\n"," if df_anno_QC_gt.iloc[i][\"fileID\"]+file_suffix == random_choice:\n"," row = df_anno_QC_gt.iloc[i]\n","\n","img = imageio.imread(Source_QC_folder+'/'+random_choice)\n","plt.subplot(1,3,3)\n","plt.axis('off')\n","plt.imshow(img, cmap='gray') # plot image\n","plt.title('Ground Truth annotations', fontsize=12)\n","\n","# for each object in the image, plot the bounding box\n","for iplot in range(row[\"Nobj\"]):\n"," plt_rectangle(plt,\n"," label = row[\"bbx_{}_name\".format(iplot)],\n"," x1=row[\"bbx_{}_xmin\".format(iplot)],\n"," y1=row[\"bbx_{}_ymin\".format(iplot)],\n"," x2=row[\"bbx_{}_xmax\".format(iplot)],\n"," y2=row[\"bbx_{}_ymax\".format(iplot)])#,\n"," #fontsize=8)\n","\n","### Show the plot ###\n","plt.savefig(QC_model_folder+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","plt.show()\n","\n","\n","#Make a pdf summary of the QC results\n","\n","qc_pdf_export()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-tJeeJjLnRkP"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"d8wuQGjoq6eN"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as restored image stacks (ImageJ-compatible TIFF images).\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images.\n","\n","**`Prediction_model_path`:** This should be the folder that contains your model."]},{"cell_type":"code","metadata":{"id":"y2TD5p7MZrEb","cellView":"form"},"source":["#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then play the cell to predict outputs from your unseen images.\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","file_suffix = os.path.splitext(os.listdir(Data_folder)[0])[1]\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, provide the name of the model and path to model folder:\n","\n","Prediction_model_path = \"\" #@param {type:\"string\"}\n","\n","#@markdown ###Which model do you want to use?\n","model_choice = \"best_map_weights\" #@param[\"best_weights\",\"last_weights\",\"best_map_weights\"]\n","\n","#@markdown ###Which backend is the model using?\n","backend = \"Full Yolo\" #@param [\"Select Model\",\"Full Yolo\",\"Inception3\",\"SqueezeNet\",\"MobileNet\",\"Tiny Yolo\"]\n","\n","\n","os.chdir('/content/keras-yolo2')\n","if backend == \"Full Yolo\":\n"," if not os.path.exists('/content/keras-yolo2/full_yolo_backend.h5'):\n"," !wget /~https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/full_yolo_backend.h5\n","elif backend == \"Inception3\":\n"," if not os.path.exists('/content/keras-yolo2/inception_backend.h5'):\n"," !wget /~https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/inception_backend.h5\n","elif backend == \"MobileNet\":\n"," if not os.path.exists('/content/keras-yolo2/mobilenet_backend.h5'):\n"," !wget /~https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/mobilenet_backend.h5\n","elif backend == \"SqueezeNet\":\n"," if not os.path.exists('/content/keras-yolo2/squeezenet_backend.h5'):\n"," !wget /~https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/squeezenet_backend.h5\n","elif backend == \"Tiny Yolo\":\n"," if not os.path.exists('/content/keras-yolo2/tiny_yolo_backend.h5'):\n"," !wget /~https://github.com/rodrigo2019/keras_yolo2/releases/download/pre-trained-weights/tiny_yolo_backend.h5\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_path = full_model_path\n","\n","if Use_the_current_trained_model == False:\n"," if os.path.exists('/content/keras-yolo2/config.json'):\n"," os.remove('/content/keras-yolo2/config.json')\n"," shutil.copyfile(Prediction_model_path+'/config.json','/content/keras-yolo2/config.json')\n","\n","if os.path.exists(Prediction_model_path+'/'+model_choice+'.h5'):\n"," print(\"The \"+os.path.basename(Prediction_model_path)+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","# Provide the code for performing predictions and saving them\n","print(\"Images will be saved into folder:\", Result_folder)\n","\n","\n","# ----- Predictions ------\n","\n","start = time.time()\n","\n","#Remove any files that might be from the prediction of QC examples.\n","if os.path.exists('/content/predicted_bounding_boxes.csv'):\n"," os.remove('/content/predicted_bounding_boxes.csv')\n","if os.path.exists('/content/predicted_bounding_boxes_new.csv'):\n"," os.remove('/content/predicted_bounding_boxes_new.csv')\n","if os.path.exists('/content/predicted_bounding_boxes_names.csv'):\n"," os.remove('/content/predicted_bounding_boxes_names.csv')\n","if os.path.exists('/content/predicted_bounding_boxes_names_new.csv'):\n"," os.remove('/content/predicted_bounding_boxes_names_new.csv')\n","\n","os.chdir('/content/keras-yolo2')\n","\n","if os.path.exists(Data_folder+'/.ipynb_checkpoints'):\n"," shutil.rmtree(Data_folder+'/.ipynb_checkpoints')\n","\n","n_objects = []\n","for img in os.listdir(Data_folder):\n"," full_image_path = Data_folder+'/'+img\n"," n_obj = predict('config.json',Prediction_model_path+'/'+model_choice+'.h5',full_image_path)#,Result_folder)\n"," n_objects.append(n_obj)\n"," K.clear_session()\n","for img in os.listdir(Data_folder):\n"," if img.endswith('detected'+file_suffix):\n"," shutil.move(Data_folder+'/'+img,Result_folder+'/'+img)\n","\n","if os.path.exists('/content/predicted_bounding_boxes.csv'):\n"," #shutil.move('/content/predicted_bounding_boxes.csv',Result_folder+'/predicted_bounding_boxes.csv')\n"," print('Bounding box labels and coordinates saved to '+ Result_folder)\n","else:\n"," print('For some reason the bounding box labels and coordinates were not saved. Check that your predictions look as expected.')\n","\n","#Make a csv file to read into imagej macro, to create custom bounding boxes\n","header = ['filename']+['xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class']*max(n_objects)\n","with open('/content/predicted_bounding_boxes.csv', newline='') as inFile, open('/content/predicted_bounding_boxes_new.csv', 'w', newline='') as outfile:\n"," r = csv.reader(inFile)\n"," w = csv.writer(outfile)\n"," next(r, None) # skip the first row from the reader, the old header\n"," # write new header\n"," w.writerow(header)\n"," # copy the rest\n"," for row in r:\n"," w.writerow(row)\n","\n","df_bbox=pd.read_csv('/content/predicted_bounding_boxes_new.csv')\n","df_bbox=df_bbox.transpose()\n","new_header = df_bbox.iloc[0] #grab the first row for the header\n","df_bbox = df_bbox[1:] #take the data less the header row\n","df_bbox.columns = new_header #set the header row as the df header\n","df_bbox.sort_values(by='filename',axis=1,inplace=True)\n","df_bbox.to_csv(Result_folder+'/predicted_bounding_boxes_for_custom_ROI.csv')\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"tP1isF0PO4C1"},"source":["## **6.2. Inspect the predicted output**\n","---\n","\n"]},{"cell_type":"code","metadata":{"id":"ypLeYWnzO6tv","cellView":"form"},"source":["# @markdown ##Run this cell to display a randomly chosen input and its corresponding predicted output.\n","\n","# This will display a randomly chosen dataset input and predicted output\n","random_choice = random.choice(os.listdir(Data_folder))\n","print(random_choice)\n","x = imread(Data_folder+\"/\"+random_choice)\n","\n","os.chdir(Result_folder)\n","y = imread(Result_folder+\"/\"+os.path.splitext(random_choice)[0]+'_detected'+file_suffix)\n","\n","plt.figure(figsize=(20,8))\n","\n","plt.subplot(1,3,1)\n","plt.axis('off')\n","plt.imshow(x, interpolation='nearest', cmap='gray')\n","plt.title('Input')\n","\n","plt.subplot(1,3,2)\n","plt.axis('off')\n","plt.imshow(y, interpolation='nearest')\n","plt.title('Predicted output');\n","\n","add_header('/content/predicted_bounding_boxes_names.csv','/content/predicted_bounding_boxes_names_new.csv')\n","\n","#We need to edit this predicted_bounding_boxes_new.csv file slightly to display the bounding boxes\n","df_bbox2 = pd.read_csv('/content/predicted_bounding_boxes_names_new.csv')\n","for img in range(0,df_bbox2.shape[0]):\n"," df_bbox2.iloc[img]\n"," row = pd.DataFrame(df_bbox2.iloc[img])\n"," if row[img][0] == random_choice:\n"," row = row.dropna()\n"," image = imageio.imread(Data_folder+'/'+row[img][0])\n"," #plt.figure(figsize=(12,12))\n"," plt.subplot(1,3,3)\n"," plt.axis('off')\n"," plt.title('Alternative Display of Prediction')\n"," plt.imshow(image, cmap='gray') # plot image\n","\n"," for i in range(1,int(len(row)-1),6):\n"," plt_rectangle(plt,\n"," label = row[img][i+5],\n"," x1=row[img][i],#.format(iplot)],\n"," y1=row[img][i+1],\n"," x2=row[img][i+2],\n"," y2=row[img][i+3])#,\n"," #fontsize=8)\n"," #plt.margins(0,0)\n"," #plt.subplots_adjust(left=0., right=1., top=1., bottom=0.)\n"," #plt.gca().xaxis.set_major_locator(plt.NullLocator())\n"," #plt.gca().yaxis.set_major_locator(plt.NullLocator())\n"," plt.savefig('/content/detected_cells.png',bbox_inches='tight',transparent=True,pad_inches=0)\n","plt.show() ## show the plot\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"hbfhHocc9FNq"},"source":["# **7. Version log**\n","---\n","**v1.13**: \n","\n","\n","* This version now includes an automatic restart allowing to set the h5py library to v2.10.\n","* The section 1 and 2 are now swapped for better export of *requirements.txt*.\n","* The YOLOv2 repository is no longer downloaded to the user's google drive but is saved to the content folder, consistent with other ZeroCostDL4Mic notebooks\n","\n","* This version also now includes built-in version check and the version log that you're reading now."]},{"cell_type":"markdown","metadata":{"id":"UvSlTaH14s3t"},"source":["\n","#**Thank you for using YOLOv2!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/fnet_2D_ZeroCostDL4Mic.ipynb b/Colab_notebooks/fnet_2D_ZeroCostDL4Mic.ipynb new file mode 100644 index 00000000..4e19629b --- /dev/null +++ b/Colab_notebooks/fnet_2D_ZeroCostDL4Mic.ipynb @@ -0,0 +1 @@ +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"fnet_2D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1SccQbvLCKwh2BgUTvdHPd_aWW9GRxQH4","timestamp":1622636347375},{"file_id":"1SisekHpRSJ0QKHvDePqFe09lkklVytwI","timestamp":1622479180098},{"file_id":"12UsRdIQbcWQjYewI2wrcwIWfVxc6hOfc","timestamp":1620660071757},{"file_id":"1A26cn0nxWQCv-LuP3UBfyCWlKBGIo0RU","timestamp":1611063104553},{"file_id":"1MmLTCC0nyX3Akb9V4C_OVxM3X_M8u-eX","timestamp":1610543191319},{"file_id":"1paNjUObR5Rcr4BMGADJTz0PQBBLZDPrY","timestamp":1602522500580},{"file_id":"1WZRIoSBNcRUEq4-Rq5M4mDkIaOlEHnxz","timestamp":1588762142860},{"file_id":"10weAY0es-pEfHlACCaBCKK7PmgdoJqdh","timestamp":1587728072051},{"file_id":"10Ze0rFZoooyyTL_OIVWGdFJEhWE6_cSB","timestamp":1586789421439},{"file_id":"1SsGyUbWcMaLGHFepMuKElRNYLdEBUwf6","timestamp":1583244509550}],"collapsed_sections":[],"machine_shape":"hm"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.7"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"IkSguVy8Xv83"},"source":["#**Label-free prediction - fnet**\n","---\n","\n"," \n","Label-free prediction (fnet) is a neural network developped to infer the distribution of specific cellular structures from label-free images such as brightfield or EM images. It was first published in 2018 by [Ounkomol *et al.* in Nature Methods](https://www.nature.com/articles/s41592-018-0111-2). The network uses a common U-Net architecture and is trained using paired imaging volumes from the same field of view, imaged in a label-free (e.g. brightfield) and labelled condition (e.g. fluorescence images of a specific label of interest). When trained, this allows the user to identify certain structures from brightfield images alone. The performance of fnet may depend significantly on the structure at hand.\n","\n","---\n"," *Disclaimer*:\n","\n"," This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n"," This notebook is largely based on the paper: \n","\n","**Label-free prediction of three-dimensional fluorescence images from transmitted light microscopy** by Ounkomol *et al.* in Nature Methods, 2018 (https://www.nature.com/articles/s41592-018-0111-2)\n","\n"," And source code found in: /~https://github.com/AllenCellModeling/pytorch_fnet\n","\n"," **Please also cite this original paper when using or developing this notebook.** \n"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"gKDLkLWUd-YX"},"source":["#**0. Before getting started**\n","---\n","**Data Format**\n","\n"," **The data used to train fnet (2D) must be 2D images in .tiff (.tif) file format and contain the signal (e.g. bright-field image) and the target channel (e.g. fluorescence) for each field of view**. To use this notebook on user data, upload the data in the following format to your google drive. To ensure corresponding images are used during training give corresponding signal and target images the same name.\n","\n","Information on how to generate a training dataset is available in our Wiki page: /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n"," **Note: Your *dataset_folder* should not have spaces or brackets in its name as this is not recognized by the fnet code and will throw an error** \n","\n","\n","* Experiment A\n"," - **Training dataset**\n"," - bright-field images\n"," - img_1.tif, img_2.tif, ...\n"," - fluorescence images\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - bright-field images\n"," - img_1.tif, img_2.tif\n"," - fluorescence images\n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"AdN8B91xZO0x"},"source":["#**1. Install fnet and dependencies**\n","---\n","Running fnet requires the fnet folder to be downloaded into the session's Files. As fnet needs several packages to be installed, this step may take a few minutes.\n","\n","You can ignore **the error warnings** as they refer to packages not required for this notebook.\n","\n","**Note: It is not necessary to keep the pytorch_fnet folder after you are finished using the notebook, so it can be deleted afterwards by playing the last cell (bottom).**"]},{"cell_type":"markdown","metadata":{"id":"6d8soOdxtp3z"},"source":["## **1.1. Install key dependencies**\n","---"]},{"cell_type":"code","metadata":{"id":"btZhzA1oswmW","cellView":"form"},"source":["#@markdown ##Install fnet and dependencies\n","\n","!pip install fpdf\n","!pip install -U scipy==1.2.0\n","!pip install tifffile==2019.7.26\n","!pip install matplotlib==2.2.3\n","\n","#Force session restart\n","exit(0)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"HjOV3wm_tvOJ"},"source":["## **1.2. Restart your runtime**\n","---\n","\n","\n","\n","** Ignore the following message error message. Your Runtime has automatically restarted. This is normal.**\n","\n","\"\"
"]},{"cell_type":"markdown","metadata":{"id":"ZNJ4RSnZtxpM"},"source":["## **1.3. Load key dependencies**\n","---\n"," "]},{"cell_type":"code","metadata":{"id":"fq21zJVFNASx","cellView":"form"},"source":["Notebook_version = '1.13'\n","Network = 'fnet (2D)'\n","\n","from builtins import any as b_any\n","\n","def get_requirements_path():\n"," # Store requirements file in 'contents' directory \n"," current_dir = os.getcwd()\n"," dir_count = current_dir.count('/') - 1\n"," path = '../' * (dir_count) + 'requirements.txt'\n"," return path\n","\n","def filter_files(file_list, filter_list):\n"," filtered_list = []\n"," for fname in file_list:\n"," if b_any(fname.split('==')[0] in s for s in filter_list):\n"," filtered_list.append(fname)\n"," return filtered_list\n","\n","def build_requirements_file(before, after):\n"," path = get_requirements_path()\n","\n"," # Exporting requirements.txt for local run\n"," !pip freeze > $path\n","\n"," # Get minimum requirements file\n"," df = pd.read_csv(path, delimiter = \"\\n\")\n"," mod_list = [m.split('.')[0] for m in after if not m in before]\n"," req_list_temp = df.values.tolist()\n"," req_list = [x[0] for x in req_list_temp]\n","\n"," # Replace with package name and handle cases where import name is different to module name\n"," mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]\n"," mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list] \n"," filtered_list = filter_files(req_list, mod_replace_list)\n","\n"," file=open(path,'w')\n"," for item in filtered_list:\n"," file.writelines(item + '\\n')\n","\n"," file.close()\n","\n","import sys\n","before = [str(m) for m in sys.modules]\n","\n","# !pip install fpdf\n","# !pip install -U scipy==1.2.0\n","# !pip install tifffile==2019.7.26\n","# !pip install matplotlib==2.2.3\n","#@markdown ##Load key dependencies\n","import os\n","import csv\n","import shutil\n","import random\n","from tempfile import mkstemp\n","from shutil import move, copymode\n","from os import fdopen, remove\n","import numpy as np\n","import shutil\n","from tempfile import mkstemp\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from skimage import img_as_float32\n","from distutils.dir_util import copy_tree\n","from datetime import datetime\n","from astropy.visualization import simple_norm\n","import time\n","from fpdf import FPDF, HTMLMixin\n","from pip._internal.operations.freeze import freeze\n","import subprocess\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","\n","#Ensure tensorflow 1.x\n","%tensorflow_version 1.x\n","import tensorflow\n","print(tensorflow.__version__)\n","\n","print(\"Tensorflow enabled.\")\n","\n","#clone fnet from github to colab\n","!git clone -b release_1 --single-branch /~https://github.com/AllenCellModeling/pytorch_fnet.git; cd pytorch_fnet; pip install .\n","\n","from skimage import io\n","from matplotlib import pyplot as plt\n","import pandas as pd\n","import matplotlib as mpl\n","\n","\n","def replace(file_path, pattern, subst):\n"," \"\"\"Function replaces a pattern in a .py file with a new pattern.\"\"\"\n"," \n"," \"\"\"Parameters:\n"," -file_path (string): path to the file to be changed.\n"," -pattern (string): pattern to be replaced. Make sure this is as unique as possible.\n"," -subst (string): new pattern. \"\"\"\n"," #Create temp file\n"," fh, abs_path = mkstemp()\n"," with fdopen(fh,'w') as new_file:\n"," with open(file_path) as old_file:\n"," for line in old_file:\n"," new_file.write(line.replace(pattern, subst))\n"," #Copy the file permissions from the old file to the new file\n"," copymode(file_path, abs_path)\n"," #Remove original file\n"," remove(file_path)\n"," #Move new file\n"," move(abs_path, file_path)\n","\n","def add_insert(filepath,line_number,insertion,append):\n"," \"\"\"Function which inserts the a line into a document.\"\"\"\n"," \n"," \"\"\"Parameters:\n"," -filepath (string): path to the file which needs to be edited.\n"," -line (integer): Where to insert the new line. In the file, this line is ideally an empty one.\n"," -insertion (string): The line to be inserted. If it already exists it will not be added again.\n"," -append (string): If anything additional needs to be appended to the line, use this. Otherwise, leave as \"\" \"\"\"\n"," \n"," f = open(filepath, \"r\")\n"," contents = f.readlines()\n"," f.close()\n"," f = open(filepath, \"r\")\n"," if not insertion in f.read():\n"," contents.insert(line_number, insertion)\n"," if append != \"\":\n"," contents.append(append)\n"," f.close()\n"," f = open(filepath, \"w\")\n"," contents = \"\".join(contents)\n"," f.write(contents)\n"," f.close()\n","\n","\n","def convert_to_script_compatible_path(original_path):\n"," \"\"\"Function which converts 'original_path' into a compatible format 'new_full_path' with the fnet .sh files \"\"\"\n"," new_full_path = \"\"\n"," for s in original_path:\n"," if s=='/':\n"," new_full_path += '\\/'\n"," else:\n"," new_full_path += s\n","\n"," return new_full_path\n","\n","\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","#Here we replace values in the old files\n","#Change maximum pixel number\n","replace(\"/content/pytorch_fnet/fnet/transforms.py\",'n_max_pixels=9732096','n_max_pixels=20000000')\n","replace(\"/content/pytorch_fnet/predict.py\",'6000000','20000000')\n","\n","#Prevent resizing in the training and the prediction\n","replace(\"/content/pytorch_fnet/predict.py\",\"0.37241\",\"1.0\")\n","replace(\"/content/pytorch_fnet/train_model.py\",\"0.37241\",\"1.0\")\n","\n","#Change the default dataset type in the training module to .tif\n","replace(\"/content/pytorch_fnet/train_model.py\",\"'--class_dataset', default='CziDataset'\",\"'--class_dataset', default='TiffDataset'\")\n","\n","\n","\n","#2D \n","\n","replace(\"/content/pytorch_fnet/train_model.py\",\"default=[32, 64, 64]\",\"default=[128, 128]\")\n","replace(\"/content/pytorch_fnet/train_model.py\",\"'--nn_module', default='fnet_nn_3d'\",\"'--nn_module', default='fnet_nn_2d'\")\n","\n","replace(\"/content/pytorch_fnet/train_model.py\",\", default_resizer_str]\",\"]\")\n","#replace(\"/content/pytorch_fnet/predict.py\",\", default_resizer_str]\",\"]\")\n","\n","\n","\n","print(\"-------------------\")\n","print(\"Libraries installed\")\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","\n","\n","# Check if this is the latest version of the notebook\n","All_notebook_versions = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_Notebook_versions.csv\", dtype=str)\n","print('Notebook version: '+Notebook_version)\n","Latest_Notebook_version = All_notebook_versions[All_notebook_versions[\"Notebook\"] == Network]['Version'].iloc[0]\n","print('Latest notebook version: '+Latest_Notebook_version)\n","if Notebook_version == Latest_Notebook_version:\n"," print(\"This notebook is up-to-date.\")\n","else:\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'Label-free Prediction (fnet)'\n"," day = datetime.now()\n"," date_time = str(day)[0:10]\n","\n"," Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+date_time\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," \n"," # add another cell \n"," if trained:\n"," training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n"," pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n"," pdf.ln(1)\n","\n"," Header_2 = 'Information for your materials and methods:'\n"," pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n"," #print(all_packages)\n","\n"," #Main Packages\n"," main_packages = ''\n"," version_numbers = []\n"," for name in ['tensorflow','numpy','torch','scipy']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n"," cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n"," cuda_version = cuda_version.stdout.decode('utf-8')\n"," cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n"," gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n"," gpu_name = gpu_name.stdout.decode('utf-8')\n"," gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n","\n"," shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n"," dataset_size = len(os.listdir(Training_source))\n","\n"," text = 'The '+Network+' model was trained from scratch for '+str(steps)+' steps on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(32)+','+str(64)+','+str(64)+')) with a batch size of '+str(batch_size)+' and an MSE loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), torch (v '+version_numbers[2]+'), scipy (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," #text = 'The '+Network+' model ('+model_name+') was trained using '+str(dataset_size)+' paired images (image dimensions: '+str(shape)+') using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), torch (v '+version_numbers[2]+'), scipy (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The GPU used was a '+gpu_name+'.'\n","\n"," #if Use_pretrained_model:\n"," # text = 'The '+Network+' model ('+model_name+') was trained using '+str(dataset_size)+' paired images (image dimensions: '+str(shape)+') using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), pytorch (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The GPU used was a '+gpu_name+'. The model was trained from the pretrained model: '+pretrained_model_path+'.'\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.multi_cell(190, 5, txt = text, align='L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n"," pdf.set_font('')\n"," if augmentation:\n"," aug_text = 'The dataset was augmented by'\n"," if Rotation:\n"," aug_text = aug_text+'\\n- rotation'\n"," if Flip:\n"," aug_text = aug_text+'\\n- flipping'\n"," else:\n"," aug_text = 'No augmentation was used for training.'\n"," pdf.multi_cell(190, 5, txt=aug_text, align='L')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," # if Use_Default_Advanced_Parameters:\n"," # pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n"," pdf.cell(200, 5, txt='The following parameters were used for training:')\n"," pdf.ln(1)\n"," html = \"\"\" \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
percentage_validation{0}
steps{1}
batch_size{2}
\n"," \"\"\".format(percentage_validation,steps,batch_size)\n"," pdf.write_html(html)\n","\n"," #pdf.multi_cell(190, 5, txt = text_2, align='L')\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," pdf.ln(1)\n"," pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(30, 5, txt= 'Training_source:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(28, 5, txt= 'Training_target:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n"," #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(22, 5, txt= 'Model Path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n"," pdf.ln(1)\n"," pdf.cell(60, 5, txt = 'Example Training pair (single slice)', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread('/content/TrainingDataExample_Fnet.png').shape\n"," pdf.image('/content/TrainingDataExample_Fnet.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- Label-free prediction (fnet): Ounkomol, Chawin, et al. \"Label-free prediction of three-dimensional fluorescence images from transmitted-light microscopy.\" Nature methods 15.11 (2018): 917-920.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," pdf.ln(3)\n"," reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," if trained:\n"," pdf.output(model_path+'/'+model_name+'/'+model_name+'_'+date_time+\"_training_report.pdf\")\n"," else:\n"," pdf.output('/content/'+model_name+'_'+date_time+\"_training_report.pdf\")\n"," \n","\n","\n","def qc_pdf_export():\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'Label-free prediction (fnet)'\n"," #model_name = os.path.basename(QC_model_folder)\n"," day = datetime.now()\n"," date_time = str(day)[0:10]\n","\n"," Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+date_time\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(2)\n"," pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n"," if os.path.exists(full_QC_model_path+'/QualityControl/lossCurvePlots.png'):\n"," exp_size = io.imread(full_QC_model_path+'/QualityControl/lossCurvePlots.png').shape\n"," pdf.image(full_QC_model_path+'/QualityControl/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," # pdf.ln(3)\n"," pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.')\n"," pdf.ln(3)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(3)\n"," pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/QualityControl/QC_example_data.png').shape\n"," pdf.image(full_QC_model_path+'/QualityControl/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n","\n"," pdf.ln(1)\n"," html = \"\"\"\n"," \n"," \n"," \"\"\"\n"," with open(full_QC_model_path+'/QualityControl/QC_metrics_'+QC_model_name+'.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," mSSIM_PvsGT = header[1]\n"," NRMSE_PvsGT = header[2]\n"," PSNR_PvsGT = header[3]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,mSSIM_PvsGT,NRMSE_PvsGT,PSNR_PvsGT)\n"," html = html+header\n"," for row in metrics:\n"," image = row[0]\n"," mSSIM_PvsGT = row[1]\n"," NRMSE_PvsGT = row[2]\n"," PSNR_PvsGT = row[3]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,str(round(float(mSSIM_PvsGT),3)),str(round(float(NRMSE_PvsGT),3)),str(round(float(PSNR_PvsGT),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}{3}
{0}{1}{2}{3}
\"\"\"\n"," \n"," pdf.write_html(html)\n","\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- Label-free prediction (fnet): Ounkomol, Chawin, et al. \"Label-free prediction of three-dimensional fluorescence images from transmitted-light microscopy.\" Nature methods 15.11 (2018): 917-920.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n"," pdf.ln(3)\n"," reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(full_QC_model_path+'/QualityControl/'+QC_model_name+'_QC_report.pdf')\n","\n","# Build requirements file for local run\n","after = [str(m) for m in sys.modules]\n","build_requirements_file(before, after)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **2. Initialise the Colab session**\n","\n","\n","\n","\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["\n","## **2.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelerator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"zCvebubeSaGY","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"sNIVx8_CLolt"},"source":["## **2.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","cellView":"form"},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hT5ZbFGyfae0"},"source":["** If you cannot see your files, reactivate your session by connecting to your hosted runtime.** \n","\n","\n","\"Example
Connect to a hosted runtime.
"]},{"cell_type":"markdown","metadata":{"id":"HLYcZR9gMv42"},"source":["# **3. Select your parameters and paths**\n","---"]},{"cell_type":"markdown","metadata":{"id":"FQ_QxtSWQ7CL"},"source":["## **3.1. Setting the main training parameters**\n","---\n"," **Paths for training data**\n","\n"," **`Training_source`,`Training_target:`** These are the paths to your folders containing the Training_source (brightfield) and Training_target (fluorescent label) training data respectively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n"," **Training Parameters**\n","\n"," **`percentage validation`** Input the percentage of your training dataset you want to use to validate the network during training. **Default value: 10** \n","\n"," **`steps:`** Input how many iterations you want to train the network for. A larger number may improve performance but risks overfitting to the training data. To reach good performance of fnet requires several 10000's iterations which will usually require **several hours**, depending on the dataset size. **Default: 50000**\n","\n","**`batch_size:`** Reducing or increasing the **batch size** may speed up or slow down your training, respectively and can influence network performance. **Default: 64**\n","\n","**Note 2: If you only need to retrain your model after a time-out, skip this cell and go straight to section 4.2. Just make sure your training datasets are still in their original folders.**"]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","cellView":"form"},"source":["#@markdown ###Datasets\n","#Datasets\n","\n","#Change checkpoints\n","replace(\"/content/pytorch_fnet/train_model.py\",\"'--interval_save', type=int, default=500\",\"'--interval_save', type=int, default=100\")\n","\n","#Adapt Class Dataset for Tiff files\n","replace(\"/content/pytorch_fnet/train_model.py\",\"'--class_dataset', default='CziDataset'\",\"'--class_dataset', default='TiffDataset'\")\n","\n","\n","### Choosing and editing the path names ###\n","\n","#Fetch the path and extract the name of the signal folder\n","Training_source = \"\" #@param {type: \"string\"}\n","source_name = os.path.basename(os.path.normpath(Training_source))\n","\n","#Fetch the path and extract the name of the signal folder\n","Training_target = \"\" #@param {type: \"string\"}\n","target_name = os.path.basename(os.path.normpath(Training_target))\n","\n","#@markdown ###Model name and model path\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","full_model_path = model_path+'/'+model_name\n","\n","new_full_model_path = convert_to_script_compatible_path(full_model_path)\n","new_full_model_path_csv = new_full_model_path+'\\/'+model_name+'\\.csv'\n","new_full_model_path_val_csv = new_full_model_path+'\\/'+model_name+'_val\\.csv'\n","\n","#here we check that no model with the same name already exist, if so print a warning\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n"," print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\")\n","#@markdown ---\n","\n","#@markdown ###Training Parameters\n","\n","percentage_validation = 10#@param{type:\"number\"}\n","\n","source = os.listdir(Training_source)\n","target = os.listdir(Training_target)\n","number_of_images = len(source[:-round(len(source)*(percentage_validation/100))])\n","\n","\n","model_name_x = model_name+\"}\" # this variable is only used to ensure closed curly brackets when editing the .sh files\n","\n","\n","### Edit the train.sh script file and train.py file ###\n","\n","#We need to declare that we will run validation on the dataset\n","#We need to add a new line to the train.sh file\n","with open(\"/content/pytorch_fnet/scripts/train_model.sh\", \"r\") as f:\n"," if not \"gpu_ids ${GPU_IDS} \\\\\" in f.read():\n"," replace(\"/content/pytorch_fnet/scripts/train_model.sh\",\" --gpu_ids ${GPU_IDS}\",\" --gpu_ids ${GPU_IDS} \\\\\")\n","\n","#We add the necessary validation parameters here.\n","with open('/content/pytorch_fnet/scripts/train_model.sh','r') as scriptfile:\n"," lines = scriptfile.readlines()\n"," if 'PATH_DATASET_VAL_CSV' not in lines:\n"," insert = 'PATH_DATASET_VAL_CSV=\"data/csvs/${DATASET}_val.csv\"'\n"," add_insert(\"/content/pytorch_fnet/scripts/train_model.sh\",10,insert,\"\")\n"," add_insert(\"/content/pytorch_fnet/scripts/train_model.sh\",22,'\\n --path_dataset_val_csv ${PATH_DATASET_VAL_CSV}',\"\")\n"," \n","#Clear the White space from train.sh\n","with open('/content/pytorch_fnet/scripts/train_model.sh', 'r') as inFile,\\\n"," open('/content/pytorch_fnet/scripts/train_model_temp.sh', 'w') as outFile:\n"," for line in inFile:\n"," if line.strip():\n"," outFile.write(line)\n","os.remove('/content/pytorch_fnet/scripts/train_model.sh')\n","os.rename('/content/pytorch_fnet/scripts/train_model_temp.sh','/content/pytorch_fnet/scripts/train_model.sh')\n","\n","# #Here we define the random set of training files to be used for validation\n","val_files = source[-round(len(source)*(percentage_validation/100)):]\n","source_files = source[:-round(len(source)*(percentage_validation/100))]\n","\n","# #Finally, we create a validation csv file to construct the validation dataset\n","# with open(model_path+'/'+model_name+'/'+model_name+'_val.csv', 'w', newline='') as file:\n","# writer = csv.writer(file)\n","# writer.writerow([\"path_signal\",\"path_target\"])\n","# for i in range(0,len(val_files)):\n","# writer.writerow([Training_source+'/'+val_files[i],Training_target+'/'+val_files[i]])\n","\n","\n","# with open(model_path+'/'+model_name+'/'+model_name+'.csv', 'w', newline='') as file:\n","# writer = csv.writer(file)\n","# writer.writerow([\"path_signal\",\"path_target\"])\n","# for i in range(0,len(source_files)):\n","# writer.writerow([Training_source+\"/\"+source_files[i],Training_target+\"/\"+source_files[i]])\n","\n","\n","#Training parameters in fnet are indicated in the train_model.sh file.\n","#Here, we edit this file to include the desired parameters\n","\n","#1. Add permissions to train_model.sh\n","os.chdir(\"/content/pytorch_fnet/scripts\")\n","!chmod u+x train_model.sh\n","\n","#2. Select parameters\n","steps = 50000#@param {type:\"number\"}\n","batch_size = 64#@param {type:\"number\"}\n","# learning_rate = 0.0001\n","# patch_size = 256\n","\n","number_of_images = len(source_files)\n","\n","#3. Insert the above values into train_model.sh\n","!if ! grep saved_models\\/\\${ train_model.sh;then sed -i 's/saved_models\\/.*/saved_models\\/\\${DATASET}\"/g' train_model.sh; fi \n","!sed -i \"s/1:-.*/1:-$model_name_x/g\" train_model.sh #change the dataset to be trained with\n","!sed -i \"s/N_ITER=.*/N_ITER=$steps/g\" train_model.sh #change the number of training iterations (steps)\n","!sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" train_model.sh #change the number of training images\n","!sed -i \"s/BATCH_SIZE=.*/BATCH_SIZE=$batch_size/g\" train_model.sh #change the batch size\n","\n","!sed -i 's/RUN_DIR=.*/RUN_DIR=\"$new_full_model_path\"/g' train_model.sh\n","!sed -i 's/PATH_DATASET_TRAIN_CSV=.*/PATH_DATASET_TRAIN_CSV=\"$new_full_model_path_csv\"/g' train_model.sh\n","!sed -i 's/PATH_DATASET_VAL_CSV=.*/PATH_DATASET_VAL_CSV=\"$new_full_model_path_val_csv\"/g' train_model.sh\n","\n","#We also change the training split as in our notebook the test images are used separately for prediction and we want fnet to train on the whole training data set.\n","\n","replace('/content/pytorch_fnet/scripts/train_model.sh','python scripts', '#python scripts')\n","replace('/content/pytorch_fnet/scripts/train_model.sh','python train_model.py', 'python /content/pytorch_fnet/train_model.py')\n","replace('/content/pytorch_fnet/scripts/train_model.sh','PATH_DATASET_ALL_CSV','#PATH_DATASET_ALL_CSV')\n","\n","#No Augmentation by default\n","Use_Data_augmentation = False\n","\n","#Load one randomly chosen training source file\n","random_choice=random.choice(os.listdir(Training_source))\n","x = io.imread(Training_source+\"/\"+random_choice)\n","\n","os.chdir(Training_target)\n","y = io.imread(Training_target+\"/\"+random_choice)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, norm=simple_norm(x, percent = 99), interpolation='nearest')\n","plt.axis('off')\n","plt.title('Training Source (single Z plane)');\n","plt.subplot(1,2,2)\n","plt.imshow(y, norm=simple_norm(y, percent = 99), interpolation='nearest')\n","plt.axis('off')\n","plt.title('Training Target (single Z plane)');\n","plt.savefig('/content/TrainingDataExample_Fnet.png',bbox_inches='tight',pad_inches=0)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xyQZKby8yFME"},"source":["## **3.2. Data augmentation**\n","---\n",""]},{"cell_type":"markdown","metadata":{"id":"w_jCy7xOx2g3"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n","Data augmentation is performed here by rotating images in XY-Plane and flip them along X-Axis. This only works if the patches are square in XY.\n","\n","**Note:** Using a full augmented dataset can exceed the RAM limitations of the colab notebook. If the augmented dataset is too large, the notebook will therefore only pick a subset of the augmented dataset for training. Make sure you only augment datasets which are small (ca. 20-30 images).\n","\n","**Note 2: If you intend to use the retraining option at a later point, save the dataset in a folder in your Google Drive.**"]},{"cell_type":"code","metadata":{"id":"DMqWq5-AxnFU","cellView":"form"},"source":["from skimage import io\n","import numpy as np\n","\n","Use_Data_augmentation = False #@param{type:\"boolean\"}\n","\n","#@markdown Select this option if you want to use augmentation to increase the size of your dataset\n","\n","#@markdown **Rotate each image 3 times by 90 degrees.**\n","Rotation = True #@param{type:\"boolean\"}\n","\n","#@markdown **Flip each image once around the x axis of the stack.**\n","Flip = True #@param{type:\"boolean\"}\n","\n","\n","#@markdown **Would you like to save your augmented images?**\n","\n","Save_augmented_images = False #@param {type:\"boolean\"}\n","\n","Saving_path = \"\" #@param {type:\"string\"}\n","\n","\n","if not Save_augmented_images:\n"," Saving_path= \"/content\"\n","\n","\n","def rotation_aug(Source_path, Target_path, aug_source_dest='augmented_source', aug_target_dest='augmented_target', flip=False):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path)\n"," \n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," # Source Rotation\n"," source_img_90 = np.rot90(source_img,axes=(0,1))\n"," source_img_180 = np.rot90(source_img_90,axes=(0,1))\n"," source_img_270 = np.rot90(source_img_180,axes=(0,1))\n","\n"," # Target Rotation\n"," target_img_90 = np.rot90(target_img,axes=(0,1))\n"," target_img_180 = np.rot90(target_img_90,axes=(0,1))\n"," target_img_270 = np.rot90(target_img_180,axes=(0,1))\n","\n"," # Add a flip to the rotation\n"," \n"," if flip == True:\n"," source_img_lr = np.fliplr(source_img)\n"," source_img_90_lr = np.fliplr(source_img_90)\n"," source_img_180_lr = np.fliplr(source_img_180)\n"," source_img_270_lr = np.fliplr(source_img_270)\n","\n"," target_img_lr = np.fliplr(target_img)\n"," target_img_90_lr = np.fliplr(target_img_90)\n"," target_img_180_lr = np.fliplr(target_img_180)\n"," target_img_270_lr = np.fliplr(target_img_270)\n","\n"," #source_img_90_ud = np.flipud(source_img_90)\n"," \n"," # Save the augmented files\n"," # Source images\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+image,source_img)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_90.tif',source_img_90)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_180.tif',source_img_180)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_270.tif',source_img_270)\n"," # Target images\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+image,target_img)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_90.tif',target_img_90)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_180.tif',target_img_180)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_270.tif',target_img_270)\n","\n"," if flip == True:\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_90_lr.tif',source_img_90_lr)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_180_lr.tif',source_img_180_lr)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_270_lr.tif',source_img_270_lr)\n","\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_90_lr.tif',target_img_90_lr)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_180_lr.tif',target_img_180_lr)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_270_lr.tif',target_img_270_lr)\n","\n","def flip(Source_path, Target_path, aug_source_dest='augmented_source', aug_target_dest='augmented_target'):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path) \n","\n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," source_img_lr = np.fliplr(source_img)\n"," target_img_lr = np.fliplr(target_img)\n","\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+image,source_img)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n","\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+image,target_img)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n","\n","\n","if Use_Data_augmentation:\n","\n"," if os.path.exists(Saving_path+'/augmented_source'):\n"," shutil.rmtree(Saving_path+'/augmented_source')\n"," os.mkdir(Saving_path+'/augmented_source')\n","\n"," if os.path.exists(Saving_path+'/augmented_target'):\n"," shutil.rmtree(Saving_path+'/augmented_target') \n"," os.mkdir(Saving_path+'/augmented_target')\n"," \n"," print(\"Data augmentation enabled\")\n"," print(\"Data augmentation in progress....\")\n","\n"," if Rotation == True:\n"," rotation_aug(Training_source,Training_target,flip=Flip)\n"," elif Rotation == False and Flip == True:\n"," flip(Training_source,Training_target)\n","\n"," #Fetch the path and extract the name of the signal folder\n"," Training_source = Saving_path+\"/augmented_source\"\n"," source_name = os.path.basename(os.path.normpath(Training_source))\n","\n"," #Fetch the path and extract the name of the target folder\n"," Training_target = Saving_path+\"/augmented_target\"\n"," target_name = os.path.basename(os.path.normpath(Training_target))\n","\n"," \n"," # aug_source = os.listdir(Saving_path+'/augmented_source')\n"," # aug_val_files = aug_source[-round(len(aug_source)*(percentage_validation/100)):]\n"," # aug_source_files = aug_source[:-round(len(aug_source)*(percentage_validation/100))]\n","\n"," # #Finally, we create a validation csv file to construct the validation dataset\n"," # with open(model_path+'/'+model_name+'/'+model_name+'_val.csv', 'w', newline='') as file:\n"," # writer = csv.writer(file)\n"," # writer.writerow([\"path_signal\",\"path_target\"])\n"," # for i in range(0,len(aug_val_files)):\n"," # writer.writerow([Saving_path+'/augmented_source/'+aug_val_files[i],Saving_path+\"/augmented_target/\"+aug_val_files[i]])\n","\n"," # with open(model_path+'/'+model_name+'/'+model_name+'.csv', 'w', newline='') as file:\n"," # writer = csv.writer(file)\n"," # writer.writerow([\"path_signal\",\"path_target\"])\n"," # for i in range(0,len(aug_source_files)):\n"," # writer.writerow([Saving_path+'/augmented_source/'+aug_source_files[i],Saving_path+'/augmented_target/'+aug_source_files[i]])\n","\n"," #Here, we ensure that there aren't too many images in the buffer.\n"," #The best value will depend on the size of the images and the assigned GPU.\n"," #If too many images are loaded to the buffer the notebook will terminate the training as the RAM limit will be exceeded.\n"," if len(aug_source)>100:\n"," number_of_images = 100\n"," else:\n"," number_of_images = len(aug_source)\n","\n"," os.chdir(\"/content/pytorch_fnet/scripts\")\n"," !chmod u+x train_model.sh\n"," !sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" train_model.sh #change the number of training images\n","\n"," print(\"Done\")\n","if not Use_Data_augmentation:\n"," print(bcolors.WARNING+\"Data augmentation disabled\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Nyf9ndiS7sL9"},"source":["#**4. Train the network**\n","---\n","Before training, carefully read the different options. This applies especially if you have trained fnet on a dataset before.\n","\n","\n","###**Choose one of the options to train fnet**.\n","\n","**4.1.** If this is the first training on the chosen dataset, play this section to start training.\n","\n","**4.2.** If you want to continue training on an already pre-trained model choose this section\n","\n"," **Carefully read the options before starting training.**"]},{"cell_type":"markdown","metadata":{"id":"P9OJ0nlI71Rc"},"source":["##**4.1. Start Training**\n","---\n","\n","####Play the cell below to start training. \n","\n","**Note:** If you are training with a model of the same name as before, the model will be overwritten. If you want to keep the previous model save it before playing the cell below or give your model a different name (section 3).\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"code","metadata":{"cellView":"form","id":"l1hL7H4hVizI"},"source":["#@markdown ##Create the dataset files for training\n","if os.path.exists(model_path+'/'+model_name):\n"," shutil.rmtree(model_path+'/'+model_name)\n"," print(bcolors.WARNING +\"!! Existing model \"+model_name+\" was found and overwritten!!\")\n","os.mkdir(model_path+'/'+model_name)\n","\n","#os.chdir(model_path)\n","# source = os.listdir(Training_source)\n","# target = os.listdir(Training_target)\n","\n","if Use_Data_augmentation == True:\n","\n"," aug_source = os.listdir(Saving_path+'/augmented_source')\n"," aug_val_files = aug_source[-round(len(aug_source)*(percentage_validation/100)):]\n"," aug_source_files = aug_source[:-round(len(aug_source)*(percentage_validation/100))]\n","\n"," #Finally, we create a validation csv file to construct the validation dataset\n"," with open(model_path+'/'+model_name+'/'+model_name+'_val.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(aug_val_files)):\n"," writer.writerow([Saving_path+'/augmented_source/'+aug_val_files[i],Saving_path+\"/augmented_target/\"+aug_val_files[i]])\n","\n"," with open(model_path+'/'+model_name+'/'+model_name+'.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(aug_source_files)):\n"," writer.writerow([Saving_path+'/augmented_source/'+aug_source_files[i],Saving_path+'/augmented_target/'+aug_source_files[i]])\n","\n","else:\n"," #Here we define the random set of training files to be used for validation\n"," val_files = source[-round(len(source)*(percentage_validation/100)):]\n"," source_files = source[:-round(len(source)*(percentage_validation/100))]\n","\n"," #Finally, we create a validation csv file to construct the validation dataset\n"," with open(model_path+'/'+model_name+'/'+model_name+'_val.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(val_files)):\n"," writer.writerow([Training_source+'/'+val_files[i],Training_target+'/'+val_files[i]])\n","\n","\n"," with open(model_path+'/'+model_name+'/'+model_name+'.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(source_files)):\n"," writer.writerow([Training_source+\"/\"+source_files[i],Training_target+\"/\"+source_files[i]])"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"X8YHeSGr76je","cellView":"form"},"source":["#@markdown ####If your dataset is large the notebook might crash unexpectedly when loading the training data into the buffer. If this happens, reduce the number of images to be loaded into the buffer and restart the training.\n","os.chdir(\"/content/pytorch_fnet/scripts\")\n","number_of_images = 100#@param{type:\"number\"}\n","!chmod u+x train_model.sh\n","!sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" train_model.sh #change the number of training images"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"7Ofm-71T8ABX","cellView":"form"},"source":["#@markdown ##Start training\n","\n","pdf_export(augmentation = Use_Data_augmentation)\n","start = time.time()\n","\n","os.chdir('/content')\n","add_insert(\"/content/pytorch_fnet/fnet/functions.py\",5,\"import fnet.fnet_model\",\"\")\n","\n","### TRAIN THE MODEL ###\n","\n","print('Let''s start the training!')\n","#Here we start the training\n","!/content/pytorch_fnet/scripts/train_model.sh $model_name 0\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","#Create a pdf document with training summary\n","\n","pdf_export(trained = True, augmentation = Use_Data_augmentation)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"bOdyjxWV8IrO"},"source":["**Note:** Fnet takes a long time for training. If your notebook times out due to the length of the training or due to a loss of GPU acceleration the last checkpoint will be saved in the saved_models folder in the pytorch_fnet folder. If you want to save it in a more convenient location on your drive, remount the drive (if you got disconnected) and in the next cell enter the location (`model_path`) where you want to save the model (`model_name`) before continuing in 4.2. **If you did not time out you can ignore this section.**"]},{"cell_type":"markdown","metadata":{"id":"-JxxMmVr8Tw-"},"source":["##**4.2. Training from a previously saved model**\n","---\n","This section allows you to use networks you have previously trained and saved and to continue training them for more training steps. The folders have the same meaning as above (3.1.). If you want to save the previously trained model, create a copy now as this section will overwrite the weights of the old model. **You can currently only train the model with the same dataset and batch size that the network was previously trained on. Make sure your training datasets are in the same location as when you originally trained the model.**\n","\n","**Note: To use this section the *pytorch_fnet* folder must be in your *gdrive/My Drive*. (Simply, play cell 2. to make sure).**"]},{"cell_type":"code","metadata":{"id":"iDIgosht8U7F","cellView":"form"},"source":["#@markdown To test if performance improves after the initial training, you can continue training on the old model. This option can also be useful if Colab disconnects or times out.\n","#@markdown Enter the paths of the datasets you want to continue training on.\n","\n","#Here we repeat steps already used above in case the notebook needs to be restarted for this cell.\n","#We need to add a new line to the train.sh file\n","with open(\"/content/pytorch_fnet/scripts/train_model.sh\", \"r\") as f:\n"," if not \"gpu_ids ${GPU_IDS} \\\\\" in f.read():\n"," replace(\"/content/pytorch_fnet/scripts/train_model.sh\",\" --gpu_ids ${GPU_IDS}\",\" --gpu_ids ${GPU_IDS} \\\\\")\n","\n","#We add the necessary validation parameters here.\n","with open('/content/pytorch_fnet/scripts/train_model.sh','r') as scriptfile:\n"," lines = scriptfile.readlines()\n"," if 'PATH_DATASET_VAL_CSV' not in lines:\n"," insert = 'PATH_DATASET_VAL_CSV=\"data/csvs/${DATASET}_val.csv\"'\n"," add_insert(\"/content/pytorch_fnet/scripts/train_model.sh\",10,insert,\"\")\n"," add_insert(\"/content/pytorch_fnet/scripts/train_model.sh\",22,'\\n --path_dataset_val_csv ${PATH_DATASET_VAL_CSV}',\"\")\n","\n","\n","#Clear the White space from train.sh\n","with open('/content/pytorch_fnet/scripts/train_model.sh', 'r') as inFile,\\\n"," open('/content/pytorch_fnet/scripts/train_model_temp.sh', 'w') as outFile:\n"," for line in inFile:\n"," if line.strip():\n"," outFile.write(line)\n","os.remove('/content/pytorch_fnet/scripts/train_model.sh')\n","os.rename('/content/pytorch_fnet/scripts/train_model_temp.sh','/content/pytorch_fnet/scripts/train_model.sh')\n","\n","\n","#Change checkpoints\n","replace(\"/content/pytorch_fnet/train_model.py\",\"'--interval_save', type=int, default=500\",\"'--interval_save', type=int, default=100\")\n","\n","#Adapt Class Dataset for Tiff files\n","replace(\"/content/pytorch_fnet/train_model.py\",\"'--class_dataset', default='CziDataset'\",\"'--class_dataset', default='TiffDataset'\")\n","\n","### Choosing and editing the path names ###\n","\n","Pretrained_model_folder = \"\" #@param{type:\"string\"}\n","Pretrained_model_name = os.path.basename(Pretrained_model_folder)\n","Pretrained_model_path = os.path.dirname(Pretrained_model_folder)\n","\n","full_model_path = Pretrained_model_path+'/'+Pretrained_model_name\n","\n","new_full_model_path = convert_to_script_compatible_path(full_model_path)\n","new_full_model_path_csv = new_full_model_path+'\\/'+Pretrained_model_name+'\\.csv'\n","new_full_model_path_val_csv = new_full_model_path+'\\/'+Pretrained_model_name+'_val\\.csv'\n","\n","Pretrained_model_name_x = Pretrained_model_name+\"}\"\n","\n","#We get the example data and the number of images from the csv path file#\n","\n","with open(full_model_path+'/'+Pretrained_model_name+'.csv') as csvfile:\n"," csvreader = csv.reader(csvfile)\n"," header = next(csvreader)\n"," number_of_images = 0\n"," for line in csvreader:\n"," ExampleSource = line[0]\n"," ExampleTarget = line[1]\n"," number_of_images += 1\n","\n","with open(full_model_path+'/'+Pretrained_model_name+'_val.csv') as csvfile:\n"," csvreader = csv.reader(csvfile)\n"," header = next(csvreader)\n"," number_of_val_images = 0\n"," for line in csvreader:\n"," number_of_val_images += 1\n","\n","#Batch Size\n","\n","batch_size = 64 #@param {type:\"number\"}\n","\n","# Editing the train.sh script file #\n","\n","os.chdir('/content/pytorch_fnet/scripts')\n","\n","#Change the train_model.sh file to include chosen dataset\n","!chmod u+x ./train_model.sh\n","!sed -i \"s/1:-.*/1:-$Pretrained_model_name_x/g\" train_model.sh\n","!sed -i \"s/train_size .* -v/train_size 1.0 -v/g\" train_model.sh #Use the whole training dataset for training\n","!sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" train_model.sh #change the number of training images\n","!sed -i \"s/BATCH_SIZE=.*/BATCH_SIZE=$batch_size/g\" train_model.sh #change the batch size\n","\n","!sed -i 's/RUN_DIR=.*/RUN_DIR=\"$new_full_model_path\"/g' train_model.sh\n","!sed -i 's/PATH_DATASET_TRAIN_CSV=.*/PATH_DATASET_TRAIN_CSV=\"$new_full_model_path_csv\"/g' train_model.sh\n","!sed -i 's/PATH_DATASET_VAL_CSV=.*/PATH_DATASET_VAL_CSV=\"$new_full_model_path_val_csv\"/g' train_model.sh\n","\n","replace('/content/pytorch_fnet/scripts/train_model.sh','python scripts', '#python scripts')\n","replace('/content/pytorch_fnet/scripts/train_model.sh','python train_model.py', 'python /content/pytorch_fnet/train_model.py')\n","replace('/content/pytorch_fnet/scripts/train_model.sh','PATH_DATASET_ALL_CSV','#PATH_DATASET_ALL_CSV')\n","\n","# Find the number of steps to add and then add #\n","with open(Pretrained_model_folder+'/losses.csv') as f:\n"," previous_steps = sum(1 for line in f)\n","print('continuing training after step '+str(previous_steps-1))\n","\n","print('To start re-training play section 4.2. below')\n","\n","#@markdown For how many additional steps do you want to train the model?\n","add_steps = 2000#@param {type:\"number\"}\n","\n","#Calculate the new number of total training epochs. Subtract 1 to discount the title row of the csv file.\n","new_steps = previous_steps + add_steps -1\n","os.chdir('/content/pytorch_fnet/scripts')\n","\n","#Edit train_model.sh file to include new total number of training epochs\n","!sed -i \"s/N_ITER=.*/N_ITER=$new_steps/g\" train_model.sh\n","\n","# Display example data #\n","\n","#Load one randomly chosen training source file\n","#random_choice=random.choice(os.listdir(Training_source))\n","x = io.imread(ExampleSource)\n","\n","#Find image Z dimension and select the mid-plane\n","# Image_Z = x.shape[0]\n","# mid_plane = int(Image_Z / 2)+1\n","\n","#os.chdir(Training_target)\n","y = io.imread(ExampleTarget)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, norm=simple_norm(x, percent = 99), interpolation='nearest')\n","plt.axis('off')\n","plt.title('Training Source (single Z plane)');\n","plt.subplot(1,2,2)\n","plt.imshow(y, norm=simple_norm(y, percent = 99), interpolation='nearest')\n","plt.axis('off')\n","plt.title('Training Target (single Z plane)');\n","plt.savefig('/content/TrainingDataExample_Fnet.png',bbox_inches='tight',pad_inches=0)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"5IXdFqhM8gO2","cellView":"form"},"source":["Use_Data_augmentation = False \n","start = time.time()\n","\n","#@markdown ##4.2. Start re-training model\n","# !pip install tifffile==2019.7.26\n","\n","os.chdir('/content/pytorch_fnet/fnet')\n","\n","add_insert(\"/content/pytorch_fnet/fnet/functions.py\",5,\"import fnet.fnet_model\",\"\")\n","\n","#Here we retrain the model on the chosen dataset.\n","os.chdir('/content/pytorch_fnet/')\n","!chmod u+x ./scripts/train_model.sh\n","!./scripts/train_model.sh $Pretrained_model_name 0\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","#Here, we redefine the variable names for the pdf export\n","percentage_validation = round((number_of_val_images/(number_of_images+number_of_val_images))*100)\n","steps = new_steps\n","model_name = Pretrained_model_name\n","model_path = Pretrained_model_path\n","Training_source = os.path.dirname(ExampleSource)\n","Training_target = os.path.dirname(ExampleTarget)\n","#Create a pdf document with training summary\n","pdf_export(trained = True, augmentation = Use_Data_augmentation)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"_0Hynw3-xHp1"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**"]},{"cell_type":"code","metadata":{"id":"eAJzMwPA6tlH","cellView":"form"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the name of the model and path to model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","#Create a folder for the quality control metrics\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/QualityControl\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/QualityControl\")\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/QualityControl\")\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"dhJROwlAMv5o"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased.\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"vMzSP50kMv5p","cellView":"form"},"source":["#@markdown ##Play the cell to show figure of training errors\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","iterationNumber_training = []\n","iterationNumber_val = []\n","\n","import csv\n","from matplotlib import pyplot as plt\n","with open(QC_model_path+'/'+QC_model_name+'/'+'losses.csv','r') as csvfile:\n"," plots = csv.reader(csvfile, delimiter=',')\n"," next(plots)\n"," for row in plots:\n"," iterationNumber_training.append(int(row[0]))\n"," lossDataFromCSV.append(float(row[1]))\n","\n","with open(QC_model_path+'/'+QC_model_name+'/'+'losses_val.csv','r') as csvfile_val:\n"," plots = csv.reader(csvfile_val, delimiter=',')\n"," next(plots)\n"," for row in plots:\n"," iterationNumber_val.append(int(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(iterationNumber_training, lossDataFromCSV, label='Training loss')\n","plt.plot(iterationNumber_val, vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. iteration number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Iteration')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(iterationNumber_training, lossDataFromCSV, label='Training loss')\n","plt.semilogy(iterationNumber_val, vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. iteration number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Iteration')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/QualityControl/lossCurvePlots.png', bbox_inches='tight', pad_inches=0)\n","plt.show()\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"X5_92nL2xdP6"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n","\n","\n","**Note:** If you receive a *CUDA out of memory* error, this can be caused by the size of the data that model needs to predict or the type of GPU has allocated to your session. To solve this issue, you can *factory reset runtime* to attempt to connect to a different GPU or use a dataset with smaller images.\n"]},{"cell_type":"code","metadata":{"id":"w90MdriMxhjD","cellView":"form"},"source":["#----------------CREATING PREDICTIONS FOR QUALITY CONTROL----------------------------------#\n","\n","### Choosing and editing the path names ###\n","\n","#Choose the folder with the quality control datasets\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","qc_images = len(os.listdir(Source_QC_folder))\n","\n","Predictions_name = \"QualityControl\" \n","Predictions_name_x = Predictions_name+\"}\"\n","\n","#If the folder you are creating already exists, delete the existing version to overwrite.\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/QualityControl/Predictions\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/QualityControl/Predictions\")\n","\n","# Get the name of the folder the test data is in\n","source_dataset_name = os.path.basename(os.path.normpath(Source_QC_folder))\n","target_dataset_name = os.path.basename(os.path.normpath(Target_QC_folder))\n","\n","new_full_QC_model_path = convert_to_script_compatible_path(full_QC_model_path)\n","new_full_QC_model_path_dataset = new_full_QC_model_path+'\\${DATASET}'\n","new_full_QC_model_path_csv = new_full_QC_model_path+'\\/QualityControl\\/qc\\${TEST_OR_TRAIN}\\.csv'# Get the name of the folder the test data is in\n","\n","\n","### Editing the predict.sh script file ###\n","\n","# Get permission to the predict.sh file and change the name of the dataset to the Predictions_folder.\n","os.chdir('/content/pytorch_fnet/')\n","# !chmod u+x ./scripts/predict.sh\n","# !sed -i \"s/1:-.*/1:-$Predictions_name_x/g\" ./scripts/predict.sh\n","!chmod u+x ./scripts/predict_2d.sh\n","!sed -i \"1,21!d\" ./scripts/predict_2d.sh\n","\n","!sed -i \"s/1:-.*/1:-$Predictions_name_x/g\" ./scripts/predict_2d.sh\n","\n","#Here, we remove the 'train' option from predict.sh as we don't need to run predictions on the train data.\n","#!sed -i \"s/in test.*/in test/g\" ./scripts/predict.sh\n","\n","# !if ! grep class_dataset ./scripts/predict.sh;then sed -i 's/DIR} \\\\/DIR} \\\\\\'$''\\n' --class_dataset TiffDataset \\\\/' /content/pytorch_fnet/scripts/predict.sh; fi\n","# !if grep CziDataset ./scripts/predict.sh;then sed -i 's/CziDataset/TiffDataset/' ./scripts/predict.sh; fi \n","\n","!sed -i \"s/MODEL_DIR=.*/MODEL_DIR=$new_full_QC_model_path/g\" ./scripts/predict_2d.sh\n","!sed -i \"s/path_dataset_csv.*/path_dataset_csv\\ $new_full_QC_model_path_csv\\ \\\\\\/g\" ./scripts/predict_2d.sh\n","!sed -i \"s/path_save_dir.*/path_save_dir $new_full_QC_model_path\\/QualityControl\\/Predictions\\ \\\\\\/g\" ./scripts/predict_2d.sh\n","!sed -i \"s/N_IMAGES=.*/N_IMAGES=$qc_images/g\" ./scripts/predict_2d.sh\n","\n","### Create a path csv file for prediction (QC)###\n","\n","#Here we create a qctest.csv to locate the files used for QC\n","test_signal = os.listdir(Source_QC_folder)\n","test_target = os.listdir(Target_QC_folder)\n","\n","with open(full_QC_model_path+'/QualityControl/qctest.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(test_signal)):\n"," writer.writerow([Source_QC_folder+'/'+test_signal[i],Target_QC_folder+'/'+test_signal[i]])\n","\n","### RUN THE PREDICTION ###\n","!./scripts/predict_2d.sh $Predictions_name 0\n","\n","### Save the results ###\n","QC_results_files = os.listdir(full_QC_model_path+'/QualityControl/Predictions')\n","\n","if os.path.exists(QC_model_path+'/'+QC_model_name+'/QualityControl/Prediction'):\n"," shutil.rmtree(QC_model_path+'/'+QC_model_name+'/QualityControl/Prediction')\n","os.mkdir(QC_model_path+'/'+QC_model_name+'/QualityControl/Prediction')\n","\n","if os.path.exists(QC_model_path+'/'+QC_model_name+'/QualityControl/Signal'):\n"," shutil.rmtree(QC_model_path+'/'+QC_model_name+'/QualityControl/Signal')\n","os.mkdir(QC_model_path+'/'+QC_model_name+'/QualityControl/Signal')\n","\n","if os.path.exists(QC_model_path+'/'+QC_model_name+'/QualityControl/Target'):\n"," shutil.rmtree(QC_model_path+'/'+QC_model_name+'/QualityControl/Target')\n","os.mkdir(QC_model_path+'/'+QC_model_name+'/QualityControl/Target')\n","\n","for i in range(len(QC_results_files)-2):\n"," if os.path.isdir(full_QC_model_path+'/QualityControl/Predictions/'+QC_results_files[i]):\n"," shutil.copyfile(full_QC_model_path+'/QualityControl/Predictions/'+QC_results_files[i]+'/prediction_'+QC_model_name+'.tiff', QC_model_path+'/'+QC_model_name+'/QualityControl/Prediction/'+'Predicted_'+test_signal[i])\n"," if os.path.exists(full_QC_model_path+'/QualityControl/Predictions/'+QC_results_files[i]+'/signal.tiff'):\n"," shutil.copyfile(full_QC_model_path+'/QualityControl/Predictions/'+QC_results_files[i]+'/signal.tiff', QC_model_path+'/'+QC_model_name+'/QualityControl/Signal/'+test_signal[i])\n"," shutil.copyfile(full_QC_model_path+'/QualityControl/Predictions/'+QC_results_files[i]+'/target.tiff', QC_model_path+'/'+QC_model_name+'/QualityControl/Target/'+test_signal[i])\n"," else:\n"," shutil.copyfile(Source_QC_folder+'/'+test_signal[i],QC_model_path+'/'+QC_model_name+'/QualityControl/Signal/'+test_signal[i])\n"," shutil.copyfile(Target_QC_folder+'/'+test_target[i],QC_model_path+'/'+QC_model_name+'/QualityControl/Target/'+test_signal[i])\n","\n","shutil.rmtree(full_QC_model_path+'/QualityControl/Predictions')\n","\n","#-----------------------------METRICS EVALUATION-------------------------------#\n","\n","# Calculating the position of the mid-plane slice\n","# Perform prediction on all datasets in the Source_QC folder\n","\n","#Finding the middle slice\n","img = io.imread(os.path.join(Source_QC_folder, os.listdir(Source_QC_folder)[0]))\n","# n_slices = img.shape[0]\n","# z_mid_plane = int(n_slices / 2)+1\n","\n","path_metrics_save = QC_model_path+'/'+QC_model_name+'/QualityControl/'\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(path_metrics_save+'QC_metrics_'+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"File name\",\"Prediction v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Prediction v. GT PSNR\"]) \n"," \n"," # These lists will be used to collect all the metrics values per slice\n"," file_name_list = []\n"," slice_number_list = []\n"," mSSIM_GvP_list = []\n"," NRMSE_GvP_list = []\n"," PSNR_GvP_list = []\n","\n"," # These lists will be used to display the mean metrics for the stacks\n"," mSSIM_GvP_list_mean = []\n"," NRMSE_GvP_list_mean = []\n"," PSNR_GvP_list_mean = []\n","\n"," # Let's loop through the provided dataset in the QC folders\n"," for thisFile in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder, thisFile)):\n"," print('Running QC on: '+thisFile)\n","\n"," test_GT_stack = io.imread(os.path.join(Target_QC_folder, thisFile))\n"," if len(test_GT_stack.shape) > 3:\n"," test_GT_stack = test_GT_stack.squeeze()\n"," test_source_stack = io.imread(os.path.join(Source_QC_folder,thisFile))\n"," test_prediction_stack = io.imread(os.path.join(path_metrics_save+\"Prediction/\",'Predicted_'+thisFile))\n"," #test_prediction_stack = np.squeeze(test_prediction_stack,axis=(0,))\n"," #n_slices = test_GT_stack.shape[0]\n","\n"," img_SSIM_GTvsPrediction_stack = np.zeros((test_GT_stack.shape[0], test_GT_stack.shape[1]))\n"," img_RSE_GTvsPrediction_stack = np.zeros((test_GT_stack.shape[0], test_GT_stack.shape[1]))\n","\n"," #for z in range(n_slices): \n"," \n"," # -------------------------------- Prediction --------------------------------\n","\n"," test_GT_norm,test_prediction_norm = norm_minmse(test_GT_stack, test_prediction_stack, normalize_gt=True)\n","\n"," # -------------------------------- Calculate the SSIM metric and maps --------------------------------\n","\n"," # Calculate the SSIM maps and index\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = structural_similarity(test_GT_norm, np.squeeze(test_prediction_norm), data_range=1.0, full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n"," #Calculate ssim_maps\n"," img_SSIM_GTvsPrediction_stack = np.float32(img_SSIM_GTvsPrediction)\n"," \n","\n"," # -------------------------------- Calculate the NRMSE metrics --------------------------------\n","\n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n","\n"," # Calculate SE maps\n"," img_RSE_GTvsPrediction_stack = np.float32(img_RSE_GTvsPrediction)\n","\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n","\n","\n"," # Calculate the PSNR between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,np.squeeze(test_prediction_norm),data_range=1.0)\n","\n","\n"," writer.writerow([thisFile,str(index_SSIM_GTvsPrediction),str(NRMSE_GTvsPrediction),str(PSNR_GTvsPrediction)])\n"," \n"," # Collect values to display in dataframe output\n"," #file_name_list.append(thisFile)\n"," #slice_number_list.append(z)\n"," mSSIM_GvP_list.append(index_SSIM_GTvsPrediction)\n","\n"," NRMSE_GvP_list.append(NRMSE_GTvsPrediction)\n","\n"," PSNR_GvP_list.append(PSNR_GTvsPrediction)\n","\n","\n"," # if (z == z_mid_plane): # catch these for display\n"," SSIM_GTvsP_forDisplay = index_SSIM_GTvsPrediction\n","\n"," NRMSE_GTvsP_forDisplay = NRMSE_GTvsPrediction\n","\n"," \n"," # If calculating average metrics for dataframe output\n"," file_name_list.append(thisFile)\n"," mSSIM_GvP_list_mean.append(sum(mSSIM_GvP_list)/len(mSSIM_GvP_list))\n","\n"," NRMSE_GvP_list_mean.append(sum(NRMSE_GvP_list)/len(NRMSE_GvP_list))\n","\n"," PSNR_GvP_list_mean.append(sum(PSNR_GvP_list)/len(PSNR_GvP_list))\n","\n"," # ----------- Change the stacks to 32 bit images -----------\n"," img_SSIM_GTvsPrediction_stack_32 = img_as_float32(img_SSIM_GTvsPrediction_stack, force_copy=False)\n"," img_RSE_GTvsPrediction_stack_32 = img_as_float32(img_RSE_GTvsPrediction_stack, force_copy=False)\n","\n","\n"," # ----------- Saving the error map stacks -----------\n"," io.imsave(path_metrics_save+'SSIM_GTvsPrediction_'+thisFile,img_SSIM_GTvsPrediction_stack_32)\n"," io.imsave(path_metrics_save+'RSE_GTvsPrediction_'+thisFile,img_RSE_GTvsPrediction_stack_32)\n","\n","#Averages of the metrics per stack as dataframe output\n","pdResults = pd.DataFrame(file_name_list, columns = [\"File name\"])\n","pdResults[\"Prediction v. GT mSSIM\"] = mSSIM_GvP_list_mean\n","\n","pdResults[\"Prediction v. GT NRMSE\"] = NRMSE_GvP_list_mean\n","\n","pdResults[\"Prediction v. GT PSNR\"] = PSNR_GvP_list_mean\n","\n","pdResults.head()\n","\n","# All data is now processed saved\n","Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same way\n","\n","plt.figure(figsize=(15,10))\n","# Currently only displays the last computed set, from memory\n","\n","# Target (Ground-truth)\n","plt.subplot(2,3,1)\n","plt.axis('off')\n","img_GT = io.imread(os.path.join(Target_QC_folder, Test_FileList[-1]))\n","# if len(img_GT.shape) > 3:\n","# img_GT = img_GT.squeeze()\n","plt.imshow(img_GT)\n","plt.title('Target')\n","\n","\n","#Setting up colours\n","cmap = plt.cm.Greys\n","\n","\n","# Source\n","plt.subplot(2,3,2)\n","plt.axis('off')\n","img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_Source,aspect='equal',cmap=cmap)\n","plt.title('Source')\n","\n","\n","#Prediction\n","plt.subplot(2,3,3)\n","plt.axis('off')\n","img_Prediction = io.imread(os.path.join(path_metrics_save+'Prediction/', 'Predicted_'+Test_FileList[-1]))\n","img_Prediction = np.squeeze(img_Prediction,axis=(0,))\n","plt.imshow(img_Prediction)\n","plt.title('Prediction')\n","\n","#Setting up colours\n","cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Prediction\n","plt.subplot(2,3,5)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_SSIM_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'SSIM_GTvsPrediction_'+Test_FileList[-1]))\n","imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n","plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n","plt.title('SSIM map: Target vs. Prediction',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(SSIM_GTvsP_forDisplay,3)),fontsize=14)\n","\n","\n","#Root Squared Error between GT and Prediction\n","plt.subplot(2,3,6)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_RSE_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'RSE_GTvsPrediction_'+Test_FileList[-1]))\n","imRSE_GTvsPrediction = plt.imshow(np.squeeze(img_RSE_GTvsPrediction), cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n","plt.title('RSE map Target vs. Prediction',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsP_forDisplay,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n","plt.savefig(full_QC_model_path+'/QualityControl/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","\n","print('-----------------------------------')\n","print('Here are the average scores for the stacks you tested in Quality control. To see values for all slices, open the .csv file saved in the Qulity Control folder.')\n","pdResults.head()\n","\n","\n","#Make a pdf summary of the QC results\n","\n","qc_pdf_export()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-tJeeJjLnRkP"},"source":["#**6. Using the trained model**\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"d8wuQGjoq6eN"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Results_folder** folder.\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Results_folder`:** This folder will contain the predicted output images.\n","\n","If you want to use a model different from the most recently trained one, untick the box and enter the path of the model in **`Prediction_model_folder`**.\n","\n","**Note: `Prediction_model_folder` expects a folder name which contains a model.p file from a previous training.**\n","\n","**Note:** If you receive a *CUDA out of memory* error, this can be caused by the size of the data that model needs to predict or the type of GPU has allocated to your session. To solve this issue, you can *factory reset runtime* to attempt to connect to a different GPU or use a dataset with smaller images.\n"]},{"cell_type":"code","metadata":{"id":"y2TD5p7MZrEb","cellView":"form"},"source":["#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then play the cell to predict outputs from your unseen images.\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Results_folder = \"\" #@param {type:\"string\"}\n","data_files = len(os.listdir(Data_folder))\n","\n","if os.path.exists(Results_folder+\"/Predictions\"):\n"," shutil.rmtree(Results_folder+\"/Predictions\")\n","\n","#@markdown ###Do you want to use the current trained model?\n","\n","Use_the_current_trained_model = True #@param{type:\"boolean\"}\n","#@markdown ###If not, provide the name of the model you want to use \n","\n","### Choosing and editing the path names ###\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","if Use_the_current_trained_model:\n"," Prediction_model_folder = model_path+'/'+model_name\n","\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","Prediction_model_name_x = Prediction_model_name+\"}\"\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","# Convert the path variables into a compatible format with the script files #\n","# Prediction path conversion\n","new_full_Prediction_model_path = convert_to_script_compatible_path(full_Prediction_model_path)\n","new_full_Prediction_model_path_csv = new_full_Prediction_model_path+'\\${TEST_OR_TRAIN}\\.csv'# Get the name of the folder the test data is in\n","\n","# Result path conversion\n","new_Results_folder_path = convert_to_script_compatible_path(Results_folder)\n","\n","# Get permission to the predict.sh file and change the name of the dataset to the Predictions_folder.\n","os.chdir('/content/pytorch_fnet/')\n","!chmod u+x ./scripts/predict_2d.sh\n","\n","### Editing the predict.sh script file ###\n","\n","# Make sure the dataset type is set to .tif (debug note: could be changed at install in predict.py file?)\n","# !if ! grep class_dataset ./scripts/predict_2d.sh;then sed -i 's/DIR} \\\\/DIR} \\\\\\'$''\\n' --class_dataset TiffDataset \\\\/' /content/pytorch_fnet/scripts/predict.sh; fi\n","# !if grep CziDataset ./scripts/predict_2d.sh;then sed -i 's/CziDataset/TiffDataset/' /content/pytorch_fnet/scripts/predict.sh; fi \n","\n","# We allow the maximum number of images to be processed to be higher, i.e. 1000.\n","!sed -i \"s/N_IMAGES=.*/N_IMAGES=$data_files/g\" ./scripts/predict_2d.sh\n","!sed -i \"s/1:-.*/1:-$Prediction_model_name_x/g\" ./scripts/predict_2d.sh\n","\n","#Here, we remove the 'train' option from predict.sh as we don't need to run predictions on the train data.\n","!sed -i \"s/in test.*/in test/g\" ./scripts/predict.sh\n","!sed -i \"1,21!d\" ./scripts/predict_2d.sh\n","\n","#We change the directories in the predict.sh file to our needed paths\n","!sed -i \"s/MODEL_DIR=.*/MODEL_DIR=$new_full_Prediction_model_path/g\" ./scripts/predict_2d.sh\n","!sed -i \"s/path_dataset_csv.*/path_dataset_csv\\ $new_full_Prediction_model_path_csv\\ \\\\\\/g\" ./scripts/predict_2d.sh\n","!sed -i \"s/path_save_dir.*/path_save_dir $new_Results_folder_path\\/Predictions\\ \\\\\\/g\" ./scripts/predict_2d.sh\n","\n","# Changing the GPU ID seems to help reduce errors\n","# replace('/content/pytorch_fnet/scripts/predict.sh','${GPU_IDS}','0')\n","\n","# We get rid of the options of saving signals and targets. Here, we just want predictions.\n","insert_1 = ' --no_signal \\\\\\n'\n","insert_2 = ' --no_target \\\\\\n'\n","add_insert(\"/content/pytorch_fnet/scripts/predict_2d.sh\",14,insert_1,\"\")\n","add_insert(\"/content/pytorch_fnet/scripts/predict_2d.sh\",14,insert_2,\"\")\n","\n","### Create the path csv file for prediction ###\n","\n","#Here we create a new test.csv with the paths to the dataset we want to predict on.\n","test_signal = os.listdir(Data_folder)\n","with open(full_Prediction_model_path+'/test.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(test_signal)):\n"," writer.writerow([Data_folder+\"/\"+test_signal[i],Data_folder+\"/\"+test_signal[i]])\n","\n","### WE RUN THE PREDICTION ###\n","start = time.time()\n","!/content/pytorch_fnet/scripts/predict_2d.sh $Prediction_model_name 0\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","#Rename the results appropriately\n","Results = os.listdir(Results_folder+'/Predictions')\n","for i in Results:\n"," if os.path.isdir(Results_folder+'/Predictions/'+i):\n"," shutil.copyfile(Results_folder+'/Predictions/'+i+'/'+os.listdir(Results_folder+'/Predictions/'+i)[0],Results_folder+'/Predictions/'+'predicted_'+test_signal[int(i)])\n"," \n","for i in Results:\n"," if os.path.isdir(Results_folder+'/Predictions/'+i):\n"," shutil.rmtree(Results_folder+'/Predictions/'+i)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"bFtArIjs9tS9"},"source":["##**6.2. Assess predicted output**\n","---\n","Here, we inspect an example prediction from the predictions on the test dataset. Select the slice of the slice you want to visualize."]},{"cell_type":"code","metadata":{"id":"66-af3rO9vM4","cellView":"form"},"source":["#@markdown ###Select the image would you like to view below\n","\n","def show_image(file=os.listdir(Data_folder)):\n"," os.chdir(Results_folder)\n","\n","#source_image = io.imread(test_signal[0])\n"," source_image = io.imread(os.path.join(Data_folder,file))\n"," prediction_image = io.imread(os.path.join(Results_folder,'Predictions/predicted_'+file))\n"," prediction_image = np.squeeze(prediction_image, axis=(0,))\n","\n","#Create the figure\n"," fig = plt.figure(figsize=(10,20))\n","\n"," #Setting up colours\n"," cmap = plt.cm.Greys\n","\n"," plt.subplot(1,2,1)\n"," print(prediction_image.shape)\n"," plt.imshow(source_image, cmap = cmap, aspect = 'equal')\n"," plt.title('Source')\n"," plt.subplot(1,2,2)\n"," plt.imshow(prediction_image, cmap = cmap, aspect = 'equal')\n"," plt.title('Prediction')\n","\n","interact(show_image, continuous_update=False);"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"89tlSWBC940z"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"i5zAd43crdN_"},"source":["# **7. Version log**\n","---\n","**v1.13**: \n","\n","* This notebook is new as of ZeroCostDL4Mic version 1.13. It includes the changes made across all the notebooks for this release.\n","* This section will include any future changes in following releases.\n"]},{"cell_type":"markdown","metadata":{"id":"UvSlTaH14s3t"},"source":["#**Thank you for using fnet!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/fnet_3D_ZeroCostDL4Mic.ipynb b/Colab_notebooks/fnet_3D_ZeroCostDL4Mic.ipynb new file mode 100644 index 00000000..9c58365e --- /dev/null +++ b/Colab_notebooks/fnet_3D_ZeroCostDL4Mic.ipynb @@ -0,0 +1 @@ +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"fnet_3D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1SisekHpRSJ0QKHvDePqFe09lkklVytwI","timestamp":1622728423435},{"file_id":"12UsRdIQbcWQjYewI2wrcwIWfVxc6hOfc","timestamp":1620660071757},{"file_id":"1A26cn0nxWQCv-LuP3UBfyCWlKBGIo0RU","timestamp":1611063104553},{"file_id":"1MmLTCC0nyX3Akb9V4C_OVxM3X_M8u-eX","timestamp":1610543191319},{"file_id":"1paNjUObR5Rcr4BMGADJTz0PQBBLZDPrY","timestamp":1602522500580},{"file_id":"1WZRIoSBNcRUEq4-Rq5M4mDkIaOlEHnxz","timestamp":1588762142860},{"file_id":"10weAY0es-pEfHlACCaBCKK7PmgdoJqdh","timestamp":1587728072051},{"file_id":"10Ze0rFZoooyyTL_OIVWGdFJEhWE6_cSB","timestamp":1586789421439},{"file_id":"1SsGyUbWcMaLGHFepMuKElRNYLdEBUwf6","timestamp":1583244509550}],"collapsed_sections":["IkSguVy8Xv83","jWAz2i7RdxUV","gKDLkLWUd-YX","UvSlTaH14s3t"],"machine_shape":"hm"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.7"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"IkSguVy8Xv83"},"source":["#**Label-free prediction - fnet**\n","---\n","\n"," \n","Label-free prediction (fnet) is a neural network developped to infer the distribution of specific cellular structures from label-free images such as brightfield or EM images. It was first published in 2018 by [Ounkomol *et al.* in Nature Methods](https://www.nature.com/articles/s41592-018-0111-2). The network uses a common U-Net architecture and is trained using paired imaging volumes from the same field of view, imaged in a label-free (e.g. brightfield) and labelled condition (e.g. fluorescence images of a specific label of interest). When trained, this allows the user to identify certain structures from brightfield images alone. The performance of fnet may depend significantly on the structure at hand.\n","\n","---\n"," *Disclaimer*:\n","\n"," This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n"," This notebook is largely based on the paper: \n","\n","**Label-free prediction of three-dimensional fluorescence images from transmitted light microscopy** by Ounkomol *et al.* in Nature Methods, 2018 (https://www.nature.com/articles/s41592-018-0111-2)\n","\n"," And source code found in: /~https://github.com/AllenCellModeling/pytorch_fnet\n","\n"," **Please also cite this original paper when using or developing this notebook.** \n"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"gKDLkLWUd-YX"},"source":["#**0. Before getting started**\n","---\n","**Data Format**\n","\n"," **The data used to train fnet must be 3D stacks in .tiff (.tif) file format and contain the signal (e.g. bright-field image) and the target channel (e.g. fluorescence) for each field of view**. To use this notebook on user data, upload the data in the following format to your google drive. To ensure corresponding images are used during training give corresponding signal and target images the same name.\n","\n","Information on how to generate a training dataset is available in our Wiki page: /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n"," **Note: Your *dataset_folder* should not have spaces or brackets in its name as this is not recognized by the fnet code and will throw an error** \n","\n","\n","* Experiment A\n"," - **Training dataset**\n"," - bright-field images\n"," - img_1.tif, img_2.tif, ...\n"," - fluorescence images\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - bright-field images\n"," - img_1.tif, img_2.tif\n"," - fluorescence images\n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"AdN8B91xZO0x"},"source":["#**1. Install fnet and dependencies**\n","---\n","Running fnet requires the fnet folder to be downloaded into the session's Files. As fnet needs several packages to be installed, this step may take a few minutes.\n","\n","You can ignore **the error warnings** as they refer to packages not required for this notebook.\n","\n","**Note: It is not necessary to keep the pytorch_fnet folder after you are finished using the notebook, so it can be deleted afterwards by playing the last cell (bottom).**"]},{"cell_type":"markdown","metadata":{"id":"GgmEMSOUybyu"},"source":["## **1.1. Install key dependencies**\n","---\n"," "]},{"cell_type":"code","metadata":{"cellView":"form","id":"bGu_k66ZxoJW"},"source":["#@markdown ##Install fnet and dependencies\n","!pip install fpdf\n","#clone fnet from github to colab\n","!git clone -b release_1 --single-branch /~https://github.com/AllenCellModeling/pytorch_fnet.git; cd pytorch_fnet; pip install .\n","!pip install -U scipy==1.2.0\n","!pip install matplotlib==2.2.3\n","!pip install tifffile==2019.7.26\n","# !pip install --no-cache-dir tifffile==2019.7.26 \n","#Force session restart\n","exit(0)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"_j2XyI76yhtT"},"source":["## **1.2. Restart your runtime**\n","---\n","\n","\n","\n","** Ignore the following message error message. Your Runtime has automatically restarted. This is normal.**\n","\n","\"\"
\n"]},{"cell_type":"markdown","metadata":{"id":"hKXc0D11y6q8"},"source":["## **1.3. Load key dependencies**\n","---\n"," "]},{"cell_type":"code","metadata":{"id":"fq21zJVFNASx","cellView":"form"},"source":["Notebook_version = '1.13'\n","Network = 'fnet (3D)'\n","\n","from builtins import any as b_any\n","\n","def get_requirements_path():\n"," # Store requirements file in 'contents' directory \n"," current_dir = os.getcwd()\n"," dir_count = current_dir.count('/') - 1\n"," path = '../' * (dir_count) + 'requirements.txt'\n"," return path\n","\n","def filter_files(file_list, filter_list):\n"," filtered_list = []\n"," for fname in file_list:\n"," if b_any(fname.split('==')[0] in s for s in filter_list):\n"," filtered_list.append(fname)\n"," return filtered_list\n","\n","def build_requirements_file(before, after):\n"," path = get_requirements_path()\n","\n"," # Exporting requirements.txt for local run\n"," !pip freeze > $path\n","\n"," # Get minimum requirements file\n"," df = pd.read_csv(path, delimiter = \"\\n\")\n"," mod_list = [m.split('.')[0] for m in after if not m in before]\n"," req_list_temp = df.values.tolist()\n"," req_list = [x[0] for x in req_list_temp]\n","\n"," # Replace with package name and handle cases where import name is different to module name\n"," mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]\n"," mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list] \n"," filtered_list = filter_files(req_list, mod_replace_list)\n","\n"," file=open(path,'w')\n"," for item in filtered_list:\n"," file.writelines(item + '\\n')\n","\n"," file.close()\n","\n","import sys\n","before = [str(m) for m in sys.modules]\n","\n","#@markdown ##Load key dependencies\n","import os\n","import csv\n","import shutil\n","import random\n","from tempfile import mkstemp\n","from shutil import move, copymode\n","from os import fdopen, remove\n","import numpy as np\n","import shutil\n","from tempfile import mkstemp\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from skimage import img_as_float32\n","from distutils.dir_util import copy_tree\n","from datetime import datetime\n","from astropy.visualization import simple_norm\n","import time\n","from fpdf import FPDF, HTMLMixin\n","from pip._internal.operations.freeze import freeze\n","import subprocess\n","from skimage import io\n","from matplotlib import pyplot as plt\n","import pandas as pd\n","import matplotlib as mpl\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","\n","#Ensure tensorflow 1.x\n","%tensorflow_version 1.x\n","import tensorflow\n","print(tensorflow.__version__)\n","\n","print(\"Tensorflow enabled.\")\n","\n","def replace(file_path, pattern, subst):\n"," \"\"\"Function replaces a pattern in a .py file with a new pattern.\"\"\"\n"," \n"," \"\"\"Parameters:\n"," -file_path (string): path to the file to be changed.\n"," -pattern (string): pattern to be replaced. Make sure this is as unique as possible.\n"," -subst (string): new pattern. \"\"\"\n"," #Create temp file\n"," fh, abs_path = mkstemp()\n"," with fdopen(fh,'w') as new_file:\n"," with open(file_path) as old_file:\n"," for line in old_file:\n"," new_file.write(line.replace(pattern, subst))\n"," #Copy the file permissions from the old file to the new file\n"," copymode(file_path, abs_path)\n"," #Remove original file\n"," remove(file_path)\n"," #Move new file\n"," move(abs_path, file_path)\n","\n","def add_insert(filepath,line_number,insertion,append):\n"," \"\"\"Function which inserts the a line into a document.\"\"\"\n"," \n"," \"\"\"Parameters:\n"," -filepath (string): path to the file which needs to be edited.\n"," -line (integer): Where to insert the new line. In the file, this line is ideally an empty one.\n"," -insertion (string): The line to be inserted. If it already exists it will not be added again.\n"," -append (string): If anything additional needs to be appended to the line, use this. Otherwise, leave as \"\" \"\"\"\n"," \n"," f = open(filepath, \"r\")\n"," contents = f.readlines()\n"," f.close()\n"," f = open(filepath, \"r\")\n"," if not insertion in f.read():\n"," contents.insert(line_number, insertion)\n"," if append != \"\":\n"," contents.append(append)\n"," f.close()\n"," f = open(filepath, \"w\")\n"," contents = \"\".join(contents)\n"," f.write(contents)\n"," f.close()\n","\n","\n","def convert_to_script_compatible_path(original_path):\n"," \"\"\"Function which converts 'original_path' into a compatible format 'new_full_path' with the fnet .sh files \"\"\"\n"," new_full_path = \"\"\n"," for s in original_path:\n"," if s=='/':\n"," new_full_path += '\\/'\n"," else:\n"," new_full_path += s\n","\n"," return new_full_path\n","\n","\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","#Here we replace values in the old files\n","#Change maximum pixel number\n","replace(\"/content/pytorch_fnet/fnet/transforms.py\",'n_max_pixels=9732096','n_max_pixels=20000000')\n","replace(\"/content/pytorch_fnet/predict.py\",'6000000','20000000')\n","\n","#Prevent resizing in the training and the prediction\n","replace(\"/content/pytorch_fnet/predict.py\",\"0.37241\",\"1.0\")\n","replace(\"/content/pytorch_fnet/train_model.py\",\"0.37241\",\"1.0\")\n","\n","#Change the default dataset type in the training module to .tif\n","replace(\"/content/pytorch_fnet/train_model.py\",\"'--class_dataset', default='CziDataset'\",\"'--class_dataset', default='TiffDataset'\")\n","\n","print(\"-------------------\")\n","print(\"Libraries installed\")\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","\n","\n","# Check if this is the latest version of the notebook\n","All_notebook_versions = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_Notebook_versions.csv\", dtype=str)\n","print('Notebook version: '+Notebook_version)\n","Latest_Notebook_version = All_notebook_versions[All_notebook_versions[\"Notebook\"] == Network]['Version'].iloc[0]\n","print('Latest notebook version: '+Latest_Notebook_version)\n","if Notebook_version == Latest_Notebook_version:\n"," print(\"This notebook is up-to-date.\")\n","else:\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","\n","def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'Label-free Prediction (fnet)'\n"," day = datetime.now()\n"," date_time = str(day)[0:10]\n","\n"," Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+date_time\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," \n"," # add another cell \n"," if trained:\n"," training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n"," pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n"," pdf.ln(1)\n","\n"," Header_2 = 'Information for your materials and methods:'\n"," pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n"," #print(all_packages)\n","\n"," #Main Packages\n"," main_packages = ''\n"," version_numbers = []\n"," for name in ['tensorflow','numpy','torch','scipy']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n"," cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n"," cuda_version = cuda_version.stdout.decode('utf-8')\n"," cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n"," gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n"," gpu_name = gpu_name.stdout.decode('utf-8')\n"," gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n","\n"," shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n"," dataset_size = len(os.listdir(Training_source))\n","\n"," text = 'The '+Network+' model was trained from scratch for '+str(steps)+' steps on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(32)+','+str(64)+','+str(64)+')) with a batch size of '+str(batch_size)+' and an MSE loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), torch (v '+version_numbers[2]+'), scipy (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," #text = 'The '+Network+' model ('+model_name+') was trained using '+str(dataset_size)+' paired images (image dimensions: '+str(shape)+') using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), torch (v '+version_numbers[2]+'), scipy (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The GPU used was a '+gpu_name+'.'\n","\n"," #if Use_pretrained_model:\n"," # text = 'The '+Network+' model ('+model_name+') was trained using '+str(dataset_size)+' paired images (image dimensions: '+str(shape)+') using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), pytorch (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The GPU used was a '+gpu_name+'. The model was trained from the pretrained model: '+pretrained_model_path+'.'\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.multi_cell(190, 5, txt = text, align='L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n"," pdf.set_font('')\n"," if augmentation:\n"," aug_text = 'The dataset was augmented by'\n"," if Rotation:\n"," aug_text = aug_text+'\\n- rotation'\n"," if Flip:\n"," aug_text = aug_text+'\\n- flipping'\n"," else:\n"," aug_text = 'No augmentation was used for training.'\n"," pdf.multi_cell(190, 5, txt=aug_text, align='L')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," # if Use_Default_Advanced_Parameters:\n"," # pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n"," pdf.cell(200, 5, txt='The following parameters were used for training:')\n"," pdf.ln(1)\n"," html = \"\"\" \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
percentage_validation{0}
steps{1}
batch_size{2}
\n"," \"\"\".format(percentage_validation,steps,batch_size)\n"," pdf.write_html(html)\n","\n"," #pdf.multi_cell(190, 5, txt = text_2, align='L')\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," pdf.ln(1)\n"," pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(30, 5, txt= 'Training_source:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(28, 5, txt= 'Training_target:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n"," #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(22, 5, txt= 'Model Path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n"," pdf.ln(1)\n"," pdf.cell(60, 5, txt = 'Example Training pair (single slice)', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread('/content/TrainingDataExample_Fnet.png').shape\n"," pdf.image('/content/TrainingDataExample_Fnet.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- Label-free prediction (fnet): Ounkomol, Chawin, et al. \"Label-free prediction of three-dimensional fluorescence images from transmitted-light microscopy.\" Nature methods 15.11 (2018): 917-920.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," pdf.ln(3)\n"," reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," if trained:\n"," pdf.output(model_path+'/'+model_name+'/'+model_name+'_'+date_time+\"_training_report.pdf\")\n"," else:\n"," pdf.output('/content/'+model_name+'_'+date_time+\"_training_report.pdf\")\n"," \n","\n","\n","def qc_pdf_export():\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'Label-free prediction (fnet)'\n"," #model_name = os.path.basename(QC_model_folder)\n"," day = datetime.now()\n"," date_time = str(day)[0:10]\n","\n"," Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+date_time\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(2)\n"," pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n"," if os.path.exists(full_QC_model_path+'/QualityControl/lossCurvePlots.png'):\n"," exp_size = io.imread(full_QC_model_path+'/QualityControl/lossCurvePlots.png').shape\n"," pdf.image(full_QC_model_path+'/QualityControl/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," # pdf.ln(3)\n"," pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.')\n"," pdf.ln(3)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(3)\n"," pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/QualityControl/QC_example_data.png').shape\n"," pdf.image(full_QC_model_path+'/QualityControl/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n","\n"," pdf.ln(1)\n"," html = \"\"\"\n"," \n"," \n"," \"\"\"\n"," with open(full_QC_model_path+'/QualityControl/QC_metrics_'+QC_model_name+'.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," slice_n = header[1]\n"," mSSIM_PvsGT = header[2]\n"," NRMSE_PvsGT = header[3]\n"," PSNR_PvsGT = header[4]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,slice_n,mSSIM_PvsGT,NRMSE_PvsGT,PSNR_PvsGT)\n"," html = html+header\n"," for row in metrics:\n"," image = row[0]\n"," slice_n = row[1]\n"," mSSIM_PvsGT = row[2]\n"," NRMSE_PvsGT = row[3]\n"," PSNR_PvsGT = row[4]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,slice_n,str(round(float(mSSIM_PvsGT),3)),str(round(float(NRMSE_PvsGT),3)),str(round(float(PSNR_PvsGT),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}{3}{4}
{0}{1}{2}{3}{4}
\"\"\"\n"," \n"," pdf.write_html(html)\n","\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- Label-free prediction (fnet): Ounkomol, Chawin, et al. \"Label-free prediction of three-dimensional fluorescence images from transmitted-light microscopy.\" Nature methods 15.11 (2018): 917-920.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n"," pdf.ln(3)\n"," reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(full_QC_model_path+'/QualityControl/'+QC_model_name+'_QC_report.pdf')\n","\n","# Build requirements file for local run\n","after = [str(m) for m in sys.modules]\n","build_requirements_file(before, after)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **2. Initialise the Colab session**\n","\n","\n","\n","\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["\n","## **2.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelerator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"zCvebubeSaGY","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"sNIVx8_CLolt"},"source":["## **2.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","cellView":"form"},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"VM8YvXMLzXyA"},"source":["** If you cannot see your files, reactivate your session by connecting to your hosted runtime.** \n","\n","\n","\"Example
Connect to a hosted runtime.
"]},{"cell_type":"markdown","metadata":{"id":"HLYcZR9gMv42"},"source":["# **3. Select your parameters and paths**\n","---"]},{"cell_type":"markdown","metadata":{"id":"FQ_QxtSWQ7CL"},"source":["## **3.1. Setting the main training parameters**\n","---\n"," **Paths for training data**\n","\n"," **`Training_source`,`Training_target:`** These are the paths to your folders containing the Training_source (brightfield) and Training_target (fluorescent label) training data respectively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**Note: The stacks for fnet should either have 32 or more slices or have a number of slices which are a power of 2 (e.g. 2,4,8,16).**\n","\n"," **Training Parameters**\n","\n"," **`percentage validation`** Input the percentage of your training dataset you want to use to validate the network during training. **Default value: 10** \n","\n"," **`steps:`** Input how many iterations you want to train the network for. A larger number may improve performance but risks overfitting to the training data. To reach good performance of fnet requires several 10000's iterations which will usually require **several hours**, depending on the dataset size. **Default: 50000**\n","\n","**`batch_size:`** Reducing or increasing the **batch size** may speed up or slow down your training, respectively and can influence network performance. **Default: 4**\n","\n","**Note 2: If you only need to retrain your model after a time-out, skip this cell and go straight to section 4.2. Just make sure your training datasets are still in their original folders.**"]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","cellView":"form"},"source":["#@markdown ###Datasets\n","#Datasets\n","\n","#Change checkpoints\n","replace(\"/content/pytorch_fnet/train_model.py\",\"'--interval_save', type=int, default=500\",\"'--interval_save', type=int, default=100\")\n","\n","#Adapt Class Dataset for Tiff files\n","replace(\"/content/pytorch_fnet/train_model.py\",\"'--class_dataset', default='CziDataset'\",\"'--class_dataset', default='TiffDataset'\")\n","\n","\n","### Choosing and editing the path names ###\n","\n","#Fetch the path and extract the name of the signal folder\n","Training_source = \"\" #@param {type: \"string\"}\n","source_name = os.path.basename(os.path.normpath(Training_source))\n","\n","#Fetch the path and extract the name of the signal folder\n","Training_target = \"\" #@param {type: \"string\"}\n","target_name = os.path.basename(os.path.normpath(Training_target))\n","\n","#@markdown ###Model name and model path\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","full_model_path = model_path+'/'+model_name\n","\n","new_full_model_path = convert_to_script_compatible_path(full_model_path)\n","new_full_model_path_csv = new_full_model_path+'\\/'+model_name+'\\.csv'\n","new_full_model_path_val_csv = new_full_model_path+'\\/'+model_name+'_val\\.csv'\n","\n","#here we check that no model with the same name already exist, if so print a warning\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n"," print(bcolors.WARNING +\"To continue training \"+model_name+\", skip this cell and instead load \"+model_name+\" as Pretrained_model_folder in section 4.2\")\n","#@markdown ---\n","\n","#@markdown ###Training Parameters\n","\n","percentage_validation = 10#@param{type:\"number\"}\n","\n","model_name_x = model_name+\"}\" # this variable is only used to ensure closed curly brackets when editing the .sh files\n","\n","\n","### Edit the train.sh script file and train.py file ###\n","\n","#We need to declare that we will run validation on the dataset\n","#We need to add a new line to the train.sh file\n","with open(\"/content/pytorch_fnet/scripts/train_model.sh\", \"r\") as f:\n"," if not \"gpu_ids ${GPU_IDS} \\\\\" in f.read():\n"," replace(\"/content/pytorch_fnet/scripts/train_model.sh\",\" --gpu_ids ${GPU_IDS}\",\" --gpu_ids ${GPU_IDS} \\\\\")\n","\n","#We add the necessary validation parameters here.\n","with open('/content/pytorch_fnet/scripts/train_model.sh','r') as scriptfile:\n"," lines = scriptfile.readlines()\n"," if 'PATH_DATASET_VAL_CSV' not in lines:\n"," insert = 'PATH_DATASET_VAL_CSV=\"data/csvs/${DATASET}_val.csv\"'\n"," add_insert(\"/content/pytorch_fnet/scripts/train_model.sh\",10,insert,\"\")\n"," add_insert(\"/content/pytorch_fnet/scripts/train_model.sh\",22,'\\n --path_dataset_val_csv ${PATH_DATASET_VAL_CSV}',\"\")\n"," \n","#Clear the White space from train.sh\n","with open('/content/pytorch_fnet/scripts/train_model.sh', 'r') as inFile,\\\n"," open('/content/pytorch_fnet/scripts/train_model_temp.sh', 'w') as outFile:\n"," for line in inFile:\n"," if line.strip():\n"," outFile.write(line)\n","os.remove('/content/pytorch_fnet/scripts/train_model.sh')\n","os.rename('/content/pytorch_fnet/scripts/train_model_temp.sh','/content/pytorch_fnet/scripts/train_model.sh')\n","\n","#Training parameters in fnet are indicated in the train_model.sh file.\n","#Here, we edit this file to include the desired parameters\n","\n","#1. Add permissions to train_model.sh\n","os.chdir(\"/content/pytorch_fnet/scripts\")\n","!chmod u+x train_model.sh\n","\n","#2. Select parameters\n","steps = 50000#@param {type:\"number\"}\n","batch_size = 4#@param {type:\"number\"}\n","\n","source = os.listdir(Training_source)\n","target = os.listdir(Training_target)\n","number_of_images = len(source[:-round(len(source)*(percentage_validation/100))])\n","\n","#3. Insert the above values into train_model.sh\n","!if ! grep saved_models\\/\\${ train_model.sh;then sed -i 's/saved_models\\/.*/saved_models\\/\\${DATASET}\"/g' train_model.sh; fi \n","!sed -i \"s/1:-.*/1:-$model_name_x/g\" train_model.sh #change the dataset to be trained with\n","!sed -i \"s/N_ITER=.*/N_ITER=$steps/g\" train_model.sh #change the number of training iterations (steps)\n","!sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" train_model.sh #change the number of training images\n","!sed -i \"s/BATCH_SIZE=.*/BATCH_SIZE=$batch_size/g\" train_model.sh #change the batch size\n","\n","!sed -i 's/RUN_DIR=.*/RUN_DIR=\"$new_full_model_path\"/g' train_model.sh\n","!sed -i 's/PATH_DATASET_TRAIN_CSV=.*/PATH_DATASET_TRAIN_CSV=\"$new_full_model_path_csv\"/g' train_model.sh\n","!sed -i 's/PATH_DATASET_VAL_CSV=.*/PATH_DATASET_VAL_CSV=\"$new_full_model_path_val_csv\"/g' train_model.sh\n","\n","#We also change the training split as in our notebook the test images are used separately for prediction and we want fnet to train on the whole training data set.\n","\n","replace('/content/pytorch_fnet/scripts/train_model.sh','python scripts', '#python scripts')\n","replace('/content/pytorch_fnet/scripts/train_model.sh','python train_model.py', 'python /content/pytorch_fnet/train_model.py')\n","replace('/content/pytorch_fnet/scripts/train_model.sh','PATH_DATASET_ALL_CSV','#PATH_DATASET_ALL_CSV')\n","#No Augmentation by default\n","Use_Data_augmentation = False\n","\n","#Load one randomly chosen training source file\n","random_choice=random.choice(os.listdir(Training_source))\n","x = io.imread(Training_source+\"/\"+random_choice)\n","\n","#Find image Z dimension and select the mid-plane\n","Image_Z = x.shape[0]\n","mid_plane = int(Image_Z / 2)+1\n","\n","os.chdir(Training_target)\n","y = io.imread(Training_target+\"/\"+random_choice)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x[mid_plane], norm=simple_norm(x[mid_plane], percent = 99), interpolation='nearest')\n","plt.axis('off')\n","plt.title('Training Source (single Z plane)');\n","plt.subplot(1,2,2)\n","plt.imshow(y[mid_plane], norm=simple_norm(y[mid_plane], percent = 99), interpolation='nearest')\n","plt.axis('off')\n","plt.title('Training Target (single Z plane)');\n","plt.savefig('/content/TrainingDataExample_Fnet.png',bbox_inches='tight',pad_inches=0)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xyQZKby8yFME"},"source":["## **3.2. Data augmentation**\n","---\n",""]},{"cell_type":"markdown","metadata":{"id":"w_jCy7xOx2g3"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n","Data augmentation is performed here by rotating images in XY-Plane and flip them along X-Axis. This only works if the patches are square in XY.\n","\n","**Note:** Using a full augmented dataset can exceed the RAM limitations of the colab notebook. If the augmented dataset is too large, the notebook will therefore only pick a subset of the augmented dataset for training. Make sure you only augment datasets which are small (ca. 20-30 images).\n","\n","**Note 2: If you intend to use the retraining option at a later point, save the dataset in a folder in your Google Drive.**"]},{"cell_type":"code","metadata":{"id":"DMqWq5-AxnFU","cellView":"form"},"source":["Use_Data_augmentation = False #@param{type:\"boolean\"}\n","\n","#@markdown Select this option if you want to use augmentation to increase the size of your dataset\n","\n","#@markdown **Rotate each image 3 times by 90 degrees.**\n","Rotation = True #@param{type:\"boolean\"}\n","\n","#@markdown **Flip each image once around the x axis of the stack.**\n","Flip = True #@param{type:\"boolean\"}\n","\n","\n","#@markdown **Would you like to save your augmented images?**\n","\n","Save_augmented_images = False #@param {type:\"boolean\"}\n","\n","Saving_path = \"\" #@param {type:\"string\"}\n","\n","\n","if not Save_augmented_images:\n"," Saving_path= \"/content\"\n","\n","\n","def rotation_aug(Source_path, Target_path, aug_source_dest='augmented_source', aug_target_dest='augmented_target', flip=False):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path)\n"," \n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," # Source Rotation\n"," source_img_90 = np.rot90(source_img,axes=(1,2))\n"," source_img_180 = np.rot90(source_img_90,axes=(1,2))\n"," source_img_270 = np.rot90(source_img_180,axes=(1,2))\n","\n"," # Target Rotation\n"," target_img_90 = np.rot90(target_img,axes=(1,2))\n"," target_img_180 = np.rot90(target_img_90,axes=(1,2))\n"," target_img_270 = np.rot90(target_img_180,axes=(1,2))\n","\n"," # Add a flip to the rotation\n"," \n"," if flip == True:\n"," source_img_lr = np.fliplr(source_img)\n"," source_img_90_lr = np.fliplr(source_img_90)\n"," source_img_180_lr = np.fliplr(source_img_180)\n"," source_img_270_lr = np.fliplr(source_img_270)\n","\n"," target_img_lr = np.fliplr(target_img)\n"," target_img_90_lr = np.fliplr(target_img_90)\n"," target_img_180_lr = np.fliplr(target_img_180)\n"," target_img_270_lr = np.fliplr(target_img_270)\n","\n"," #source_img_90_ud = np.flipud(source_img_90)\n"," \n"," # Save the augmented files\n"," # Source images\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+image,source_img)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_90.tif',source_img_90)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_180.tif',source_img_180)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_270.tif',source_img_270)\n"," # Target images\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+image,target_img)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_90.tif',target_img_90)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_180.tif',target_img_180)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_270.tif',target_img_270)\n","\n"," if flip == True:\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_90_lr.tif',source_img_90_lr)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_180_lr.tif',source_img_180_lr)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_270_lr.tif',source_img_270_lr)\n","\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_90_lr.tif',target_img_90_lr)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_180_lr.tif',target_img_180_lr)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_270_lr.tif',target_img_270_lr)\n","\n","def flip(Source_path, Target_path, aug_source_dest='augmented_source', aug_target_dest='augmented_target'):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path) \n","\n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," source_img_lr = np.fliplr(source_img)\n"," target_img_lr = np.fliplr(target_img)\n","\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+image,source_img)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n","\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+image,target_img)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n","\n","\n","if Use_Data_augmentation:\n","\n"," if os.path.exists(Saving_path+'/augmented_source'):\n"," shutil.rmtree(Saving_path+'/augmented_source')\n"," os.mkdir(Saving_path+'/augmented_source')\n","\n"," if os.path.exists(Saving_path+'/augmented_target'):\n"," shutil.rmtree(Saving_path+'/augmented_target') \n"," os.mkdir(Saving_path+'/augmented_target')\n"," \n"," print(\"Data augmentation enabled\")\n"," print(\"Data augmentation in progress....\")\n","\n"," if Rotation == True:\n"," rotation_aug(Training_source,Training_target,flip=Flip)\n"," elif Rotation == False and Flip == True:\n"," flip(Training_source,Training_target)\n"," #Fetch the path and extract the name of the signal folder\n"," Training_source = Saving_path+\"/augmented_source\"\n"," source_name = os.path.basename(os.path.normpath(Training_source))\n","\n"," #Fetch the path and extract the name of the target folder\n"," Training_target = Saving_path+\"/augmented_target\"\n"," target_name = os.path.basename(os.path.normpath(Training_target))\n","\n"," #Here, we ensure that there aren't too many images in the buffer.\n"," #The best value will depend on the size of the images and the assigned GPU.\n"," #If too many images are loaded to the buffer the notebook will terminate the training as the RAM limit will be exceeded.\n"," if len(os.listdir(Saving_path+'/augmented_source'))>100:\n"," number_of_images = 100\n"," else:\n"," number_of_images = len(os.listdir(Saving_path+'/augmented_source'))\n","\n"," os.chdir(\"/content/pytorch_fnet/scripts\")\n"," !chmod u+x train_model.sh\n"," !sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" train_model.sh #change the number of training images\n","\n"," print(\"Done\")\n","if not Use_Data_augmentation:\n"," print(bcolors.WARNING+\"Data augmentation disabled\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Nyf9ndiS7sL9"},"source":["#**4. Train the network**\n","---\n","\n","Before training, carefully read the different options. This applies especially if you have trained fnet on a dataset before.\n","\n","\n","###**Choose one of the options to train fnet**.\n","\n","**4.1.** If this is the first training on the chosen dataset, play this section to start training.\n","\n","**4.2.** If you want to continue training on an already pre-trained model choose this section\n","\n"," **Carefully read the options before starting training.**"]},{"cell_type":"markdown","metadata":{"id":"P9OJ0nlI71Rc"},"source":["##**4.2. Start Training**\n","---\n","\n","####Play the cell below to start training. \n","\n","**Note:** If you are training with a model of the same name as before, the model will be overwritten. If you want to keep the previous model save it before playing the cell below or give your model a different name (section 3).\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"code","metadata":{"id":"MQvrHFVcJ6VT","cellView":"form"},"source":["#@markdown ##Create the dataset files for training\n","if os.path.exists(model_path+'/'+model_name):\n"," shutil.rmtree(model_path+'/'+model_name)\n"," print(bcolors.WARNING +\"!! Existing model \"+model_name+\" was found and overwritten!!\")\n","os.mkdir(model_path+'/'+model_name)\n","\n","#os.chdir(model_path)\n","# source = os.listdir(Training_source)\n","# target = os.listdir(Training_target)\n","\n","if Use_Data_augmentation == True:\n","\n"," aug_source = os.listdir(Saving_path+'/augmented_source')\n"," aug_val_files = aug_source[-round(len(aug_source)*(percentage_validation/100)):]\n"," aug_source_files = aug_source[:-round(len(aug_source)*(percentage_validation/100))]\n","\n"," #Finally, we create a validation csv file to construct the validation dataset\n"," with open(model_path+'/'+model_name+'/'+model_name+'_val.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(aug_val_files)):\n"," writer.writerow([Saving_path+'/augmented_source/'+aug_val_files[i],Saving_path+\"/augmented_target/\"+aug_val_files[i]])\n","\n"," with open(model_path+'/'+model_name+'/'+model_name+'.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(aug_source_files)):\n"," writer.writerow([Saving_path+'/augmented_source/'+aug_source_files[i],Saving_path+'/augmented_target/'+aug_source_files[i]])\n","\n","else:\n"," #Here we define the random set of training files to be used for validation\n"," val_files = source[-round(len(source)*(percentage_validation/100)):]\n"," source_files = source[:-round(len(source)*(percentage_validation/100))]\n","\n"," #Finally, we create a validation csv file to construct the validation dataset\n"," with open(model_path+'/'+model_name+'/'+model_name+'_val.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(val_files)):\n"," writer.writerow([Training_source+'/'+val_files[i],Training_target+'/'+val_files[i]])\n","\n","\n"," with open(model_path+'/'+model_name+'/'+model_name+'.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(source_files)):\n"," writer.writerow([Training_source+\"/\"+source_files[i],Training_target+\"/\"+source_files[i]])"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"X8YHeSGr76je","cellView":"form"},"source":["#@markdown ####If your dataset is large the notebook might crash unexpectedly when loading the training data into the buffer. If this happens, reduce the number of images to be loaded into the buffer and restart the training.\n","os.chdir(\"/content/pytorch_fnet/scripts\")\n","number_of_images = 10#@param{type:\"number\"}\n","!chmod u+x train_model.sh\n","!sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" train_model.sh #change the number of training images"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"7Ofm-71T8ABX","cellView":"form"},"source":["#@markdown ##Start training\n","pdf_export(augmentation = Use_Data_augmentation)\n","start = time.time()\n","\n","#Here we import an additional module to the functions.py file to run it without errors.\n","\n","os.chdir('/content')\n","add_insert(\"/content/pytorch_fnet/fnet/functions.py\",5,\"import fnet.fnet_model\",\"\")\n","\n","### TRAIN THE MODEL ###\n","\n","print('Let''s start the training!')\n","#Here we start the training\n","!/content/pytorch_fnet/scripts/train_model.sh $model_name 0\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","#Create a pdf document with training summary\n","\n","pdf_export(trained = True, augmentation = Use_Data_augmentation)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"bOdyjxWV8IrO"},"source":["**Note:** Fnet takes a long time for training. If your notebook times out due to the length of the training or due to a loss of GPU acceleration the last checkpoint will be saved in the saved_models folder in the pytorch_fnet folder. If you want to save it in a more convenient location on your drive, remount the drive (if you got disconnected) and in the next cell enter the location (`model_path`) where you want to save the model (`model_name`) before continuing in 4.2. **If you did not time out you can ignore this section.**"]},{"cell_type":"markdown","metadata":{"id":"-JxxMmVr8Tw-"},"source":["##**4.2. Training from a previously saved model**\n","---\n","This section allows you to use networks you have previously trained and saved and to continue training them for more training steps. The folders have the same meaning as above (3.1.). If you want to save the previously trained model, create a copy now as this section will overwrite the weights of the old model. **You can currently only train the model with the same dataset and batch size that the network was previously trained on. Make sure your training datasets are in the same location as when you originally trained the model.**\n","\n","**Note: To use this section the *pytorch_fnet* folder must be in your *gdrive/My Drive*. (Simply, play cell 2. to make sure).**"]},{"cell_type":"code","metadata":{"id":"iDIgosht8U7F","cellView":"form"},"source":["#@markdown To test if performance improves after the initial training, you can continue training on the old model. This option can also be useful if Colab disconnects or times out.\n","#@markdown Enter the paths of the datasets you want to continue training on.\n","\n","#Here we repeat steps already used above in case the notebook needs to be restarted for this cell.\n","#We need to add a new line to the train.sh file\n","with open(\"/content/pytorch_fnet/scripts/train_model.sh\", \"r\") as f:\n"," if not \"gpu_ids ${GPU_IDS} \\\\\" in f.read():\n"," replace(\"/content/pytorch_fnet/scripts/train_model.sh\",\" --gpu_ids ${GPU_IDS}\",\" --gpu_ids ${GPU_IDS} \\\\\")\n","\n","#We add the necessary validation parameters here.\n","with open('/content/pytorch_fnet/scripts/train_model.sh','r') as scriptfile:\n"," lines = scriptfile.readlines()\n"," if 'PATH_DATASET_VAL_CSV' not in lines:\n"," insert = 'PATH_DATASET_VAL_CSV=\"data/csvs/${DATASET}_val.csv\"'\n"," add_insert(\"/content/pytorch_fnet/scripts/train_model.sh\",10,insert,\"\")\n"," add_insert(\"/content/pytorch_fnet/scripts/train_model.sh\",22,'\\n --path_dataset_val_csv ${PATH_DATASET_VAL_CSV}',\"\")\n","\n","\n","#Clear the White space from train.sh\n","with open('/content/pytorch_fnet/scripts/train_model.sh', 'r') as inFile,\\\n"," open('/content/pytorch_fnet/scripts/train_model_temp.sh', 'w') as outFile:\n"," for line in inFile:\n"," if line.strip():\n"," outFile.write(line)\n","os.remove('/content/pytorch_fnet/scripts/train_model.sh')\n","os.rename('/content/pytorch_fnet/scripts/train_model_temp.sh','/content/pytorch_fnet/scripts/train_model.sh')\n","\n","\n","#Change checkpoints\n","replace(\"/content/pytorch_fnet/train_model.py\",\"'--interval_save', type=int, default=500\",\"'--interval_save', type=int, default=100\")\n","\n","#Adapt Class Dataset for Tiff files\n","replace(\"/content/pytorch_fnet/train_model.py\",\"'--class_dataset', default='CziDataset'\",\"'--class_dataset', default='TiffDataset'\")\n","\n","### Choosing and editing the path names ###\n","\n","Pretrained_model_folder = \"\" #@param{type:\"string\"}\n","Pretrained_model_name = os.path.basename(Pretrained_model_folder)\n","Pretrained_model_path = os.path.dirname(Pretrained_model_folder)\n","\n","full_model_path = Pretrained_model_path+'/'+Pretrained_model_name\n","\n","new_full_model_path = convert_to_script_compatible_path(full_model_path)\n","new_full_model_path_csv = new_full_model_path+'\\/'+Pretrained_model_name+'\\.csv'\n","new_full_model_path_val_csv = new_full_model_path+'\\/'+Pretrained_model_name+'_val\\.csv'\n","\n","Pretrained_model_name_x = Pretrained_model_name+\"}\"\n","\n","#We get the example data and the number of images from the csv path file#\n","\n","with open(full_model_path+'/'+Pretrained_model_name+'.csv') as csvfile:\n"," csvreader = csv.reader(csvfile)\n"," header = next(csvreader)\n"," number_of_images = 0\n"," for line in csvreader:\n"," ExampleSource = line[0]\n"," ExampleTarget = line[1]\n"," number_of_images += 1\n","\n","with open(full_model_path+'/'+Pretrained_model_name+'_val.csv') as csvfile:\n"," csvreader = csv.reader(csvfile)\n"," header = next(csvreader)\n"," number_of_val_images = 0\n"," for line in csvreader:\n"," number_of_val_images += 1\n","\n","#Batch Size\n","\n","batch_size = 4 #@param {type:\"number\"}\n","\n","# Editing the train.sh script file #\n","\n","os.chdir('/content/pytorch_fnet/scripts')\n","\n","#Change the train_model.sh file to include chosen dataset\n","!chmod u+x ./train_model.sh\n","!sed -i \"s/1:-.*/1:-$Pretrained_model_name_x/g\" train_model.sh\n","!sed -i \"s/train_size .* -v/train_size 1.0 -v/g\" train_model.sh #Use the whole training dataset for training\n","!sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" train_model.sh #change the number of training images\n","!sed -i \"s/BATCH_SIZE=.*/BATCH_SIZE=$batch_size/g\" train_model.sh #change the batch size\n","\n","!sed -i 's/RUN_DIR=.*/RUN_DIR=\"$new_full_model_path\"/g' train_model.sh\n","!sed -i 's/PATH_DATASET_TRAIN_CSV=.*/PATH_DATASET_TRAIN_CSV=\"$new_full_model_path_csv\"/g' train_model.sh\n","!sed -i 's/PATH_DATASET_VAL_CSV=.*/PATH_DATASET_VAL_CSV=\"$new_full_model_path_val_csv\"/g' train_model.sh\n","\n","replace('/content/pytorch_fnet/scripts/train_model.sh','python scripts', '#python scripts')\n","replace('/content/pytorch_fnet/scripts/train_model.sh','python train_model.py', 'python /content/pytorch_fnet/train_model.py')\n","replace('/content/pytorch_fnet/scripts/train_model.sh','PATH_DATASET_ALL_CSV','#PATH_DATASET_ALL_CSV')\n","\n","# Find the number of steps to add and then add #\n","with open(Pretrained_model_folder+'/losses.csv') as f:\n"," previous_steps = sum(1 for line in f)\n","print('continuing training after step '+str(previous_steps-1))\n","\n","print('To start re-training play section 4.2. below')\n","\n","#@markdown For how many additional steps do you want to train the model?\n","add_steps = 150#@param {type:\"number\"}\n","\n","#Calculate the new number of total training epochs. Subtract 1 to discount the title row of the csv file.\n","new_steps = previous_steps + add_steps -1\n","os.chdir('/content/pytorch_fnet/scripts')\n","\n","#Edit train_model.sh file to include new total number of training epochs\n","!sed -i \"s/N_ITER=.*/N_ITER=$new_steps/g\" train_model.sh\n","\n","# Display example data #\n","\n","#Load one randomly chosen training source file\n","#random_choice=random.choice(os.listdir(Training_source))\n","x = io.imread(ExampleSource)\n","\n","#Find image Z dimension and select the mid-plane\n","Image_Z = x.shape[0]\n","mid_plane = int(Image_Z / 2)+1\n","\n","#os.chdir(Training_target)\n","y = io.imread(ExampleTarget)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x[mid_plane], norm=simple_norm(x[mid_plane], percent = 99), interpolation='nearest')\n","plt.axis('off')\n","plt.title('Training Source (single Z plane)');\n","plt.subplot(1,2,2)\n","plt.imshow(y[mid_plane], norm=simple_norm(y[mid_plane], percent = 99), interpolation='nearest')\n","plt.axis('off')\n","plt.title('Training Target (single Z plane)');\n","plt.savefig('/content/TrainingDataExample_Fnet.png',bbox_inches='tight',pad_inches=0)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"cellView":"form","id":"h1INk9nRE15L"},"source":["#@markdown ####If your dataset is large the notebook might crash unexpectedly when loading the training data into the buffer. If this happens, reduce the number of images to be loaded into the buffer and restart the training.\n","os.chdir(\"/content/pytorch_fnet/scripts\")\n","number_of_images = 10#@param{type:\"number\"}\n","!chmod u+x train_model.sh\n","!sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" train_model.sh #change the number of training images"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"5IXdFqhM8gO2","cellView":"form"},"source":["Use_Data_augmentation = False \n","start = time.time()\n","\n","#@markdown ##4.2. Start re-training model\n","\n","os.chdir('/content/pytorch_fnet/fnet')\n","\n","add_insert(\"/content/pytorch_fnet/fnet/functions.py\",5,\"import fnet.fnet_model\",\"\")\n","\n","#Here we retrain the model on the chosen dataset.\n","os.chdir('/content/pytorch_fnet/')\n","!chmod u+x ./scripts/train_model.sh\n","!./scripts/train_model.sh $Pretrained_model_name 0\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","#Here, we redefine the variable names for the pdf export\n","percentage_validation = round((number_of_val_images/(number_of_images+number_of_val_images))*100)\n","steps = new_steps\n","model_name = Pretrained_model_name\n","model_path = Pretrained_model_path\n","Training_source = os.path.dirname(ExampleSource)\n","Training_target = os.path.dirname(ExampleTarget)\n","#Create a pdf document with training summary\n","pdf_export(trained = True, augmentation = Use_Data_augmentation)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"_0Hynw3-xHp1"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**"]},{"cell_type":"code","metadata":{"id":"eAJzMwPA6tlH","cellView":"form"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the name of the model and path to model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","#Create a folder for the quality control metrics\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/QualityControl\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/QualityControl\")\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/QualityControl\")\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"dhJROwlAMv5o"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased.\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"vMzSP50kMv5p","cellView":"form"},"source":["#@markdown ##Play the cell to show figure of training errors\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","iterationNumber_training = []\n","iterationNumber_val = []\n","\n","import csv\n","from matplotlib import pyplot as plt\n","with open(QC_model_path+'/'+QC_model_name+'/'+'losses.csv','r') as csvfile:\n"," plots = csv.reader(csvfile, delimiter=',')\n"," next(plots)\n"," for row in plots:\n"," iterationNumber_training.append(int(row[0]))\n"," lossDataFromCSV.append(float(row[1]))\n","\n","with open(QC_model_path+'/'+QC_model_name+'/'+'losses_val.csv','r') as csvfile_val:\n"," plots = csv.reader(csvfile_val, delimiter=',')\n"," next(plots)\n"," for row in plots:\n"," iterationNumber_val.append(int(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(iterationNumber_training, lossDataFromCSV, label='Training loss')\n","plt.plot(iterationNumber_val, vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. iteration number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Iteration')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(iterationNumber_training, lossDataFromCSV, label='Training loss')\n","plt.semilogy(iterationNumber_val, vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. iteration number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Iteration')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/QualityControl/lossCurvePlots.png', bbox_inches='tight', pad_inches=0)\n","plt.show()\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"X5_92nL2xdP6"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n","\n","\n","**Note:** If you receive a *CUDA out of memory* error, this can be caused by the size of the data that model needs to predict or the type of GPU has allocated to your session. To solve this issue, you can *factory reset runtime* to attempt to connect to a different GPU or use a dataset with smaller images.\n","\n","**Note 2:** If you get an 'sequence argument must have length equal to input rank' error, you may need to reshape your images from [z, x, y, c] or [c,z,x,y] to [z,x,y] by squeezing out the channel dimension, e.g. using numpy.squeeze. "]},{"cell_type":"code","metadata":{"id":"w90MdriMxhjD","cellView":"form"},"source":["#----------------CREATING PREDICTIONS FOR QUALITY CONTROL----------------------------------#\n","\n","### Choosing and editing the path names ###\n","\n","#Choose the folder with the quality control datasets\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","Predictions_name = \"QualityControl\" \n","Predictions_name_x = Predictions_name+\"}\"\n","\n","#If the folder you are creating already exists, delete the existing version to overwrite.\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/QualityControl/Predictions\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/QualityControl/Predictions\")\n","\n","# Get the name of the folder the test data is in\n","source_dataset_name = os.path.basename(os.path.normpath(Source_QC_folder))\n","target_dataset_name = os.path.basename(os.path.normpath(Target_QC_folder))\n","\n","new_full_QC_model_path = convert_to_script_compatible_path(full_QC_model_path)\n","new_full_QC_model_path_dataset = new_full_QC_model_path+'\\${DATASET}'\n","new_full_QC_model_path_csv = new_full_QC_model_path+'\\/QualityControl\\/qc\\${TEST_OR_TRAIN}\\.csv'# Get the name of the folder the test data is in\n","\n","\n","### Editing the predict.sh script file ###\n","\n","# Get permission to the predict.sh file and change the name of the dataset to the Predictions_folder.\n","os.chdir('/content/pytorch_fnet/')\n","!chmod u+x ./scripts/predict.sh\n","!sed -i \"s/1:-.*/1:-$Predictions_name_x/g\" ./scripts/predict.sh\n","\n","#Here, we remove the 'train' option from predict.sh as we don't need to run predictions on the train data.\n","!sed -i \"s/in test.*/in test/g\" ./scripts/predict.sh\n","\n","!if ! grep class_dataset ./scripts/predict.sh;then sed -i 's/DIR} \\\\/DIR} \\\\\\'$''\\n' --class_dataset TiffDataset \\\\/' /content/pytorch_fnet/scripts/predict.sh; fi\n","!if grep CziDataset ./scripts/predict.sh;then sed -i 's/CziDataset/TiffDataset/' ./scripts/predict.sh; fi \n","\n","!sed -i \"s/MODEL_DIR=.*/MODEL_DIR=$new_full_QC_model_path/g\" ./scripts/predict.sh\n","!sed -i \"s/path_dataset_csv.*/path_dataset_csv\\ $new_full_QC_model_path_csv\\ \\\\\\/g\" ./scripts/predict.sh\n","!sed -i \"s/path_save_dir.*/path_save_dir $new_full_QC_model_path\\/QualityControl\\/Predictions\\ \\\\\\/g\" ./scripts/predict.sh\n","\n","\n","### Create a path csv file for prediction (QC)###\n","\n","#Here we create a qctest.csv to locate the files used for QC\n","test_signal = os.listdir(Source_QC_folder)\n","test_target = os.listdir(Target_QC_folder)\n","\n","with open(full_QC_model_path+'/QualityControl/qctest.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(test_signal)):\n"," writer.writerow([Source_QC_folder+'/'+test_signal[i],Target_QC_folder+'/'+test_signal[i]])\n","\n","### RUN THE PREDICTION ###\n","!/content/pytorch_fnet/scripts/predict.sh $Predictions_name 0\n","\n","### Save the results ###\n","QC_results_files = os.listdir(full_QC_model_path+'/QualityControl/Predictions')\n","\n","if os.path.exists(QC_model_path+'/'+QC_model_name+'/QualityControl/Prediction'):\n"," shutil.rmtree(QC_model_path+'/'+QC_model_name+'/QualityControl/Prediction')\n","os.mkdir(QC_model_path+'/'+QC_model_name+'/QualityControl/Prediction')\n","\n","if os.path.exists(QC_model_path+'/'+QC_model_name+'/QualityControl/Signal'):\n"," shutil.rmtree(QC_model_path+'/'+QC_model_name+'/QualityControl/Signal')\n","os.mkdir(QC_model_path+'/'+QC_model_name+'/QualityControl/Signal')\n","\n","if os.path.exists(QC_model_path+'/'+QC_model_name+'/QualityControl/Target'):\n"," shutil.rmtree(QC_model_path+'/'+QC_model_name+'/QualityControl/Target')\n","os.mkdir(QC_model_path+'/'+QC_model_name+'/QualityControl/Target')\n","\n","for i in range(len(QC_results_files)-2):\n"," if os.path.isdir(full_QC_model_path+'/QualityControl/Predictions/'+QC_results_files[i]):\n"," shutil.copyfile(full_QC_model_path+'/QualityControl/Predictions/'+QC_results_files[i]+'/prediction_'+QC_model_name+'.tiff', QC_model_path+'/'+QC_model_name+'/QualityControl/Prediction/'+'Predicted_'+test_signal[i])\n"," if os.path.exists(full_QC_model_path+'/QualityControl/Predictions/'+QC_results_files[i]+'/signal.tiff'):\n"," shutil.copyfile(full_QC_model_path+'/QualityControl/Predictions/'+QC_results_files[i]+'/signal.tiff', QC_model_path+'/'+QC_model_name+'/QualityControl/Signal/'+test_signal[i])\n"," shutil.copyfile(full_QC_model_path+'/QualityControl/Predictions/'+QC_results_files[i]+'/target.tiff', QC_model_path+'/'+QC_model_name+'/QualityControl/Target/'+test_signal[i])\n"," else:\n"," shutil.copyfile(Source_QC_folder+'/'+test_signal[i],QC_model_path+'/'+QC_model_name+'/QualityControl/Signal/'+test_signal[i])\n"," shutil.copyfile(Target_QC_folder+'/'+test_target[i],QC_model_path+'/'+QC_model_name+'/QualityControl/Target/'+test_signal[i])\n","\n","shutil.rmtree(full_QC_model_path+'/QualityControl/Predictions')\n","\n","#-----------------------------METRICS EVALUATION-------------------------------#\n","\n","# Calculating the position of the mid-plane slice\n","# Perform prediction on all datasets in the Source_QC folder\n","\n","#Finding the middle slice\n","img = io.imread(os.path.join(Source_QC_folder, os.listdir(Source_QC_folder)[0]))\n","n_slices = img.shape[0]\n","z_mid_plane = int(n_slices / 2)+1\n","\n","path_metrics_save = QC_model_path+'/'+QC_model_name+'/QualityControl/'\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(path_metrics_save+'QC_metrics_'+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"File name\",\"Slice #\",\"Prediction v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Prediction v. GT PSNR\"]) \n"," \n"," # These lists will be used to collect all the metrics values per slice\n"," file_name_list = []\n"," slice_number_list = []\n"," mSSIM_GvP_list = []\n"," NRMSE_GvP_list = []\n"," PSNR_GvP_list = []\n","\n"," # These lists will be used to display the mean metrics for the stacks\n"," mSSIM_GvP_list_mean = []\n"," NRMSE_GvP_list_mean = []\n"," PSNR_GvP_list_mean = []\n","\n"," # Let's loop through the provided dataset in the QC folders\n"," for thisFile in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder, thisFile)):\n"," print('Running QC on: '+thisFile)\n","\n"," test_GT_stack = io.imread(os.path.join(Target_QC_folder, thisFile))\n"," if len(test_GT_stack.shape) > 3:\n"," test_GT_stack = test_GT_stack.squeeze()\n"," test_source_stack = io.imread(os.path.join(Source_QC_folder,thisFile))\n"," test_prediction_stack = io.imread(os.path.join(path_metrics_save+\"Prediction/\",'Predicted_'+thisFile))\n"," test_prediction_stack = np.squeeze(test_prediction_stack,axis=(0,))\n"," n_slices = test_GT_stack.shape[0]\n","\n"," img_SSIM_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_RSE_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n","\n"," for z in range(n_slices): \n"," \n"," # -------------------------------- Prediction --------------------------------\n","\n"," test_GT_norm,test_prediction_norm = norm_minmse(test_GT_stack[z], test_prediction_stack[z], normalize_gt=True)\n","\n"," # -------------------------------- Calculate the SSIM metric and maps --------------------------------\n","\n"," # Calculate the SSIM maps and index\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = structural_similarity(test_GT_norm, test_prediction_norm, data_range=1.0, full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n"," #Calculate ssim_maps\n"," img_SSIM_GTvsPrediction_stack[z] = np.float32(img_SSIM_GTvsPrediction)\n"," \n","\n"," # -------------------------------- Calculate the NRMSE metrics --------------------------------\n","\n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n","\n"," # Calculate SE maps\n"," img_RSE_GTvsPrediction_stack[z] = np.float32(img_RSE_GTvsPrediction)\n","\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n","\n","\n"," # Calculate the PSNR between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n","\n","\n"," writer.writerow([thisFile, str(z),str(index_SSIM_GTvsPrediction),str(NRMSE_GTvsPrediction),str(PSNR_GTvsPrediction)])\n"," \n"," # Collect values to display in dataframe output\n"," #file_name_list.append(thisFile)\n"," slice_number_list.append(z)\n"," mSSIM_GvP_list.append(index_SSIM_GTvsPrediction)\n","\n"," NRMSE_GvP_list.append(NRMSE_GTvsPrediction)\n","\n"," PSNR_GvP_list.append(PSNR_GTvsPrediction)\n","\n","\n"," if (z == z_mid_plane): # catch these for display\n"," SSIM_GTvsP_forDisplay = index_SSIM_GTvsPrediction\n","\n"," NRMSE_GTvsP_forDisplay = NRMSE_GTvsPrediction\n","\n"," \n"," # If calculating average metrics for dataframe output\n"," file_name_list.append(thisFile)\n"," mSSIM_GvP_list_mean.append(sum(mSSIM_GvP_list)/len(mSSIM_GvP_list))\n","\n"," NRMSE_GvP_list_mean.append(sum(NRMSE_GvP_list)/len(NRMSE_GvP_list))\n","\n"," PSNR_GvP_list_mean.append(sum(PSNR_GvP_list)/len(PSNR_GvP_list))\n","\n"," # ----------- Change the stacks to 32 bit images -----------\n"," img_SSIM_GTvsPrediction_stack_32 = img_as_float32(img_SSIM_GTvsPrediction_stack, force_copy=False)\n"," img_RSE_GTvsPrediction_stack_32 = img_as_float32(img_RSE_GTvsPrediction_stack, force_copy=False)\n","\n","\n"," # ----------- Saving the error map stacks -----------\n"," io.imsave(path_metrics_save+'SSIM_GTvsPrediction_'+thisFile,img_SSIM_GTvsPrediction_stack_32)\n"," io.imsave(path_metrics_save+'RSE_GTvsPrediction_'+thisFile,img_RSE_GTvsPrediction_stack_32)\n","\n","#Averages of the metrics per stack as dataframe output\n","pdResults = pd.DataFrame(file_name_list, columns = [\"File name\"])\n","pdResults[\"Prediction v. GT mSSIM\"] = mSSIM_GvP_list_mean\n","\n","pdResults[\"Prediction v. GT NRMSE\"] = NRMSE_GvP_list_mean\n","\n","pdResults[\"Prediction v. GT PSNR\"] = PSNR_GvP_list_mean\n","\n","pdResults.head()\n","\n","# All data is now processed saved\n","Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same way\n","\n","plt.figure(figsize=(15,10))\n","# Currently only displays the last computed set, from memory\n","\n","# Target (Ground-truth)\n","plt.subplot(2,3,1)\n","plt.axis('off')\n","img_GT = io.imread(os.path.join(Target_QC_folder, Test_FileList[-1]))\n","if len(img_GT.shape) > 3:\n"," img_GT = img_GT.squeeze()\n","plt.imshow(img_GT[z_mid_plane])\n","plt.title('Target (slice #'+str(z_mid_plane)+')')\n","\n","\n","#Setting up colours\n","cmap = plt.cm.Greys\n","\n","\n","# Source\n","plt.subplot(2,3,2)\n","plt.axis('off')\n","img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_Source[z_mid_plane],aspect='equal',cmap=cmap)\n","plt.title('Source (slice #'+str(z_mid_plane)+')')\n","\n","\n","#Prediction\n","plt.subplot(2,3,3)\n","plt.axis('off')\n","img_Prediction = io.imread(os.path.join(path_metrics_save+'Prediction/', 'Predicted_'+Test_FileList[-1]))\n","img_Prediction = np.squeeze(img_Prediction,axis=(0,))\n","plt.imshow(img_Prediction[z_mid_plane])\n","plt.title('Prediction (slice #'+str(z_mid_plane)+')')\n","\n","#Setting up colours\n","cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Prediction\n","plt.subplot(2,3,5)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_SSIM_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'SSIM_GTvsPrediction_'+Test_FileList[-1]))\n","imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0,vmax=1)\n","plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n","plt.title('SSIM map: Target vs. Prediction',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(SSIM_GTvsP_forDisplay,3)),fontsize=14)\n","\n","\n","#Root Squared Error between GT and Prediction\n","plt.subplot(2,3,6)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_RSE_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'RSE_GTvsPrediction_'+Test_FileList[-1]))\n","imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n","plt.title('RSE map Target vs. Prediction',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsP_forDisplay,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n","plt.savefig(full_QC_model_path+'/QualityControl/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","\n","print('-----------------------------------')\n","print('Here are the average scores for the stacks you tested in Quality control. To see values for all slices, open the .csv file saved in the Qulity Control folder.')\n","pdResults.head()\n","\n","\n","#Make a pdf summary of the QC results\n","\n","qc_pdf_export()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-tJeeJjLnRkP"},"source":["#**6. Using the trained model**\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"d8wuQGjoq6eN"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Results_folder** folder.\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Results_folder`:** This folder will contain the predicted output images.\n","\n","If you want to use a model different from the most recently trained one, untick the box and enter the path of the model in **`Prediction_model_folder`**.\n","\n","**Note: `Prediction_model_folder` expects a folder name which contains a model.p file from a previous training.**\n","\n","**Note:** If you receive a *CUDA out of memory* error, this can be caused by the size of the data that model needs to predict or the type of GPU has allocated to your session. To solve this issue, you can *factory reset runtime* to attempt to connect to a different GPU or use a dataset with smaller images.\n"]},{"cell_type":"code","metadata":{"id":"y2TD5p7MZrEb","cellView":"form"},"source":["#Before prediction we will remove the old prediction folder because fnet won't execute if a path already exists that has the same name.\n","#This is just in case you have already trained on a dataset with the same name\n","#The data will be saved outside of the pytorch_folder (Results_folder) so it won't be lost when you run this section again.\n","\n","#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then play the cell to predict outputs from your unseen images.\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Results_folder = \"\" #@param {type:\"string\"}\n","\n","if os.path.exists(Results_folder+\"/Predictions\"):\n"," shutil.rmtree(Results_folder+\"/Predictions\")\n","\n","#@markdown ###Do you want to use the current trained model?\n","\n","Use_the_current_trained_model = True #@param{type:\"boolean\"}\n","#@markdown ###If not, provide the name of the model you want to use \n","\n","### Choosing and editing the path names ###\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","if Use_the_current_trained_model:\n"," Prediction_model_folder = model_path+'/'+model_name\n","\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","Prediction_model_name_x = Prediction_model_name+\"}\"\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","# Convert the path variables into a compatible format with the script files #\n","# Prediction path conversion\n","new_full_Prediction_model_path = convert_to_script_compatible_path(full_Prediction_model_path)\n","new_full_Prediction_model_path_csv = new_full_Prediction_model_path+'\\${TEST_OR_TRAIN}\\.csv'# Get the name of the folder the test data is in\n","\n","# Result path conversion\n","new_Results_folder_path = convert_to_script_compatible_path(Results_folder)\n","\n","# Get permission to the predict.sh file and change the name of the dataset to the Predictions_folder.\n","os.chdir('/content/pytorch_fnet/')\n","!chmod u+x ./scripts/predict.sh\n","\n","### Editing the predict.sh script file ###\n","\n","# Make sure the dataset type is set to .tif (debug note: could be changed at install in predict.py file?)\n","!if ! grep class_dataset ./scripts/predict.sh;then sed -i 's/DIR} \\\\/DIR} \\\\\\'$''\\n' --class_dataset TiffDataset \\\\/' /content/pytorch_fnet/scripts/predict.sh; fi\n","!if grep CziDataset ./scripts/predict.sh;then sed -i 's/CziDataset/TiffDataset/' /content/pytorch_fnet/scripts/predict.sh; fi \n","\n","# We allow the maximum number of images to be processed to be higher, i.e. 1000.\n","!sed -i \"s/N_IMAGES=.*/N_IMAGES=1000/g\" ./scripts/predict.sh\n","!sed -i \"s/1:-.*/1:-$Prediction_model_name_x/g\" ./scripts/predict.sh\n","\n","#Here, we remove the 'train' option from predict.sh as we don't need to run predictions on the train data.\n","!sed -i \"s/in test.*/in test/g\" ./scripts/predict.sh\n","\n","#We change the directories in the predict.sh file to our needed paths\n","!sed -i \"s/MODEL_DIR=.*/MODEL_DIR=$new_full_Prediction_model_path/g\" ./scripts/predict.sh\n","!sed -i \"s/path_dataset_csv.*/path_dataset_csv\\ $new_full_Prediction_model_path_csv\\ \\\\\\/g\" ./scripts/predict.sh\n","!sed -i \"s/path_save_dir.*/path_save_dir $new_Results_folder_path\\/Predictions\\ \\\\\\/g\" ./scripts/predict.sh\n","\n","# Changing the GPU ID seems to help reduce errors\n","replace('/content/pytorch_fnet/scripts/predict.sh','${GPU_IDS}','0')\n","\n","# We get rid of the options of saving signals and targets. Here, we just want predictions.\n","insert_1 = ' --no_signal \\\\\\n'\n","insert_2 = ' --no_target \\\\\\n'\n","add_insert(\"/content/pytorch_fnet/scripts/predict.sh\",14,insert_1,\"\")\n","add_insert(\"/content/pytorch_fnet/scripts/predict.sh\",14,insert_2,\"\")\n","\n","### Create the path csv file for prediction ###\n","\n","#Here we create a new test.csv with the paths to the dataset we want to predict on.\n","test_signal = os.listdir(Data_folder)\n","with open(full_Prediction_model_path+'/test.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(test_signal)):\n"," writer.writerow([Data_folder+\"/\"+test_signal[i],Data_folder+\"/\"+test_signal[i]])\n","\n","### WE RUN THE PREDICTION ###\n","start = time.time()\n","!/content/pytorch_fnet/scripts/predict.sh $Prediction_model_name 0\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","#Rename the results appropriately\n","Results = os.listdir(Results_folder+'/Predictions')\n","for i in Results:\n"," if os.path.isdir(Results_folder+'/Predictions/'+i):\n"," shutil.copyfile(Results_folder+'/Predictions/'+i+'/'+os.listdir(Results_folder+'/Predictions/'+i)[0],Results_folder+'/Predictions/'+'predicted_'+test_signal[int(i)])\n"," \n","for i in Results:\n"," if os.path.isdir(Results_folder+'/Predictions/'+i):\n"," shutil.rmtree(Results_folder+'/Predictions/'+i)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"bFtArIjs9tS9"},"source":["##**6.2. Assess predicted output**\n","---\n","Here, we inspect an example prediction from the predictions on the test dataset. Select the slice of the slice you want to visualize."]},{"cell_type":"code","metadata":{"id":"66-af3rO9vM4","cellView":"form"},"source":["#@markdown ###Select the slice would you like to view?\n","slice_number = 15#@param {type:\"number\"}\n","\n","def show_image(file=os.listdir(Data_folder)):\n"," os.chdir(Results_folder)\n","\n","#source_image = io.imread(test_signal[0])\n"," source_image = io.imread(os.path.join(Data_folder,file))\n"," prediction_image = io.imread(os.path.join(Results_folder,'Predictions/predicted_'+file))\n"," prediction_image = np.squeeze(prediction_image, axis=(0,))\n","\n","#Create the figure\n"," fig = plt.figure(figsize=(10,20))\n","\n"," #Setting up colours\n"," cmap = plt.cm.Greys\n","\n"," plt.subplot(1,2,1)\n"," print(prediction_image.shape)\n"," plt.imshow(source_image[slice_number], cmap = cmap, aspect = 'equal')\n"," plt.title('Source')\n"," plt.subplot(1,2,2)\n"," plt.imshow(prediction_image[slice_number], cmap = cmap, aspect = 'equal')\n"," plt.title('Prediction')\n","\n","interact(show_image, continuous_update=False);"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"89tlSWBC940z"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"uRcJEjslvTj2"},"source":["# **7. Version log**\n","---\n","**v1.13**: \n","\n","\n","* This version has an additional step before re-training in section 4.2. which allows to change the number of images loaded into buffer.\n","* An additional note is given for the QC step, indicating the shape of the image files.\n","* Existing model files are now overwritten in an additional section before the training cell, allowing errors to be corrected before the model folder is overwritten.\n","* The section 1 and 2 are now swapped for better export of *requirements.txt*.\n","* This version also now includes built-in version check and the version log that you're reading now."]},{"cell_type":"markdown","metadata":{"id":"UvSlTaH14s3t"},"source":["#**Thank you for using fnet!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/fnet_ZeroCostDL4Mic.ipynb b/Colab_notebooks/fnet_ZeroCostDL4Mic.ipynb deleted file mode 100644 index f7c44035..00000000 --- a/Colab_notebooks/fnet_ZeroCostDL4Mic.ipynb +++ /dev/null @@ -1 +0,0 @@ -{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"fnet_3D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1SisekHpRSJ0QKHvDePqFe09lkklVytwI","timestamp":1622728423435},{"file_id":"12UsRdIQbcWQjYewI2wrcwIWfVxc6hOfc","timestamp":1620660071757},{"file_id":"1A26cn0nxWQCv-LuP3UBfyCWlKBGIo0RU","timestamp":1611063104553},{"file_id":"1MmLTCC0nyX3Akb9V4C_OVxM3X_M8u-eX","timestamp":1610543191319},{"file_id":"1paNjUObR5Rcr4BMGADJTz0PQBBLZDPrY","timestamp":1602522500580},{"file_id":"1WZRIoSBNcRUEq4-Rq5M4mDkIaOlEHnxz","timestamp":1588762142860},{"file_id":"10weAY0es-pEfHlACCaBCKK7PmgdoJqdh","timestamp":1587728072051},{"file_id":"10Ze0rFZoooyyTL_OIVWGdFJEhWE6_cSB","timestamp":1586789421439},{"file_id":"1SsGyUbWcMaLGHFepMuKElRNYLdEBUwf6","timestamp":1583244509550}],"collapsed_sections":["IkSguVy8Xv83","jWAz2i7RdxUV","gKDLkLWUd-YX","89tlSWBC940z","UvSlTaH14s3t"],"toc_visible":true,"machine_shape":"hm"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.7"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"IkSguVy8Xv83"},"source":["#**Label-free prediction - fnet**\n","---\n","\n"," \n","Label-free prediction (fnet) is a neural network developped to infer the distribution of specific cellular structures from label-free images such as brightfield or EM images. It was first published in 2018 by [Ounkomol *et al.* in Nature Methods](https://www.nature.com/articles/s41592-018-0111-2). The network uses a common U-Net architecture and is trained using paired imaging volumes from the same field of view, imaged in a label-free (e.g. brightfield) and labelled condition (e.g. fluorescence images of a specific label of interest). When trained, this allows the user to identify certain structures from brightfield images alone. The performance of fnet may depend significantly on the structure at hand.\n","\n","---\n"," *Disclaimer*:\n","\n"," This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n"," This notebook is largely based on the paper: \n","\n","**Label-free prediction of three-dimensional fluorescence images from transmitted light microscopy** by Ounkomol *et al.* in Nature Methods, 2018 (https://www.nature.com/articles/s41592-018-0111-2)\n","\n"," And source code found in: /~https://github.com/AllenCellModeling/pytorch_fnet\n","\n"," **Please also cite this original paper when using or developing this notebook.** \n"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"gKDLkLWUd-YX"},"source":["#**0. Before getting started**\n","---\n","**Data Format**\n","\n"," **The data used to train fnet must be 3D stacks in .tiff (.tif) file format and contain the signal (e.g. bright-field image) and the target channel (e.g. fluorescence) for each field of view**. To use this notebook on user data, upload the data in the following format to your google drive. To ensure corresponding images are used during training give corresponding signal and target images the same name.\n","\n","Information on how to generate a training dataset is available in our Wiki page: /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n"," **Note: Your *dataset_folder* should not have spaces or brackets in its name as this is not recognized by the fnet code and will throw an error** \n","\n","\n","* Experiment A\n"," - **Training dataset**\n"," - bright-field images\n"," - img_1.tif, img_2.tif, ...\n"," - fluorescence images\n"," - img_1.tif, img_2.tif, ...\n"," - **Quality control dataset**\n"," - bright-field images\n"," - img_1.tif, img_2.tif\n"," - fluorescence images\n"," - img_1.tif, img_2.tif\n"," - **Data to be predicted**\n"," - **Results**\n","\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **1. Initialise the Colab session**\n","\n","\n","\n","\n","---\n"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["\n","## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelerator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"zCvebubeSaGY","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","%tensorflow_version 1.x\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"sNIVx8_CLolt"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","cellView":"form"},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"AdN8B91xZO0x"},"source":["#**2. Install fnet and dependencies**\n","---\n","Running fnet requires the fnet folder to be downloaded into the session's Files. As fnet needs several packages to be installed, this step may take a few minutes.\n","\n","You can ignore **the error warnings** as they refer to packages not required for this notebook.\n","\n","**Note: It is not necessary to keep the pytorch_fnet folder after you are finished using the notebook, so it can be deleted afterwards by playing the last cell (bottom).**"]},{"cell_type":"code","metadata":{"id":"fq21zJVFNASx","cellView":"form"},"source":["Notebook_version = ['1.12.2']\n","\n","from builtins import any as b_any\n","\n","def get_requirements_path():\n"," # Store requirements file in 'contents' directory \n"," current_dir = os.getcwd()\n"," dir_count = current_dir.count('/') - 1\n"," path = '../' * (dir_count) + 'requirements.txt'\n"," return path\n","\n","def filter_files(file_list, filter_list):\n"," filtered_list = []\n"," for fname in file_list:\n"," if b_any(fname.split('==')[0] in s for s in filter_list):\n"," filtered_list.append(fname)\n"," return filtered_list\n","\n","def build_requirements_file(before, after):\n"," path = get_requirements_path()\n","\n"," # Exporting requirements.txt for local run\n"," !pip freeze > $path\n","\n"," # Get minimum requirements file\n"," df = pd.read_csv(path, delimiter = \"\\n\")\n"," mod_list = [m.split('.')[0] for m in after if not m in before]\n"," req_list_temp = df.values.tolist()\n"," req_list = [x[0] for x in req_list_temp]\n","\n"," # Replace with package name and handle cases where import name is different to module name\n"," mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]\n"," mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list] \n"," filtered_list = filter_files(req_list, mod_replace_list)\n","\n"," file=open(path,'w')\n"," for item in filtered_list:\n"," file.writelines(item + '\\n')\n","\n"," file.close()\n","\n","import sys\n","before = [str(m) for m in sys.modules]\n","\n","!pip install fpdf\n","\n","#@markdown ##Play this cell to install fnet and dependencies.\n","import os\n","import csv\n","import shutil\n","import random\n","from tempfile import mkstemp\n","from shutil import move, copymode\n","from os import fdopen, remove\n","import numpy as np\n","import shutil\n","from tempfile import mkstemp\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from skimage import img_as_float32\n","from distutils.dir_util import copy_tree\n","from datetime import datetime\n","from astropy.visualization import simple_norm\n","import time\n","from fpdf import FPDF, HTMLMixin\n","from pip._internal.operations.freeze import freeze\n","import subprocess\n","\n","#Ensure tensorflow 1.x\n","%tensorflow_version 1.x\n","import tensorflow\n","print(tensorflow.__version__)\n","\n","print(\"Tensorflow enabled.\")\n","\n","#clone fnet from github to colab\n","!git clone -b release_1 --single-branch /~https://github.com/AllenCellModeling/pytorch_fnet.git; cd pytorch_fnet; pip install .\n","!pip install -U scipy==1.2.0\n","!pip install matplotlib==2.2.3\n","from skimage import io\n","from matplotlib import pyplot as plt\n","import pandas as pd\n","import matplotlib as mpl\n","\n","\n","def replace(file_path, pattern, subst):\n"," \"\"\"Function replaces a pattern in a .py file with a new pattern.\"\"\"\n"," \n"," \"\"\"Parameters:\n"," -file_path (string): path to the file to be changed.\n"," -pattern (string): pattern to be replaced. Make sure this is as unique as possible.\n"," -subst (string): new pattern. \"\"\"\n"," #Create temp file\n"," fh, abs_path = mkstemp()\n"," with fdopen(fh,'w') as new_file:\n"," with open(file_path) as old_file:\n"," for line in old_file:\n"," new_file.write(line.replace(pattern, subst))\n"," #Copy the file permissions from the old file to the new file\n"," copymode(file_path, abs_path)\n"," #Remove original file\n"," remove(file_path)\n"," #Move new file\n"," move(abs_path, file_path)\n","\n","def add_insert(filepath,line_number,insertion,append):\n"," \"\"\"Function which inserts the a line into a document.\"\"\"\n"," \n"," \"\"\"Parameters:\n"," -filepath (string): path to the file which needs to be edited.\n"," -line (integer): Where to insert the new line. In the file, this line is ideally an empty one.\n"," -insertion (string): The line to be inserted. If it already exists it will not be added again.\n"," -append (string): If anything additional needs to be appended to the line, use this. Otherwise, leave as \"\" \"\"\"\n"," \n"," f = open(filepath, \"r\")\n"," contents = f.readlines()\n"," f.close()\n"," f = open(filepath, \"r\")\n"," if not insertion in f.read():\n"," contents.insert(line_number, insertion)\n"," if append != \"\":\n"," contents.append(append)\n"," f.close()\n"," f = open(filepath, \"w\")\n"," contents = \"\".join(contents)\n"," f.write(contents)\n"," f.close()\n","\n","\n","def convert_to_script_compatible_path(original_path):\n"," \"\"\"Function which converts 'original_path' into a compatible format 'new_full_path' with the fnet .sh files \"\"\"\n"," new_full_path = \"\"\n"," for s in original_path:\n"," if s=='/':\n"," new_full_path += '\\/'\n"," else:\n"," new_full_path += s\n","\n"," return new_full_path\n","\n","\n","def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," \"\"\"Percentile-based image normalization.\"\"\"\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n","def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n","def norm_minmse(gt, x, normalize_gt=True):\n"," \"\"\"This function is adapted from Martin Weigert\"\"\"\n","\n"," \"\"\"\n"," normalizes and affinely scales an image pair such that the MSE is minimized \n"," \n"," Parameters\n"," ----------\n"," gt: ndarray\n"," the ground truth image \n"," x: ndarray\n"," the image that will be affinely scaled \n"," normalize_gt: bool\n"," set to True of gt image should be normalized (default)\n"," Returns\n"," -------\n"," gt_scaled, x_scaled \n"," \"\"\"\n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","#Here we replace values in the old files\n","#Change maximum pixel number\n","replace(\"/content/pytorch_fnet/fnet/transforms.py\",'n_max_pixels=9732096','n_max_pixels=20000000')\n","replace(\"/content/pytorch_fnet/predict.py\",'6000000','20000000')\n","\n","#Prevent resizing in the training and the prediction\n","replace(\"/content/pytorch_fnet/predict.py\",\"0.37241\",\"1.0\")\n","replace(\"/content/pytorch_fnet/train_model.py\",\"0.37241\",\"1.0\")\n","\n","#Change the default dataset type in the training module to .tif\n","replace(\"/content/pytorch_fnet/train_model.py\",\"'--class_dataset', default='CziDataset'\",\"'--class_dataset', default='TiffDataset'\")\n","\n","print(\"-------------------\")\n","print(\"Libraries installed\")\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","\n","\n","# Check if this is the latest version of the notebook\n","Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n","print('Notebook version: '+Notebook_version[0])\n","strlist = Notebook_version[0].split('.')\n","Notebook_version_main = strlist[0]+'.'+strlist[1]\n","if Notebook_version_main == Latest_notebook_version.columns:\n"," print(\"This notebook is up-to-date.\")\n","else:\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","\n","if not Notebook_version == list(Latest_notebook_version.columns):\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'Label-free Prediction (fnet)'\n"," day = datetime.now()\n"," date_time = str(day)[0:10]\n","\n"," Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+date_time\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," \n"," # add another cell \n"," if trained:\n"," training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n"," pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n"," pdf.ln(1)\n","\n"," Header_2 = 'Information for your materials and methods:'\n"," pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n"," #print(all_packages)\n","\n"," #Main Packages\n"," main_packages = ''\n"," version_numbers = []\n"," for name in ['tensorflow','numpy','torch','scipy']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n"," cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n"," cuda_version = cuda_version.stdout.decode('utf-8')\n"," cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n"," gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n"," gpu_name = gpu_name.stdout.decode('utf-8')\n"," gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n","\n"," shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n"," dataset_size = len(os.listdir(Training_source))\n","\n"," text = 'The '+Network+' model was trained from scratch for '+str(steps)+' steps on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(32)+','+str(64)+','+str(64)+')) with a batch size of '+str(batch_size)+' and an MSE loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), torch (v '+version_numbers[2]+'), scipy (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," #text = 'The '+Network+' model ('+model_name+') was trained using '+str(dataset_size)+' paired images (image dimensions: '+str(shape)+') using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), torch (v '+version_numbers[2]+'), scipy (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The GPU used was a '+gpu_name+'.'\n","\n"," #if Use_pretrained_model:\n"," # text = 'The '+Network+' model ('+model_name+') was trained using '+str(dataset_size)+' paired images (image dimensions: '+str(shape)+') using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), pytorch (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The GPU used was a '+gpu_name+'. The model was trained from the pretrained model: '+pretrained_model_path+'.'\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.multi_cell(190, 5, txt = text, align='L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n"," pdf.set_font('')\n"," if augmentation:\n"," aug_text = 'The dataset was augmented by'\n"," if Rotation:\n"," aug_text = aug_text+'\\n- rotation'\n"," if Flip:\n"," aug_text = aug_text+'\\n- flipping'\n"," else:\n"," aug_text = 'No augmentation was used for training.'\n"," pdf.multi_cell(190, 5, txt=aug_text, align='L')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," # if Use_Default_Advanced_Parameters:\n"," # pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n"," pdf.cell(200, 5, txt='The following parameters were used for training:')\n"," pdf.ln(1)\n"," html = \"\"\" \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
percentage_validation{0}
steps{1}
batch_size{2}
\n"," \"\"\".format(percentage_validation,steps,batch_size)\n"," pdf.write_html(html)\n","\n"," #pdf.multi_cell(190, 5, txt = text_2, align='L')\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," pdf.ln(1)\n"," pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(30, 5, txt= 'Training_source:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(28, 5, txt= 'Training_target:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n"," #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(22, 5, txt= 'Model Path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n"," pdf.ln(1)\n"," pdf.cell(60, 5, txt = 'Example Training pair (single slice)', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread('/content/TrainingDataExample_Fnet.png').shape\n"," pdf.image('/content/TrainingDataExample_Fnet.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- Label-free prediction (fnet): Ounkomol, Chawin, et al. \"Label-free prediction of three-dimensional fluorescence images from transmitted-light microscopy.\" Nature methods 15.11 (2018): 917-920.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," pdf.ln(3)\n"," reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," if trained:\n"," pdf.output(model_path+'/'+model_name+'/'+model_name+'_'+date_time+\"_training_report.pdf\")\n"," else:\n"," pdf.output('/content/'+model_name+'_'+date_time+\"_training_report.pdf\")\n"," \n","\n","\n","def qc_pdf_export():\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'Label-free prediction (fnet)'\n"," #model_name = os.path.basename(QC_model_folder)\n"," day = datetime.now()\n"," date_time = str(day)[0:10]\n","\n"," Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+date_time\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(2)\n"," pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n"," if os.path.exists(full_QC_model_path+'/QualityControl/lossCurvePlots.png'):\n"," exp_size = io.imread(full_QC_model_path+'/QualityControl/lossCurvePlots.png').shape\n"," pdf.image(full_QC_model_path+'/QualityControl/lossCurvePlots.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," else:\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size=10)\n"," # pdf.ln(3)\n"," pdf.multi_cell(190, 5, txt='If you would like to see the evolution of the loss function during training please play the first cell of the QC section in the notebook.')\n"," pdf.ln(3)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(3)\n"," pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/QualityControl/QC_example_data.png').shape\n"," pdf.image(full_QC_model_path+'/QualityControl/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n","\n"," pdf.ln(1)\n"," html = \"\"\"\n"," \n"," \n"," \"\"\"\n"," with open(full_QC_model_path+'/QualityControl/QC_metrics_'+QC_model_name+'.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," slice_n = header[1]\n"," mSSIM_PvsGT = header[2]\n"," NRMSE_PvsGT = header[3]\n"," PSNR_PvsGT = header[4]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,slice_n,mSSIM_PvsGT,NRMSE_PvsGT,PSNR_PvsGT)\n"," html = html+header\n"," for row in metrics:\n"," image = row[0]\n"," slice_n = row[1]\n"," mSSIM_PvsGT = row[2]\n"," NRMSE_PvsGT = row[3]\n"," PSNR_PvsGT = row[4]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \n"," \n"," \"\"\".format(image,slice_n,str(round(float(mSSIM_PvsGT),3)),str(round(float(NRMSE_PvsGT),3)),str(round(float(PSNR_PvsGT),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}{3}{4}
{0}{1}{2}{3}{4}
\"\"\"\n"," \n"," pdf.write_html(html)\n","\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- Label-free prediction (fnet): Ounkomol, Chawin, et al. \"Label-free prediction of three-dimensional fluorescence images from transmitted-light microscopy.\" Nature methods 15.11 (2018): 917-920.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n"," pdf.ln(3)\n"," reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(full_QC_model_path+'/QualityControl/'+QC_model_name+'_QC_report.pdf')\n","\n","# Build requirements file for local run\n","after = [str(m) for m in sys.modules]\n","build_requirements_file(before, after)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"HLYcZR9gMv42"},"source":["# **3. Select your parameters and paths**\n","---"]},{"cell_type":"markdown","metadata":{"id":"FQ_QxtSWQ7CL"},"source":["## **3.1. Setting the main training parameters**\n","---\n"," **Paths for training data**\n","\n"," **`Training_source`,`Training_target:`** These are the paths to your folders containing the Training_source (brightfield) and Training_target (fluorescent label) training data respectively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**Note: The stacks for fnet should either have 32 or more slices or have a number of slices which are a power of 2 (e.g. 2,4,8,16).**\n","\n"," **Training Parameters**\n","\n"," **`percentage validation`** Input the percentage of your training dataset you want to use to validate the network during training. **Default value: 10** \n","\n"," **`steps:`** Input how many iterations you want to train the network for. A larger number may improve performance but risks overfitting to the training data. To reach good performance of fnet requires several 10000's iterations which will usually require **several hours**, depending on the dataset size. **Default: 50000**\n","\n","**`batch_size:`** Reducing or increasing the **batch size** may speed up or slow down your training, respectively and can influence network performance. **Default: 4**\n","\n","**Note 2: If you only need to retrain your model after a time-out, skip this cell and go straight to section 4.2. Just make sure your training datasets are still in their original folders.**"]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","cellView":"form"},"source":["#@markdown ###Datasets\n","#Datasets\n","\n","#Change checkpoints\n","replace(\"/content/pytorch_fnet/train_model.py\",\"'--interval_save', type=int, default=500\",\"'--interval_save', type=int, default=100\")\n","\n","#Adapt Class Dataset for Tiff files\n","replace(\"/content/pytorch_fnet/train_model.py\",\"'--class_dataset', default='CziDataset'\",\"'--class_dataset', default='TiffDataset'\")\n","\n","\n","### Choosing and editing the path names ###\n","\n","#Fetch the path and extract the name of the signal folder\n","Training_source = \"\" #@param {type: \"string\"}\n","source_name = os.path.basename(os.path.normpath(Training_source))\n","\n","#Fetch the path and extract the name of the signal folder\n","Training_target = \"\" #@param {type: \"string\"}\n","target_name = os.path.basename(os.path.normpath(Training_target))\n","\n","#@markdown ###Model name and model path\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","full_model_path = model_path+'/'+model_name\n","\n","new_full_model_path = convert_to_script_compatible_path(full_model_path)\n","new_full_model_path_csv = new_full_model_path+'\\/'+model_name+'\\.csv'\n","new_full_model_path_val_csv = new_full_model_path+'\\/'+model_name+'_val\\.csv'\n","\n","#here we check that no model with the same name already exist, if so print a warning\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n"," print(bcolors.WARNING +\"To continue training \"+model_name+\", skip this cell and instead load \"+model_name+\" as Pretrained_model_folder in section 4.2\")\n","#@markdown ---\n","\n","#@markdown ###Training Parameters\n","\n","percentage_validation = 10#@param{type:\"number\"}\n","\n","model_name_x = model_name+\"}\" # this variable is only used to ensure closed curly brackets when editing the .sh files\n","\n","\n","### Edit the train.sh script file and train.py file ###\n","\n","#We need to declare that we will run validation on the dataset\n","#We need to add a new line to the train.sh file\n","with open(\"/content/pytorch_fnet/scripts/train_model.sh\", \"r\") as f:\n"," if not \"gpu_ids ${GPU_IDS} \\\\\" in f.read():\n"," replace(\"/content/pytorch_fnet/scripts/train_model.sh\",\" --gpu_ids ${GPU_IDS}\",\" --gpu_ids ${GPU_IDS} \\\\\")\n","\n","#We add the necessary validation parameters here.\n","with open('/content/pytorch_fnet/scripts/train_model.sh','r') as scriptfile:\n"," lines = scriptfile.readlines()\n"," if 'PATH_DATASET_VAL_CSV' not in lines:\n"," insert = 'PATH_DATASET_VAL_CSV=\"data/csvs/${DATASET}_val.csv\"'\n"," add_insert(\"/content/pytorch_fnet/scripts/train_model.sh\",10,insert,\"\")\n"," add_insert(\"/content/pytorch_fnet/scripts/train_model.sh\",22,'\\n --path_dataset_val_csv ${PATH_DATASET_VAL_CSV}',\"\")\n"," \n","#Clear the White space from train.sh\n","with open('/content/pytorch_fnet/scripts/train_model.sh', 'r') as inFile,\\\n"," open('/content/pytorch_fnet/scripts/train_model_temp.sh', 'w') as outFile:\n"," for line in inFile:\n"," if line.strip():\n"," outFile.write(line)\n","os.remove('/content/pytorch_fnet/scripts/train_model.sh')\n","os.rename('/content/pytorch_fnet/scripts/train_model_temp.sh','/content/pytorch_fnet/scripts/train_model.sh')\n","\n","#Training parameters in fnet are indicated in the train_model.sh file.\n","#Here, we edit this file to include the desired parameters\n","\n","#1. Add permissions to train_model.sh\n","os.chdir(\"/content/pytorch_fnet/scripts\")\n","!chmod u+x train_model.sh\n","\n","#2. Select parameters\n","steps = 50000#@param {type:\"number\"}\n","batch_size = 4#@param {type:\"number\"}\n","\n","source = os.listdir(Training_source)\n","target = os.listdir(Training_target)\n","number_of_images = len(source[:-round(len(source)*(percentage_validation/100))])\n","\n","#3. Insert the above values into train_model.sh\n","!if ! grep saved_models\\/\\${ train_model.sh;then sed -i 's/saved_models\\/.*/saved_models\\/\\${DATASET}\"/g' train_model.sh; fi \n","!sed -i \"s/1:-.*/1:-$model_name_x/g\" train_model.sh #change the dataset to be trained with\n","!sed -i \"s/N_ITER=.*/N_ITER=$steps/g\" train_model.sh #change the number of training iterations (steps)\n","!sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" train_model.sh #change the number of training images\n","!sed -i \"s/BATCH_SIZE=.*/BATCH_SIZE=$batch_size/g\" train_model.sh #change the batch size\n","\n","!sed -i 's/RUN_DIR=.*/RUN_DIR=\"$new_full_model_path\"/g' train_model.sh\n","!sed -i 's/PATH_DATASET_TRAIN_CSV=.*/PATH_DATASET_TRAIN_CSV=\"$new_full_model_path_csv\"/g' train_model.sh\n","!sed -i 's/PATH_DATASET_VAL_CSV=.*/PATH_DATASET_VAL_CSV=\"$new_full_model_path_val_csv\"/g' train_model.sh\n","\n","#We also change the training split as in our notebook the test images are used separately for prediction and we want fnet to train on the whole training data set.\n","\n","replace('/content/pytorch_fnet/scripts/train_model.sh','python scripts', '#python scripts')\n","replace('/content/pytorch_fnet/scripts/train_model.sh','python train_model.py', 'python /content/pytorch_fnet/train_model.py')\n","replace('/content/pytorch_fnet/scripts/train_model.sh','PATH_DATASET_ALL_CSV','#PATH_DATASET_ALL_CSV')\n","#No Augmentation by default\n","Use_Data_augmentation = False\n","\n","#Load one randomly chosen training source file\n","random_choice=random.choice(os.listdir(Training_source))\n","x = io.imread(Training_source+\"/\"+random_choice)\n","\n","#Find image Z dimension and select the mid-plane\n","Image_Z = x.shape[0]\n","mid_plane = int(Image_Z / 2)+1\n","\n","os.chdir(Training_target)\n","y = io.imread(Training_target+\"/\"+random_choice)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x[mid_plane], norm=simple_norm(x[mid_plane], percent = 99), interpolation='nearest')\n","plt.axis('off')\n","plt.title('Training Source (single Z plane)');\n","plt.subplot(1,2,2)\n","plt.imshow(y[mid_plane], norm=simple_norm(y[mid_plane], percent = 99), interpolation='nearest')\n","plt.axis('off')\n","plt.title('Training Target (single Z plane)');\n","plt.savefig('/content/TrainingDataExample_Fnet.png',bbox_inches='tight',pad_inches=0)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xyQZKby8yFME"},"source":["## **3.2. Data augmentation**\n","---\n",""]},{"cell_type":"markdown","metadata":{"id":"w_jCy7xOx2g3"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n","Data augmentation is performed here by rotating images in XY-Plane and flip them along X-Axis. This only works if the patches are square in XY.\n","\n","**Note:** Using a full augmented dataset can exceed the RAM limitations of the colab notebook. If the augmented dataset is too large, the notebook will therefore only pick a subset of the augmented dataset for training. Make sure you only augment datasets which are small (ca. 20-30 images).\n","\n","**Note 2: If you intend to use the retraining option at a later point, save the dataset in a folder in your Google Drive.**"]},{"cell_type":"code","metadata":{"id":"DMqWq5-AxnFU","cellView":"form"},"source":["from skimage import io\n","import numpy as np\n","\n","Use_Data_augmentation = True #@param{type:\"boolean\"}\n","\n","#@markdown Select this option if you want to use augmentation to increase the size of your dataset\n","\n","#@markdown **Rotate each image 3 times by 90 degrees.**\n","Rotation = False #@param{type:\"boolean\"}\n","\n","#@markdown **Flip each image once around the x axis of the stack.**\n","Flip = True #@param{type:\"boolean\"}\n","\n","\n","#@markdown **Would you like to save your augmented images?**\n","\n","Save_augmented_images = False #@param {type:\"boolean\"}\n","\n","Saving_path = \"\" #@param {type:\"string\"}\n","\n","\n","if not Save_augmented_images:\n"," Saving_path= \"/content\"\n","\n","\n","def rotation_aug(Source_path, Target_path, aug_source_dest='augmented_source', aug_target_dest='augmented_target', flip=False):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path)\n"," \n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," # Source Rotation\n"," source_img_90 = np.rot90(source_img,axes=(1,2))\n"," source_img_180 = np.rot90(source_img_90,axes=(1,2))\n"," source_img_270 = np.rot90(source_img_180,axes=(1,2))\n","\n"," # Target Rotation\n"," target_img_90 = np.rot90(target_img,axes=(1,2))\n"," target_img_180 = np.rot90(target_img_90,axes=(1,2))\n"," target_img_270 = np.rot90(target_img_180,axes=(1,2))\n","\n"," # Add a flip to the rotation\n"," \n"," if flip == True:\n"," source_img_lr = np.fliplr(source_img)\n"," source_img_90_lr = np.fliplr(source_img_90)\n"," source_img_180_lr = np.fliplr(source_img_180)\n"," source_img_270_lr = np.fliplr(source_img_270)\n","\n"," target_img_lr = np.fliplr(target_img)\n"," target_img_90_lr = np.fliplr(target_img_90)\n"," target_img_180_lr = np.fliplr(target_img_180)\n"," target_img_270_lr = np.fliplr(target_img_270)\n","\n"," #source_img_90_ud = np.flipud(source_img_90)\n"," \n"," # Save the augmented files\n"," # Source images\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+image,source_img)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_90.tif',source_img_90)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_180.tif',source_img_180)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_270.tif',source_img_270)\n"," # Target images\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+image,target_img)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_90.tif',target_img_90)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_180.tif',target_img_180)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_270.tif',target_img_270)\n","\n"," if flip == True:\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_90_lr.tif',source_img_90_lr)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_180_lr.tif',source_img_180_lr)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_270_lr.tif',source_img_270_lr)\n","\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_90_lr.tif',target_img_90_lr)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_180_lr.tif',target_img_180_lr)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_270_lr.tif',target_img_270_lr)\n","\n","def flip(Source_path, Target_path, aug_source_dest='augmented_source', aug_target_dest='augmented_target'):\n"," Source_images = os.listdir(Source_path)\n"," Target_images = os.listdir(Target_path) \n","\n"," for image in Source_images:\n"," source_img = io.imread(os.path.join(Source_path,image))\n"," target_img = io.imread(os.path.join(Target_path,image))\n"," \n"," source_img_lr = np.fliplr(source_img)\n"," target_img_lr = np.fliplr(target_img)\n","\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+image,source_img)\n"," io.imsave(Saving_path+'/'+aug_source_dest+'/'+os.path.splitext(image)[0]+'_lr.tif',source_img_lr)\n","\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+image,target_img)\n"," io.imsave(Saving_path+'/'+aug_target_dest+'/'+os.path.splitext(image)[0]+'_lr.tif',target_img_lr)\n","\n","\n","if Use_Data_augmentation:\n","\n"," if os.path.exists(Saving_path+'/augmented_source'):\n"," shutil.rmtree(Saving_path+'/augmented_source')\n"," os.mkdir(Saving_path+'/augmented_source')\n","\n"," if os.path.exists(Saving_path+'/augmented_target'):\n"," shutil.rmtree(Saving_path+'/augmented_target') \n"," os.mkdir(Saving_path+'/augmented_target')\n"," \n"," print(\"Data augmentation enabled\")\n"," print(\"Data augmentation in progress....\")\n","\n"," if Rotation == True:\n"," rotation_aug(Training_source,Training_target,flip=Flip)\n"," elif Rotation == False and Flip == True:\n"," flip(Training_source,Training_target)\n"," #Fetch the path and extract the name of the signal folder\n"," Training_source = Saving_path+\"/augmented_source\"\n"," source_name = os.path.basename(os.path.normpath(Training_source))\n","\n"," #Fetch the path and extract the name of the target folder\n"," Training_target = Saving_path+\"/augmented_target\"\n"," target_name = os.path.basename(os.path.normpath(Training_target))\n","\n"," #Here, we ensure that there aren't too many images in the buffer.\n"," #The best value will depend on the size of the images and the assigned GPU.\n"," #If too many images are loaded to the buffer the notebook will terminate the training as the RAM limit will be exceeded.\n"," if len(os.listdir(Saving_path+'/augmented_source'))>100:\n"," number_of_images = 100\n"," else:\n"," number_of_images = len(os.listdir(Saving_path+'/augmented_source'))\n","\n"," os.chdir(\"/content/pytorch_fnet/scripts\")\n"," !chmod u+x train_model.sh\n"," !sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" train_model.sh #change the number of training images\n","\n"," print(\"Done\")\n","if not Use_Data_augmentation:\n"," print(bcolors.WARNING+\"Data augmentation disabled\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Nyf9ndiS7sL9"},"source":["#**4. Train the network**\n","---\n","\n","Before training, carefully read the different options. This applies especially if you have trained fnet on a dataset before.\n","\n","\n","###**Choose one of the options to train fnet**.\n","\n","**4.1.** If this is the first training on the chosen dataset, play this section to start training.\n","\n","**4.2.** If you want to continue training on an already pre-trained model choose this section\n","\n"," **Carefully read the options before starting training.**"]},{"cell_type":"markdown","metadata":{"id":"P9OJ0nlI71Rc"},"source":["##**4.2. Start Training**\n","---\n","\n","####Play the cell below to start training. \n","\n","**Note:** If you are training with a model of the same name as before, the model will be overwritten. If you want to keep the previous model save it before playing the cell below or give your model a different name (section 3).\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"code","metadata":{"cellView":"form","id":"MQvrHFVcJ6VT"},"source":["#@markdown ##Create the dataset files for training\n","if os.path.exists(model_path+'/'+model_name):\n"," shutil.rmtree(model_path+'/'+model_name)\n"," print(bcolors.WARNING +\"!! Existing model \"+model_name+\" was found and overwritten!!\")\n","os.mkdir(model_path+'/'+model_name)\n","\n","#os.chdir(model_path)\n","# source = os.listdir(Training_source)\n","# target = os.listdir(Training_target)\n","\n","if Use_Data_augmentation == True:\n","\n"," aug_source = os.listdir(Saving_path+'/augmented_source')\n"," aug_val_files = aug_source[-round(len(aug_source)*(percentage_validation/100)):]\n"," aug_source_files = aug_source[:-round(len(aug_source)*(percentage_validation/100))]\n","\n"," #Finally, we create a validation csv file to construct the validation dataset\n"," with open(model_path+'/'+model_name+'/'+model_name+'_val.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(aug_val_files)):\n"," writer.writerow([Saving_path+'/augmented_source/'+aug_val_files[i],Saving_path+\"/augmented_target/\"+aug_val_files[i]])\n","\n"," with open(model_path+'/'+model_name+'/'+model_name+'.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(aug_source_files)):\n"," writer.writerow([Saving_path+'/augmented_source/'+aug_source_files[i],Saving_path+'/augmented_target/'+aug_source_files[i]])\n","\n","else:\n"," #Here we define the random set of training files to be used for validation\n"," val_files = source[-round(len(source)*(percentage_validation/100)):]\n"," source_files = source[:-round(len(source)*(percentage_validation/100))]\n","\n"," #Finally, we create a validation csv file to construct the validation dataset\n"," with open(model_path+'/'+model_name+'/'+model_name+'_val.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(val_files)):\n"," writer.writerow([Training_source+'/'+val_files[i],Training_target+'/'+val_files[i]])\n","\n","\n"," with open(model_path+'/'+model_name+'/'+model_name+'.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(source_files)):\n"," writer.writerow([Training_source+\"/\"+source_files[i],Training_target+\"/\"+source_files[i]])"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"X8YHeSGr76je","cellView":"form"},"source":["#@markdown ####If your dataset is large the notebook might crash unexpectedly when loading the training data into the buffer. If this happens, reduce the number of images to be loaded into the buffer and restart the training.\n","os.chdir(\"/content/pytorch_fnet/scripts\")\n","number_of_images = 25#@param{type:\"number\"}\n","!chmod u+x train_model.sh\n","!sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" train_model.sh #change the number of training images"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"7Ofm-71T8ABX","cellView":"form"},"source":["#@markdown ##Start training\n","pdf_export(augmentation = Use_Data_augmentation)\n","start = time.time()\n","\n","#This tifffile release runs error-free in this version of fnet.\n","!pip install tifffile==2019.7.26\n","\n","#Here we import an additional module to the functions.py file to run it without errors.\n","\n","os.chdir('/content')\n","add_insert(\"/content/pytorch_fnet/fnet/functions.py\",5,\"import fnet.fnet_model\",\"\")\n","\n","### TRAIN THE MODEL ###\n","\n","print('Let''s start the training!')\n","#Here we start the training\n","!/content/pytorch_fnet/scripts/train_model.sh $model_name 0\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","#Create a pdf document with training summary\n","\n","pdf_export(trained = True, augmentation = Use_Data_augmentation)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"bOdyjxWV8IrO"},"source":["**Note:** Fnet takes a long time for training. If your notebook times out due to the length of the training or due to a loss of GPU acceleration the last checkpoint will be saved in the saved_models folder in the pytorch_fnet folder. If you want to save it in a more convenient location on your drive, remount the drive (if you got disconnected) and in the next cell enter the location (`model_path`) where you want to save the model (`model_name`) before continuing in 4.2. **If you did not time out you can ignore this section.**"]},{"cell_type":"markdown","metadata":{"id":"-JxxMmVr8Tw-"},"source":["##**4.2. Training from a previously saved model**\n","---\n","This section allows you to use networks you have previously trained and saved and to continue training them for more training steps. The folders have the same meaning as above (3.1.). If you want to save the previously trained model, create a copy now as this section will overwrite the weights of the old model. **You can currently only train the model with the same dataset and batch size that the network was previously trained on. Make sure your training datasets are in the same location as when you originally trained the model.**\n","\n","**Note: To use this section the *pytorch_fnet* folder must be in your *gdrive/My Drive*. (Simply, play cell 2. to make sure).**"]},{"cell_type":"code","metadata":{"id":"iDIgosht8U7F","cellView":"form"},"source":["#@markdown To test if performance improves after the initial training, you can continue training on the old model. This option can also be useful if Colab disconnects or times out.\n","#@markdown Enter the paths of the datasets you want to continue training on.\n","\n","#Here we repeat steps already used above in case the notebook needs to be restarted for this cell.\n","#We need to add a new line to the train.sh file\n","with open(\"/content/pytorch_fnet/scripts/train_model.sh\", \"r\") as f:\n"," if not \"gpu_ids ${GPU_IDS} \\\\\" in f.read():\n"," replace(\"/content/pytorch_fnet/scripts/train_model.sh\",\" --gpu_ids ${GPU_IDS}\",\" --gpu_ids ${GPU_IDS} \\\\\")\n","\n","#We add the necessary validation parameters here.\n","with open('/content/pytorch_fnet/scripts/train_model.sh','r') as scriptfile:\n"," lines = scriptfile.readlines()\n"," if 'PATH_DATASET_VAL_CSV' not in lines:\n"," insert = 'PATH_DATASET_VAL_CSV=\"data/csvs/${DATASET}_val.csv\"'\n"," add_insert(\"/content/pytorch_fnet/scripts/train_model.sh\",10,insert,\"\")\n"," add_insert(\"/content/pytorch_fnet/scripts/train_model.sh\",22,'\\n --path_dataset_val_csv ${PATH_DATASET_VAL_CSV}',\"\")\n","\n","\n","#Clear the White space from train.sh\n","with open('/content/pytorch_fnet/scripts/train_model.sh', 'r') as inFile,\\\n"," open('/content/pytorch_fnet/scripts/train_model_temp.sh', 'w') as outFile:\n"," for line in inFile:\n"," if line.strip():\n"," outFile.write(line)\n","os.remove('/content/pytorch_fnet/scripts/train_model.sh')\n","os.rename('/content/pytorch_fnet/scripts/train_model_temp.sh','/content/pytorch_fnet/scripts/train_model.sh')\n","\n","\n","#Change checkpoints\n","replace(\"/content/pytorch_fnet/train_model.py\",\"'--interval_save', type=int, default=500\",\"'--interval_save', type=int, default=100\")\n","\n","#Adapt Class Dataset for Tiff files\n","replace(\"/content/pytorch_fnet/train_model.py\",\"'--class_dataset', default='CziDataset'\",\"'--class_dataset', default='TiffDataset'\")\n","\n","### Choosing and editing the path names ###\n","\n","Pretrained_model_folder = \"\" #@param{type:\"string\"}\n","Pretrained_model_name = os.path.basename(Pretrained_model_folder)\n","Pretrained_model_path = os.path.dirname(Pretrained_model_folder)\n","\n","full_model_path = Pretrained_model_path+'/'+Pretrained_model_name\n","\n","new_full_model_path = convert_to_script_compatible_path(full_model_path)\n","new_full_model_path_csv = new_full_model_path+'\\/'+Pretrained_model_name+'\\.csv'\n","new_full_model_path_val_csv = new_full_model_path+'\\/'+Pretrained_model_name+'_val\\.csv'\n","\n","Pretrained_model_name_x = Pretrained_model_name+\"}\"\n","\n","#We get the example data and the number of images from the csv path file#\n","\n","with open(full_model_path+'/'+Pretrained_model_name+'.csv') as csvfile:\n"," csvreader = csv.reader(csvfile)\n"," header = next(csvreader)\n"," number_of_images = 0\n"," for line in csvreader:\n"," ExampleSource = line[0]\n"," ExampleTarget = line[1]\n"," number_of_images += 1\n","\n","with open(full_model_path+'/'+Pretrained_model_name+'_val.csv') as csvfile:\n"," csvreader = csv.reader(csvfile)\n"," header = next(csvreader)\n"," number_of_val_images = 0\n"," for line in csvreader:\n"," number_of_val_images += 1\n","\n","#Batch Size\n","\n","batch_size = 4 #@param {type:\"number\"}\n","\n","# Editing the train.sh script file #\n","\n","os.chdir('/content/pytorch_fnet/scripts')\n","\n","#Change the train_model.sh file to include chosen dataset\n","!chmod u+x ./train_model.sh\n","!sed -i \"s/1:-.*/1:-$Pretrained_model_name_x/g\" train_model.sh\n","!sed -i \"s/train_size .* -v/train_size 1.0 -v/g\" train_model.sh #Use the whole training dataset for training\n","!sed -i \"s/BUFFER_SIZE=.*/BUFFER_SIZE=$number_of_images/g\" train_model.sh #change the number of training images\n","!sed -i \"s/BATCH_SIZE=.*/BATCH_SIZE=$batch_size/g\" train_model.sh #change the batch size\n","\n","!sed -i 's/RUN_DIR=.*/RUN_DIR=\"$new_full_model_path\"/g' train_model.sh\n","!sed -i 's/PATH_DATASET_TRAIN_CSV=.*/PATH_DATASET_TRAIN_CSV=\"$new_full_model_path_csv\"/g' train_model.sh\n","!sed -i 's/PATH_DATASET_VAL_CSV=.*/PATH_DATASET_VAL_CSV=\"$new_full_model_path_val_csv\"/g' train_model.sh\n","\n","replace('/content/pytorch_fnet/scripts/train_model.sh','python scripts', '#python scripts')\n","replace('/content/pytorch_fnet/scripts/train_model.sh','python train_model.py', 'python /content/pytorch_fnet/train_model.py')\n","replace('/content/pytorch_fnet/scripts/train_model.sh','PATH_DATASET_ALL_CSV','#PATH_DATASET_ALL_CSV')\n","\n","# Find the number of steps to add and then add #\n","with open(Pretrained_model_folder+'/losses.csv') as f:\n"," previous_steps = sum(1 for line in f)\n","print('continuing training after step '+str(previous_steps-1))\n","\n","print('To start re-training play section 4.2. below')\n","\n","#@markdown For how many additional steps do you want to train the model?\n","add_steps = 10000#@param {type:\"number\"}\n","\n","#Calculate the new number of total training epochs. Subtract 1 to discount the title row of the csv file.\n","new_steps = previous_steps + add_steps -1\n","os.chdir('/content/pytorch_fnet/scripts')\n","\n","#Edit train_model.sh file to include new total number of training epochs\n","!sed -i \"s/N_ITER=.*/N_ITER=$new_steps/g\" train_model.sh\n","\n","# Display example data #\n","\n","#Load one randomly chosen training source file\n","#random_choice=random.choice(os.listdir(Training_source))\n","x = io.imread(ExampleSource)\n","\n","#Find image Z dimension and select the mid-plane\n","Image_Z = x.shape[0]\n","mid_plane = int(Image_Z / 2)+1\n","\n","#os.chdir(Training_target)\n","y = io.imread(ExampleTarget)\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x[mid_plane], norm=simple_norm(x[mid_plane], percent = 99), interpolation='nearest')\n","plt.axis('off')\n","plt.title('Training Source (single Z plane)');\n","plt.subplot(1,2,2)\n","plt.imshow(y[mid_plane], norm=simple_norm(y[mid_plane], percent = 99), interpolation='nearest')\n","plt.axis('off')\n","plt.title('Training Target (single Z plane)');\n","plt.savefig('/content/TrainingDataExample_Fnet.png',bbox_inches='tight',pad_inches=0)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"5IXdFqhM8gO2","cellView":"form"},"source":["Use_Data_augmentation = False \n","start = time.time()\n","\n","#@markdown ##4.2. Start re-training model\n","!pip install tifffile==2019.7.26\n","\n","os.chdir('/content/pytorch_fnet/fnet')\n","\n","add_insert(\"/content/pytorch_fnet/fnet/functions.py\",5,\"import fnet.fnet_model\",\"\")\n","\n","#Here we retrain the model on the chosen dataset.\n","os.chdir('/content/pytorch_fnet/')\n","!chmod u+x ./scripts/train_model.sh\n","!./scripts/train_model.sh $Pretrained_model_name 0\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","#Here, we redefine the variable names for the pdf export\n","percentage_validation = round((number_of_val_images/(number_of_images+number_of_val_images))*100)\n","steps = new_steps\n","model_name = Pretrained_model_name\n","model_path = Pretrained_model_path\n","Training_source = os.path.dirname(ExampleSource)\n","Training_target = os.path.dirname(ExampleTarget)\n","#Create a pdf document with training summary\n","pdf_export(trained = True, augmentation = Use_Data_augmentation)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"_0Hynw3-xHp1"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**"]},{"cell_type":"code","metadata":{"id":"eAJzMwPA6tlH","cellView":"form"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the name of the model and path to model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","#Create a folder for the quality control metrics\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/QualityControl\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/QualityControl\")\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/QualityControl\")\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"dhJROwlAMv5o"},"source":["## **5.1. Inspection of the loss function**\n","---\n","\n","First, it is good practice to evaluate the training progress by comparing the training loss with the validation loss. The latter is a metric which shows how well the network performs on a subset of unseen data which is set aside from the training dataset. For more information on this, see for example [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols *et al.*\n","\n","**Training loss** describes an error value after each epoch for the difference between the model's prediction and its ground-truth target.\n","\n","**Validation loss** describes the same error value between the model's prediction on a validation image and compared to it's target.\n","\n","During training both values should decrease before reaching a minimal value which does not decrease further even after more training. Comparing the development of the validation loss with the training loss can give insights into the model's performance.\n","\n","Decreasing **Training loss** and **Validation loss** indicates that training is still necessary and increasing the `number_of_epochs` is recommended. Note that the curves can look flat towards the right side, just because of the y-axis scaling. The network has reached convergence once the curves flatten out. After this point no further training is required. If the **Validation loss** suddenly increases again an the **Training loss** simultaneously goes towards zero, it means that the network is overfitting to the training data. In other words the network is remembering the exact patterns from the training data and no longer generalizes well to unseen data. In this case the training dataset has to be increased.\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"vMzSP50kMv5p","cellView":"form"},"source":["#@markdown ##Play the cell to show figure of training errors\n","\n","lossDataFromCSV = []\n","vallossDataFromCSV = []\n","\n","iterationNumber_training = []\n","iterationNumber_val = []\n","\n","import csv\n","from matplotlib import pyplot as plt\n","with open(QC_model_path+'/'+QC_model_name+'/'+'losses.csv','r') as csvfile:\n"," plots = csv.reader(csvfile, delimiter=',')\n"," next(plots)\n"," for row in plots:\n"," iterationNumber_training.append(int(row[0]))\n"," lossDataFromCSV.append(float(row[1]))\n","\n","with open(QC_model_path+'/'+QC_model_name+'/'+'losses_val.csv','r') as csvfile_val:\n"," plots = csv.reader(csvfile_val, delimiter=',')\n"," next(plots)\n"," for row in plots:\n"," iterationNumber_val.append(int(row[0]))\n"," vallossDataFromCSV.append(float(row[1]))\n","\n","plt.figure(figsize=(15,10))\n","\n","plt.subplot(2,1,1)\n","plt.plot(iterationNumber_training, lossDataFromCSV, label='Training loss')\n","plt.plot(iterationNumber_val, vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. iteration number (linear scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Iteration')\n","plt.legend()\n","\n","plt.subplot(2,1,2)\n","plt.semilogy(iterationNumber_training, lossDataFromCSV, label='Training loss')\n","plt.semilogy(iterationNumber_val, vallossDataFromCSV, label='Validation loss')\n","plt.title('Training loss and validation loss vs. iteration number (log scale)')\n","plt.ylabel('Loss')\n","plt.xlabel('Iteration')\n","plt.legend()\n","plt.savefig(QC_model_path+'/'+QC_model_name+'/QualityControl/lossCurvePlots.png', bbox_inches='tight', pad_inches=0)\n","plt.show()\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"X5_92nL2xdP6"},"source":["## **5.2. Error mapping and quality metrics estimation**\n","---\n","\n","This section will display SSIM maps and RSE maps as well as calculating total SSIM, NRMSE and PSNR metrics for all the images provided in the \"Source_QC_folder\" and \"Target_QC_folder\" !\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n","\n","\n","**Note:** If you receive a *CUDA out of memory* error, this can be caused by the size of the data that model needs to predict or the type of GPU has allocated to your session. To solve this issue, you can *factory reset runtime* to attempt to connect to a different GPU or use a dataset with smaller images.\n"]},{"cell_type":"code","metadata":{"id":"w90MdriMxhjD","cellView":"form"},"source":["!pip install -U scipy==1.2.0\n","!pip install --no-cache-dir tifffile==2019.7.26 \n","from distutils.dir_util import copy_tree\n","\n","#----------------CREATING PREDICTIONS FOR QUALITY CONTROL----------------------------------#\n","\n","### Choosing and editing the path names ###\n","\n","#Choose the folder with the quality control datasets\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","Predictions_name = \"QualityControl\" \n","Predictions_name_x = Predictions_name+\"}\"\n","\n","#If the folder you are creating already exists, delete the existing version to overwrite.\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/QualityControl/Predictions\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/QualityControl/Predictions\")\n","\n","# Get the name of the folder the test data is in\n","source_dataset_name = os.path.basename(os.path.normpath(Source_QC_folder))\n","target_dataset_name = os.path.basename(os.path.normpath(Target_QC_folder))\n","\n","new_full_QC_model_path = convert_to_script_compatible_path(full_QC_model_path)\n","new_full_QC_model_path_dataset = new_full_QC_model_path+'\\${DATASET}'\n","new_full_QC_model_path_csv = new_full_QC_model_path+'\\/QualityControl\\/qc\\${TEST_OR_TRAIN}\\.csv'# Get the name of the folder the test data is in\n","\n","\n","### Editing the predict.sh script file ###\n","\n","# Get permission to the predict.sh file and change the name of the dataset to the Predictions_folder.\n","os.chdir('/content/pytorch_fnet/')\n","!chmod u+x ./scripts/predict.sh\n","!sed -i \"s/1:-.*/1:-$Predictions_name_x/g\" ./scripts/predict.sh\n","\n","#Here, we remove the 'train' option from predict.sh as we don't need to run predictions on the train data.\n","!sed -i \"s/in test.*/in test/g\" ./scripts/predict.sh\n","\n","!if ! grep class_dataset ./scripts/predict.sh;then sed -i 's/DIR} \\\\/DIR} \\\\\\'$''\\n' --class_dataset TiffDataset \\\\/' /content/pytorch_fnet/scripts/predict.sh; fi\n","!if grep CziDataset ./scripts/predict.sh;then sed -i 's/CziDataset/TiffDataset/' ./scripts/predict.sh; fi \n","\n","!sed -i \"s/MODEL_DIR=.*/MODEL_DIR=$new_full_QC_model_path/g\" ./scripts/predict.sh\n","!sed -i \"s/path_dataset_csv.*/path_dataset_csv\\ $new_full_QC_model_path_csv\\ \\\\\\/g\" ./scripts/predict.sh\n","!sed -i \"s/path_save_dir.*/path_save_dir $new_full_QC_model_path\\/QualityControl\\/Predictions\\ \\\\\\/g\" ./scripts/predict.sh\n","\n","\n","### Create a path csv file for prediction (QC)###\n","\n","#Here we create a qctest.csv to locate the files used for QC\n","test_signal = os.listdir(Source_QC_folder)\n","test_target = os.listdir(Target_QC_folder)\n","\n","with open(full_QC_model_path+'/QualityControl/qctest.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(test_signal)):\n"," writer.writerow([Source_QC_folder+'/'+test_signal[i],Target_QC_folder+'/'+test_signal[i]])\n","\n","### RUN THE PREDICTION ###\n","!/content/pytorch_fnet/scripts/predict.sh $Predictions_name 0\n","\n","### Save the results ###\n","QC_results_files = os.listdir(full_QC_model_path+'/QualityControl/Predictions')\n","\n","if os.path.exists(QC_model_path+'/'+QC_model_name+'/QualityControl/Prediction'):\n"," shutil.rmtree(QC_model_path+'/'+QC_model_name+'/QualityControl/Prediction')\n","os.mkdir(QC_model_path+'/'+QC_model_name+'/QualityControl/Prediction')\n","\n","if os.path.exists(QC_model_path+'/'+QC_model_name+'/QualityControl/Signal'):\n"," shutil.rmtree(QC_model_path+'/'+QC_model_name+'/QualityControl/Signal')\n","os.mkdir(QC_model_path+'/'+QC_model_name+'/QualityControl/Signal')\n","\n","if os.path.exists(QC_model_path+'/'+QC_model_name+'/QualityControl/Target'):\n"," shutil.rmtree(QC_model_path+'/'+QC_model_name+'/QualityControl/Target')\n","os.mkdir(QC_model_path+'/'+QC_model_name+'/QualityControl/Target')\n","\n","for i in range(len(QC_results_files)-2):\n"," if os.path.isdir(full_QC_model_path+'/QualityControl/Predictions/'+QC_results_files[i]):\n"," shutil.copyfile(full_QC_model_path+'/QualityControl/Predictions/'+QC_results_files[i]+'/prediction_'+QC_model_name+'.tiff', QC_model_path+'/'+QC_model_name+'/QualityControl/Prediction/'+'Predicted_'+test_signal[i])\n"," if os.path.exists(full_QC_model_path+'/QualityControl/Predictions/'+QC_results_files[i]+'/signal.tiff'):\n"," shutil.copyfile(full_QC_model_path+'/QualityControl/Predictions/'+QC_results_files[i]+'/signal.tiff', QC_model_path+'/'+QC_model_name+'/QualityControl/Signal/'+test_signal[i])\n"," shutil.copyfile(full_QC_model_path+'/QualityControl/Predictions/'+QC_results_files[i]+'/target.tiff', QC_model_path+'/'+QC_model_name+'/QualityControl/Target/'+test_signal[i])\n"," else:\n"," shutil.copyfile(Source_QC_folder+'/'+test_signal[i],QC_model_path+'/'+QC_model_name+'/QualityControl/Signal/'+test_signal[i])\n"," shutil.copyfile(Target_QC_folder+'/'+test_target[i],QC_model_path+'/'+QC_model_name+'/QualityControl/Target/'+test_signal[i])\n","\n","shutil.rmtree(full_QC_model_path+'/QualityControl/Predictions')\n","\n","#-----------------------------METRICS EVALUATION-------------------------------#\n","\n","# Calculating the position of the mid-plane slice\n","# Perform prediction on all datasets in the Source_QC folder\n","\n","#Finding the middle slice\n","img = io.imread(os.path.join(Source_QC_folder, os.listdir(Source_QC_folder)[0]))\n","n_slices = img.shape[0]\n","z_mid_plane = int(n_slices / 2)+1\n","\n","path_metrics_save = QC_model_path+'/'+QC_model_name+'/QualityControl/'\n","\n","# Open and create the csv file that will contain all the QC metrics\n","with open(path_metrics_save+'QC_metrics_'+QC_model_name+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"File name\",\"Slice #\",\"Prediction v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Prediction v. GT PSNR\"]) \n"," \n"," # These lists will be used to collect all the metrics values per slice\n"," file_name_list = []\n"," slice_number_list = []\n"," mSSIM_GvP_list = []\n"," NRMSE_GvP_list = []\n"," PSNR_GvP_list = []\n","\n"," # These lists will be used to display the mean metrics for the stacks\n"," mSSIM_GvP_list_mean = []\n"," NRMSE_GvP_list_mean = []\n"," PSNR_GvP_list_mean = []\n","\n"," # Let's loop through the provided dataset in the QC folders\n"," for thisFile in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder, thisFile)):\n"," print('Running QC on: '+thisFile)\n","\n"," test_GT_stack = io.imread(os.path.join(Target_QC_folder, thisFile))\n"," if len(test_GT_stack.shape) > 3:\n"," test_GT_stack = test_GT_stack.squeeze()\n"," test_source_stack = io.imread(os.path.join(Source_QC_folder,thisFile))\n"," test_prediction_stack = io.imread(os.path.join(path_metrics_save+\"Prediction/\",'Predicted_'+thisFile))\n"," test_prediction_stack = np.squeeze(test_prediction_stack,axis=(0,))\n"," n_slices = test_GT_stack.shape[0]\n","\n"," img_SSIM_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n"," img_RSE_GTvsPrediction_stack = np.zeros((n_slices, test_GT_stack.shape[1], test_GT_stack.shape[2]))\n","\n"," for z in range(n_slices): \n"," \n"," # -------------------------------- Prediction --------------------------------\n","\n"," test_GT_norm,test_prediction_norm = norm_minmse(test_GT_stack[z], test_prediction_stack[z], normalize_gt=True)\n","\n"," # -------------------------------- Calculate the SSIM metric and maps --------------------------------\n","\n"," # Calculate the SSIM maps and index\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = structural_similarity(test_GT_norm, test_prediction_norm, data_range=1.0, full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n"," #Calculate ssim_maps\n"," img_SSIM_GTvsPrediction_stack[z] = np.float32(img_SSIM_GTvsPrediction)\n"," \n","\n"," # -------------------------------- Calculate the NRMSE metrics --------------------------------\n","\n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n","\n"," # Calculate SE maps\n"," img_RSE_GTvsPrediction_stack[z] = np.float32(img_RSE_GTvsPrediction)\n","\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n","\n","\n"," # Calculate the PSNR between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n","\n","\n"," writer.writerow([thisFile, str(z),str(index_SSIM_GTvsPrediction),str(NRMSE_GTvsPrediction),str(PSNR_GTvsPrediction)])\n"," \n"," # Collect values to display in dataframe output\n"," #file_name_list.append(thisFile)\n"," slice_number_list.append(z)\n"," mSSIM_GvP_list.append(index_SSIM_GTvsPrediction)\n","\n"," NRMSE_GvP_list.append(NRMSE_GTvsPrediction)\n","\n"," PSNR_GvP_list.append(PSNR_GTvsPrediction)\n","\n","\n"," if (z == z_mid_plane): # catch these for display\n"," SSIM_GTvsP_forDisplay = index_SSIM_GTvsPrediction\n","\n"," NRMSE_GTvsP_forDisplay = NRMSE_GTvsPrediction\n","\n"," \n"," # If calculating average metrics for dataframe output\n"," file_name_list.append(thisFile)\n"," mSSIM_GvP_list_mean.append(sum(mSSIM_GvP_list)/len(mSSIM_GvP_list))\n","\n"," NRMSE_GvP_list_mean.append(sum(NRMSE_GvP_list)/len(NRMSE_GvP_list))\n","\n"," PSNR_GvP_list_mean.append(sum(PSNR_GvP_list)/len(PSNR_GvP_list))\n","\n"," # ----------- Change the stacks to 32 bit images -----------\n"," img_SSIM_GTvsPrediction_stack_32 = img_as_float32(img_SSIM_GTvsPrediction_stack, force_copy=False)\n"," img_RSE_GTvsPrediction_stack_32 = img_as_float32(img_RSE_GTvsPrediction_stack, force_copy=False)\n","\n","\n"," # ----------- Saving the error map stacks -----------\n"," io.imsave(path_metrics_save+'SSIM_GTvsPrediction_'+thisFile,img_SSIM_GTvsPrediction_stack_32)\n"," io.imsave(path_metrics_save+'RSE_GTvsPrediction_'+thisFile,img_RSE_GTvsPrediction_stack_32)\n","\n","#Averages of the metrics per stack as dataframe output\n","pdResults = pd.DataFrame(file_name_list, columns = [\"File name\"])\n","pdResults[\"Prediction v. GT mSSIM\"] = mSSIM_GvP_list_mean\n","\n","pdResults[\"Prediction v. GT NRMSE\"] = NRMSE_GvP_list_mean\n","\n","pdResults[\"Prediction v. GT PSNR\"] = PSNR_GvP_list_mean\n","\n","pdResults.head()\n","\n","# All data is now processed saved\n","Test_FileList = os.listdir(Source_QC_folder) # this assumes, as it should, that both source and target are named the same way\n","\n","plt.figure(figsize=(15,10))\n","# Currently only displays the last computed set, from memory\n","\n","# Target (Ground-truth)\n","plt.subplot(2,3,1)\n","plt.axis('off')\n","img_GT = io.imread(os.path.join(Target_QC_folder, Test_FileList[-1]))\n","if len(img_GT.shape) > 3:\n"," img_GT = img_GT.squeeze()\n","plt.imshow(img_GT[z_mid_plane])\n","plt.title('Target (slice #'+str(z_mid_plane)+')')\n","\n","\n","#Setting up colours\n","cmap = plt.cm.Greys\n","\n","\n","# Source\n","plt.subplot(2,3,2)\n","plt.axis('off')\n","img_Source = io.imread(os.path.join(Source_QC_folder, Test_FileList[-1]))\n","plt.imshow(img_Source[z_mid_plane],aspect='equal',cmap=cmap)\n","plt.title('Source (slice #'+str(z_mid_plane)+')')\n","\n","\n","#Prediction\n","plt.subplot(2,3,3)\n","plt.axis('off')\n","img_Prediction = io.imread(os.path.join(path_metrics_save+'Prediction/', 'Predicted_'+Test_FileList[-1]))\n","img_Prediction = np.squeeze(img_Prediction,axis=(0,))\n","plt.imshow(img_Prediction[z_mid_plane])\n","plt.title('Prediction (slice #'+str(z_mid_plane)+')')\n","\n","#Setting up colours\n","cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Prediction\n","plt.subplot(2,3,5)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_SSIM_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'SSIM_GTvsPrediction_'+Test_FileList[-1]))\n","imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0,vmax=1)\n","plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n","plt.title('SSIM map: Target vs. Prediction',fontsize=15)\n","plt.xlabel('mSSIM: '+str(round(SSIM_GTvsP_forDisplay,3)),fontsize=14)\n","\n","\n","#Root Squared Error between GT and Prediction\n","plt.subplot(2,3,6)\n","#plt.axis('off')\n","plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","img_RSE_GTvsPrediction = io.imread(os.path.join(path_metrics_save, 'RSE_GTvsPrediction_'+Test_FileList[-1]))\n","imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction[z_mid_plane], cmap = cmap, vmin=0, vmax=1)\n","plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n","plt.title('RSE map Target vs. Prediction',fontsize=15)\n","plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsP_forDisplay,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n","plt.savefig(full_QC_model_path+'/QualityControl/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","\n","print('-----------------------------------')\n","print('Here are the average scores for the stacks you tested in Quality control. To see values for all slices, open the .csv file saved in the Qulity Control folder.')\n","pdResults.head()\n","\n","\n","#Make a pdf summary of the QC results\n","\n","qc_pdf_export()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-tJeeJjLnRkP"},"source":["#**6. Using the trained model**\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"d8wuQGjoq6eN"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Results_folder** folder.\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Results_folder`:** This folder will contain the predicted output images.\n","\n","If you want to use a model different from the most recently trained one, untick the box and enter the path of the model in **`Prediction_model_folder`**.\n","\n","**Note: `Prediction_model_folder` expects a folder name which contains a model.p file from a previous training.**\n","\n","**Note:** If you receive a *CUDA out of memory* error, this can be caused by the size of the data that model needs to predict or the type of GPU has allocated to your session. To solve this issue, you can *factory reset runtime* to attempt to connect to a different GPU or use a dataset with smaller images.\n"]},{"cell_type":"code","metadata":{"id":"y2TD5p7MZrEb","cellView":"form"},"source":["#Before prediction we will remove the old prediction folder because fnet won't execute if a path already exists that has the same name.\n","#This is just in case you have already trained on a dataset with the same name\n","#The data will be saved outside of the pytorch_folder (Results_folder) so it won't be lost when you run this section again.\n","\n","!pip install -U scipy==1.2.0\n","!pip install --no-cache-dir tifffile==2019.7.26 \n","\n","#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then play the cell to predict outputs from your unseen images.\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Results_folder = \"\" #@param {type:\"string\"}\n","\n","if os.path.exists(Results_folder+\"/Predictions\"):\n"," shutil.rmtree(Results_folder+\"/Predictions\")\n","\n","#@markdown ###Do you want to use the current trained model?\n","\n","Use_the_current_trained_model = True #@param{type:\"boolean\"}\n","#@markdown ###If not, provide the name of the model you want to use \n","\n","### Choosing and editing the path names ###\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","if Use_the_current_trained_model:\n"," Prediction_model_folder = model_path+'/'+model_name\n","\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","Prediction_model_name_x = Prediction_model_name+\"}\"\n","\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","# Convert the path variables into a compatible format with the script files #\n","# Prediction path conversion\n","new_full_Prediction_model_path = convert_to_script_compatible_path(full_Prediction_model_path)\n","new_full_Prediction_model_path_csv = new_full_Prediction_model_path+'\\${TEST_OR_TRAIN}\\.csv'# Get the name of the folder the test data is in\n","\n","# Result path conversion\n","new_Results_folder_path = convert_to_script_compatible_path(Results_folder)\n","\n","# Get permission to the predict.sh file and change the name of the dataset to the Predictions_folder.\n","os.chdir('/content/pytorch_fnet/')\n","!chmod u+x ./scripts/predict.sh\n","\n","### Editing the predict.sh script file ###\n","\n","# Make sure the dataset type is set to .tif (debug note: could be changed at install in predict.py file?)\n","!if ! grep class_dataset ./scripts/predict.sh;then sed -i 's/DIR} \\\\/DIR} \\\\\\'$''\\n' --class_dataset TiffDataset \\\\/' /content/pytorch_fnet/scripts/predict.sh; fi\n","!if grep CziDataset ./scripts/predict.sh;then sed -i 's/CziDataset/TiffDataset/' /content/pytorch_fnet/scripts/predict.sh; fi \n","\n","# We allow the maximum number of images to be processed to be higher, i.e. 1000.\n","!sed -i \"s/N_IMAGES=.*/N_IMAGES=1000/g\" ./scripts/predict.sh\n","!sed -i \"s/1:-.*/1:-$Prediction_model_name_x/g\" ./scripts/predict.sh\n","\n","#Here, we remove the 'train' option from predict.sh as we don't need to run predictions on the train data.\n","!sed -i \"s/in test.*/in test/g\" ./scripts/predict.sh\n","\n","#We change the directories in the predict.sh file to our needed paths\n","!sed -i \"s/MODEL_DIR=.*/MODEL_DIR=$new_full_Prediction_model_path/g\" ./scripts/predict.sh\n","!sed -i \"s/path_dataset_csv.*/path_dataset_csv\\ $new_full_Prediction_model_path_csv\\ \\\\\\/g\" ./scripts/predict.sh\n","!sed -i \"s/path_save_dir.*/path_save_dir $new_Results_folder_path\\/Predictions\\ \\\\\\/g\" ./scripts/predict.sh\n","\n","# Changing the GPU ID seems to help reduce errors\n","replace('/content/pytorch_fnet/scripts/predict.sh','${GPU_IDS}','0')\n","\n","# We get rid of the options of saving signals and targets. Here, we just want predictions.\n","insert_1 = ' --no_signal \\\\\\n'\n","insert_2 = ' --no_target \\\\\\n'\n","add_insert(\"/content/pytorch_fnet/scripts/predict.sh\",14,insert_1,\"\")\n","add_insert(\"/content/pytorch_fnet/scripts/predict.sh\",14,insert_2,\"\")\n","\n","### Create the path csv file for prediction ###\n","\n","#Here we create a new test.csv with the paths to the dataset we want to predict on.\n","test_signal = os.listdir(Data_folder)\n","with open(full_Prediction_model_path+'/test.csv', 'w', newline='') as file:\n"," writer = csv.writer(file)\n"," writer.writerow([\"path_signal\",\"path_target\"])\n"," for i in range(0,len(test_signal)):\n"," writer.writerow([Data_folder+\"/\"+test_signal[i],Data_folder+\"/\"+test_signal[i]])\n","\n","### WE RUN THE PREDICTION ###\n","start = time.time()\n","!/content/pytorch_fnet/scripts/predict.sh $Prediction_model_name 0\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","#Rename the results appropriately\n","Results = os.listdir(Results_folder+'/Predictions')\n","for i in Results:\n"," if os.path.isdir(Results_folder+'/Predictions/'+i):\n"," shutil.copyfile(Results_folder+'/Predictions/'+i+'/'+os.listdir(Results_folder+'/Predictions/'+i)[0],Results_folder+'/Predictions/'+'predicted_'+test_signal[int(i)])\n"," \n","for i in Results:\n"," if os.path.isdir(Results_folder+'/Predictions/'+i):\n"," shutil.rmtree(Results_folder+'/Predictions/'+i)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"bFtArIjs9tS9"},"source":["##**6.2. Assess predicted output**\n","---\n","Here, we inspect an example prediction from the predictions on the test dataset. Select the slice of the slice you want to visualize."]},{"cell_type":"code","metadata":{"id":"66-af3rO9vM4","cellView":"form"},"source":["!pip install matplotlib==2.2.3\n","import numpy as np\n","import matplotlib.pyplot as plt\n","from skimage import io\n","import os\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","\n","#@markdown ###Select the slice would you like to view?\n","slice_number = 15#@param {type:\"number\"}\n","\n","def show_image(file=os.listdir(Data_folder)):\n"," os.chdir(Results_folder)\n","\n","#source_image = io.imread(test_signal[0])\n"," source_image = io.imread(os.path.join(Data_folder,file))\n"," prediction_image = io.imread(os.path.join(Results_folder,'Predictions/predicted_'+file))\n"," prediction_image = np.squeeze(prediction_image, axis=(0,))\n","\n","#Create the figure\n"," fig = plt.figure(figsize=(10,20))\n","\n"," #Setting up colours\n"," cmap = plt.cm.Greys\n","\n"," plt.subplot(1,2,1)\n"," print(prediction_image.shape)\n"," plt.imshow(source_image[slice_number], cmap = cmap, aspect = 'equal')\n"," plt.title('Source')\n"," plt.subplot(1,2,2)\n"," plt.imshow(prediction_image[slice_number], cmap = cmap, aspect = 'equal')\n"," plt.title('Prediction')\n","\n","interact(show_image, continuous_update=False);"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"89tlSWBC940z"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"UvSlTaH14s3t"},"source":["#**Thank you for using fnet!**"]}]} \ No newline at end of file diff --git a/Colab_notebooks/pix2pix_ZeroCostDL4Mic.ipynb b/Colab_notebooks/pix2pix_ZeroCostDL4Mic.ipynb index ac855c97..06fbe32d 100644 --- a/Colab_notebooks/pix2pix_ZeroCostDL4Mic.ipynb +++ b/Colab_notebooks/pix2pix_ZeroCostDL4Mic.ipynb @@ -1,2831 +1 @@ -{ - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "pix2pix_ZeroCostDL4Mic.ipynb", - "provenance": [], - "collapsed_sections": [], - "toc_visible": true, - "machine_shape": "hm" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.7" - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "accelerator": "GPU" - }, - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "IkSguVy8Xv83" - }, - "source": [ - "# **pix2pix**\n", - "\n", - "---\n", - "\n", - "pix2pix is a deep-learning method allowing image-to-image translation from one image domain type to another image domain type. It was first published by [Isola *et al.* in 2016](https://arxiv.org/abs/1611.07004). The image transformation requires paired images for training (supervised learning) and is made possible here by using a conditional Generative Adversarial Network (GAN) architecture to use information from the input image and obtain the equivalent translated image.\n", - "\n", - " **This particular notebook enables image-to-image translation learned from paired dataset. If you are interested in performing unpaired image-to-image translation, you should consider using the CycleGAN notebook instead.**\n", - "\n", - "---\n", - "\n", - "*Disclaimer*:\n", - "\n", - "This notebook is part of the *Zero-Cost Deep-Learning to Enhance Microscopy* project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n", - "\n", - "This notebook is based on the following paper: \n", - "\n", - " **Image-to-Image Translation with Conditional Adversarial Networks** by Isola *et al.* on arXiv in 2016 (https://arxiv.org/abs/1611.07004)\n", - "\n", - "The source code of the PyTorch implementation of pix2pix can be found here: /~https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix\n", - "\n", - "**Please also cite this original paper when using or developing this notebook.**" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "W7HfryEazzJE" - }, - "source": [ - "# **License**\n", - "\n", - "---" - ] - }, - { - "cell_type": "code", - "metadata": { - "cellView": "form", - "id": "4TTFT14b0J6n" - }, - "source": [ - "#@markdown ##Double click to see the license information\n", - "\n", - "#------------------------- LICENSE FOR ZeroCostDL4Mic------------------------------------\n", - "#This ZeroCostDL4Mic notebook is distributed under the MIT licence\n", - "\n", - "\n", - "\n", - "#------------------------- LICENSE FOR CycleGAN ------------------------------------\n", - "\n", - "#Copyright (c) 2017, Jun-Yan Zhu and Taesung Park\n", - "#All rights reserved.\n", - "\n", - "#Redistribution and use in source and binary forms, with or without\n", - "#modification, are permitted provided that the following conditions are met:\n", - "\n", - "#* Redistributions of source code must retain the above copyright notice, this\n", - "# list of conditions and the following disclaimer.\n", - "\n", - "#* Redistributions in binary form must reproduce the above copyright notice,\n", - "# this list of conditions and the following disclaimer in the documentation\n", - "# and/or other materials provided with the distribution.\n", - "\n", - "#THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n", - "#AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n", - "#IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n", - "#DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n", - "#FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n", - "#DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n", - "#SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n", - "#CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n", - "#OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n", - "#OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n", - "\n", - "\n", - "#--------------------------- LICENSE FOR pix2pix --------------------------------\n", - "#BSD License\n", - "\n", - "#For pix2pix software\n", - "#Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu\n", - "#All rights reserved.\n", - "\n", - "#Redistribution and use in source and binary forms, with or without\n", - "#modification, are permitted provided that the following conditions are met:\n", - "\n", - "#* Redistributions of source code must retain the above copyright notice, this\n", - "# list of conditions and the following disclaimer.\n", - "\n", - "#* Redistributions in binary form must reproduce the above copyright notice,\n", - "# this list of conditions and the following disclaimer in the documentation\n", - "# and/or other materials provided with the distribution.\n", - "\n", - "#----------------------------- LICENSE FOR DCGAN --------------------------------\n", - "#BSD License\n", - "\n", - "#For dcgan.torch software\n", - "\n", - "#Copyright (c) 2015, Facebook, Inc. All rights reserved.\n", - "\n", - "#Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:\n", - "\n", - "#Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.\n", - "\n", - "#Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.\n", - "\n", - "#Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.\n", - "\n", - "#THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE." - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jWAz2i7RdxUV" - }, - "source": [ - "# **How to use this notebook?**\n", - "\n", - "---\n", - "\n", - "Video describing how to use our notebooks are available on youtube:\n", - " - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n", - " - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n", - "\n", - "\n", - "---\n", - "###**Structure of a notebook**\n", - "\n", - "The notebook contains two types of cell: \n", - "\n", - "**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n", - "\n", - "**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n", - "\n", - "---\n", - "###**Table of contents, Code snippets** and **Files**\n", - "\n", - "On the top left side of the notebook you find three tabs which contain from top to bottom:\n", - "\n", - "*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n", - "\n", - "*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n", - "\n", - "*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n", - "\n", - "**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n", - "\n", - "**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n", - "\n", - "---\n", - "###**Making changes to the notebook**\n", - "\n", - "**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n", - "\n", - "To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n", - "You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "gKDLkLWUd-YX" - }, - "source": [ - "#**0. Before getting started**\n", - "---\n", - " For pix2pix to train, **it needs to have access to a paired training dataset**. This means that the same image needs to be acquired in the two conditions and provided with indication of correspondence.\n", - "\n", - " Therefore, the data structure is important. It is necessary that all the input data are in the same folder and that all the output data is in a separate folder. The provided training dataset is already split in two folders called Training_source and Training_target. Information on how to generate a training dataset is available in our Wiki page: /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n", - "\n", - "**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n", - "\n", - " **Additionally, the corresponding input and output files need to have the same name**.\n", - "\n", - " Please note that you currently can **only use .PNG files!**\n", - "\n", - "\n", - "Here's a common data structure that can work:\n", - "* Experiment A\n", - " - **Training dataset**\n", - " - Training_source\n", - " - img_1.png, img_2.png, ...\n", - " - Training_target\n", - " - img_1.png, img_2.png, ...\n", - " - **Quality control dataset**\n", - " - Training_source\n", - " - img_1.png, img_2.png\n", - " - Training_target\n", - " - img_1.png, img_2.png\n", - " - **Data to be predicted**\n", - " - **Results**\n", - "\n", - "---\n", - "**Important note**\n", - "\n", - "- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n", - "\n", - "- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n", - "\n", - "- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "n4yWFoJNnoin" - }, - "source": [ - "# **1. Initialise the Colab session**\n", - "\n", - "\n", - "\n", - "\n", - "---\n", - "\n", - "\n", - "\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "DMNHVZfHmbKb" - }, - "source": [ - "\n", - "## **1.1. Check for GPU access**\n", - "---\n", - "\n", - "By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n", - "\n", - "Go to **Runtime -> Change the Runtime type**\n", - "\n", - "**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n", - "\n", - "**Accelerator: GPU** *(Graphics processing unit)*\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "zCvebubeSaGY", - "cellView": "form" - }, - "source": [ - "#@markdown ##Run this cell to check if you have GPU access\n", - "\n", - "\n", - "import tensorflow as tf\n", - "if tf.test.gpu_device_name()=='':\n", - " print('You do not have GPU access.') \n", - " print('Did you change your runtime ?') \n", - " print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n", - " print('Expect slow performance. To access GPU try reconnecting later')\n", - "\n", - "else:\n", - " print('You have GPU access')\n", - " !nvidia-smi" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "sNIVx8_CLolt" - }, - "source": [ - "## **1.2. Mount your Google Drive**\n", - "---\n", - " To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n", - "\n", - " Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n", - "\n", - " Once this is done, your data are available in the **Files** tab on the top left of notebook." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "01Djr8v-5pPk", - "cellView": "form" - }, - "source": [ - "#@markdown ##Play the cell to connect your Google Drive to Colab\n", - "\n", - "#@markdown * Click on the URL. \n", - "\n", - "#@markdown * Sign in your Google Account. \n", - "\n", - "#@markdown * Copy the authorization code. \n", - "\n", - "#@markdown * Enter the authorization code. \n", - "\n", - "#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n", - "\n", - "# mount user's Google Drive to Google Colab.\n", - "from google.colab import drive\n", - "drive.mount('/content/gdrive')" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "AdN8B91xZO0x" - }, - "source": [ - "# **2. Install pix2pix and dependencies**\n", - "---\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "fq21zJVFNASx", - "cellView": "form" - }, - "source": [ - "Notebook_version = ['1.12.4']\n", - "\n", - "from builtins import any as b_any\n", - "\n", - "def get_requirements_path():\n", - " # Store requirements file in 'contents' directory \n", - " current_dir = os.getcwd()\n", - " dir_count = current_dir.count('/') - 1\n", - " path = '../' * (dir_count) + 'requirements.txt'\n", - " return path\n", - "\n", - "def filter_files(file_list, filter_list):\n", - " filtered_list = []\n", - " for fname in file_list:\n", - " if b_any(fname.split('==')[0] in s for s in filter_list):\n", - " filtered_list.append(fname)\n", - " return filtered_list\n", - "\n", - "def build_requirements_file(before, after):\n", - " path = get_requirements_path()\n", - "\n", - " # Exporting requirements.txt for local run\n", - " !pip freeze > $path\n", - "\n", - " # Get minimum requirements file\n", - " df = pd.read_csv(path, delimiter = \"\\n\")\n", - " mod_list = [m.split('.')[0] for m in after if not m in before]\n", - " req_list_temp = df.values.tolist()\n", - " req_list = [x[0] for x in req_list_temp]\n", - "\n", - " # Replace with package name and handle cases where import name is different to module name\n", - " mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]\n", - " mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list] \n", - " filtered_list = filter_files(req_list, mod_replace_list)\n", - "\n", - " file=open(path,'w')\n", - " for item in filtered_list:\n", - " file.writelines(item + '\\n')\n", - "\n", - " file.close()\n", - "\n", - "import sys\n", - "before = [str(m) for m in sys.modules]\n", - "\n", - "#@markdown ##Install pix2pix and dependencies\n", - "\n", - "#Here, we install libraries which are not already included in Colab.\n", - "!git clone /~https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix\n", - "import os\n", - "os.chdir('pytorch-CycleGAN-and-pix2pix/')\n", - "!pip install -r requirements.txt\n", - "!pip install fpdf\n", - "!pip install lpips\n", - "\n", - "import lpips\n", - "from PIL import Image\n", - "import imageio\n", - "from skimage import data\n", - "from skimage import exposure\n", - "from skimage.exposure import match_histograms\n", - "import os.path\n", - "\n", - "# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n", - "import numpy as np\n", - "from matplotlib import pyplot as plt\n", - "import urllib\n", - "import os, random\n", - "import shutil \n", - "import zipfile\n", - "from tifffile import imread, imsave\n", - "import time\n", - "import sys\n", - "from pathlib import Path\n", - "import pandas as pd\n", - "import csv\n", - "from glob import glob\n", - "from scipy import signal\n", - "from scipy import ndimage\n", - "from skimage import io\n", - "from sklearn.linear_model import LinearRegression\n", - "from skimage.util import img_as_uint\n", - "import matplotlib as mpl\n", - "from skimage.metrics import structural_similarity\n", - "from skimage.metrics import peak_signal_noise_ratio as psnr\n", - "from astropy.visualization import simple_norm\n", - "from skimage import img_as_float32\n", - "from skimage.util import img_as_ubyte\n", - "from tqdm import tqdm \n", - "from fpdf import FPDF, HTMLMixin\n", - "from datetime import datetime\n", - "import subprocess\n", - "from pip._internal.operations.freeze import freeze\n", - "\n", - "# Colors for the warning messages\n", - "class bcolors:\n", - " WARNING = '\\033[31m'\n", - "\n", - "#Disable some of the tensorflow warnings\n", - "import warnings\n", - "warnings.filterwarnings(\"ignore\")\n", - "\n", - "print('----------------------------')\n", - "print(\"Libraries installed\")\n", - "\n", - "# Check if this is the latest version of the notebook\n", - "Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n", - "\n", - "if Notebook_version == list(Latest_notebook_version.columns):\n", - " print(\"This notebook is up-to-date.\")\n", - "\n", - "if not Notebook_version == list(Latest_notebook_version.columns):\n", - " print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n", - "\n", - "# average function\n", - "def Average(lst): \n", - " return sum(lst) / len(lst) \n", - "\n", - "def perceptual_diff(im0, im1, network, spatial):\n", - "\n", - " tensor0 = lpips.im2tensor(im0)\n", - " tensor1 = lpips.im2tensor(im1)\n", - " # Set up the loss function we will use\n", - " loss_fn = lpips.LPIPS(net=network, spatial=spatial, verbose=False)\n", - "\n", - " diff = loss_fn.forward(tensor0, tensor1)\n", - "\n", - " return diff\n", - "\n", - "def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n", - " class MyFPDF(FPDF, HTMLMixin):\n", - " pass\n", - "\n", - " pdf = MyFPDF()\n", - " pdf.add_page()\n", - " pdf.set_right_margin(-1)\n", - " pdf.set_font(\"Arial\", size = 11, style='B') \n", - "\n", - " Network = 'pix2pix'\n", - " day = datetime.now()\n", - " datetime_str = str(day)[0:10]\n", - "\n", - " Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n", - " pdf.multi_cell(180, 5, txt = Header, align = 'L') \n", - " \n", - " # add another cell \n", - " if trained:\n", - " training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n", - " pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n", - " pdf.ln(1)\n", - "\n", - " Header_2 = 'Information for your materials and method:'\n", - " pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n", - "\n", - " all_packages = ''\n", - " for requirement in freeze(local_only=True):\n", - " all_packages = all_packages+requirement+', '\n", - " #print(all_packages)\n", - "\n", - " #Main Packages\n", - " main_packages = ''\n", - " version_numbers = []\n", - " for name in ['tensorflow','numpy','torch']:\n", - " find_name=all_packages.find(name)\n", - " main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n", - " #Version numbers only here:\n", - " version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n", - "\n", - " cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n", - " cuda_version = cuda_version.stdout.decode('utf-8')\n", - " cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n", - " gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n", - " gpu_name = gpu_name.stdout.decode('utf-8')\n", - " gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n", - " #print(cuda_version[cuda_version.find(', V')+3:-1])\n", - " #print(gpu_name)\n", - "\n", - " shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n", - " dataset_size = len(os.listdir(Training_source))\n", - "\n", - " text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a vanilla GAN loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), numpy (v '+version_numbers[1]+'), torch (v '+version_numbers[2]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n", - "\n", - " if Use_pretrained_model:\n", - " text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a vanilla GAN loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was retrained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), numpy (v '+version_numbers[1]+'), torch (v '+version_numbers[2]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n", - "\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - " pdf.multi_cell(190, 5, txt = text, align='L')\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.ln(1)\n", - " pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n", - " pdf.set_font('')\n", - " if augmentation:\n", - " aug_text = 'The dataset was augmented by a factor of '+str(Multiply_dataset_by)+' by'\n", - " if rotate_270_degrees != 0 or rotate_90_degrees != 0:\n", - " aug_text = aug_text+'\\n- rotation'\n", - " if flip_left_right != 0 or flip_top_bottom != 0:\n", - " aug_text = aug_text+'\\n- flipping'\n", - " if random_zoom_magnification != 0:\n", - " aug_text = aug_text+'\\n- random zoom magnification'\n", - " if random_distortion != 0:\n", - " aug_text = aug_text+'\\n- random distortion'\n", - " if image_shear != 0:\n", - " aug_text = aug_text+'\\n- image shearing'\n", - " if skew_image != 0:\n", - " aug_text = aug_text+'\\n- image skewing'\n", - " else:\n", - " aug_text = 'No augmentation was used for training.'\n", - " pdf.multi_cell(190, 5, txt=aug_text, align='L')\n", - " pdf.set_font('Arial', size = 11, style = 'B')\n", - " pdf.ln(1)\n", - " pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - " if Use_Default_Advanced_Parameters:\n", - " pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n", - " pdf.cell(200, 5, txt='The following parameters were used for training:')\n", - " pdf.ln(1)\n", - " html = \"\"\" \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
ParameterValue
number_of_epochs{0}
patch_size{1}
batch_size{2}
initial_learning_rate{3}
\n", - " \"\"\".format(number_of_epochs,str(patch_size)+'x'+str(patch_size),batch_size,initial_learning_rate)\n", - " pdf.write_html(html)\n", - "\n", - " #pdf.multi_cell(190, 5, txt = text_2, align='L')\n", - " pdf.set_font(\"Arial\", size = 11, style='B')\n", - " pdf.ln(1)\n", - " pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.cell(30, 5, txt= 'Training_source:', align = 'L', ln=0)\n", - " pdf.set_font('')\n", - " pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.cell(29, 5, txt= 'Training_target:', align = 'L', ln=0)\n", - " pdf.set_font('')\n", - " pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n", - " #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n", - " pdf.ln(1)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.cell(22, 5, txt= 'Model Path:', align = 'L', ln=0)\n", - " pdf.set_font('')\n", - " pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n", - " pdf.ln(1)\n", - " pdf.cell(60, 5, txt = 'Example Training pair', ln=1)\n", - " pdf.ln(1)\n", - " exp_size = io.imread('/content/TrainingDataExample_pix2pix.png').shape\n", - " pdf.image('/content/TrainingDataExample_pix2pix.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n", - " pdf.ln(1)\n", - " ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n", - " pdf.multi_cell(190, 5, txt = ref_1, align='L')\n", - " ref_2 = '- pix2pix: Isola, Phillip, et al. \"Image-to-image translation with conditional adversarial networks.\" Proceedings of the IEEE conference on computer vision and pattern recognition. 2017.'\n", - " pdf.multi_cell(190, 5, txt = ref_2, align='L')\n", - " if augmentation:\n", - " ref_3 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\n", - " pdf.multi_cell(190, 5, txt = ref_3, align='L')\n", - " pdf.ln(3)\n", - " reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n", - " pdf.set_font('Arial', size = 11, style='B')\n", - " pdf.multi_cell(190, 5, txt=reminder, align='C')\n", - "\n", - " pdf.output(model_path+'/'+model_name+'/'+model_name+\"_training_report.pdf\")\n", - "\n", - "def qc_pdf_export():\n", - " class MyFPDF(FPDF, HTMLMixin):\n", - " pass\n", - "\n", - " pdf = MyFPDF()\n", - " pdf.add_page()\n", - " pdf.set_right_margin(-1)\n", - " pdf.set_font(\"Arial\", size = 11, style='B') \n", - "\n", - " Network = 'pix2pix'\n", - "\n", - "\n", - " day = datetime.now()\n", - " datetime_str = str(day)[0:10]\n", - "\n", - " Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n", - " pdf.multi_cell(180, 5, txt = Header, align = 'L') \n", - "\n", - " all_packages = ''\n", - " for requirement in freeze(local_only=True):\n", - " all_packages = all_packages+requirement+', '\n", - "\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 11, style = 'B')\n", - " pdf.ln(2)\n", - " pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n", - " pdf.ln(1)\n", - " exp_size = io.imread(full_QC_model_path+'/Quality Control/SSIMvsCheckpoint_data.png').shape\n", - " pdf.image(full_QC_model_path+'/Quality Control/SSIMvsCheckpoint_data.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n", - " pdf.ln(2)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.ln(3)\n", - " pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n", - " pdf.ln(1)\n", - " exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n", - " if Image_type == 'RGB':\n", - " pdf.image(full_QC_model_path+'/Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/5), h = round(exp_size[0]/5))\n", - " if Image_type == 'Grayscale':\n", - " pdf.image(full_QC_model_path+'/Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n", - " pdf.ln(1)\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 11, style = 'B')\n", - " pdf.ln(1)\n", - " pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - "\n", - " pdf.ln(1)\n", - " for checkpoint in os.listdir(full_QC_model_path+'/Quality Control'):\n", - " if os.path.isdir(os.path.join(full_QC_model_path,'Quality Control',checkpoint)) and checkpoint != 'Prediction':\n", - " pdf.set_font('')\n", - " pdf.set_font('Arial', size = 10, style = 'B')\n", - " pdf.cell(70, 5, txt = 'Metrics for checkpoint: '+ str(checkpoint), align='L', ln=1)\n", - " html = \"\"\"\n", - " \n", - " \n", - " \"\"\"\n", - " with open(full_QC_model_path+'/Quality Control/'+str(checkpoint)+'/QC_metrics_'+QC_model_name+str(checkpoint)+'.csv', 'r') as csvfile:\n", - " metrics = csv.reader(csvfile)\n", - " header = next(metrics)\n", - " image = header[0]\n", - " mSSIM_PvsGT = header[1]\n", - " mSSIM_SvsGT = header[2]\n", - " header = \"\"\"\n", - " \n", - " \n", - " \n", - " \n", - " \"\"\".format(image,mSSIM_PvsGT,mSSIM_SvsGT)\n", - " html = html+header\n", - " for row in metrics:\n", - " image = row[0]\n", - " mSSIM_PvsGT = row[1]\n", - " mSSIM_SvsGT = row[2]\n", - " cells = \"\"\"\n", - " \n", - " \n", - " \n", - " \n", - " \"\"\".format(image,str(round(float(mSSIM_PvsGT),3)),str(round(float(mSSIM_SvsGT),3)))\n", - " html = html+cells\n", - " html = html+\"\"\"
{0}{1}{2}
{0}{1}{2}
\"\"\"\n", - " pdf.write_html(html)\n", - " pdf.ln(2)\n", - " else:\n", - " continue\n", - "\n", - " pdf.ln(1)\n", - " pdf.set_font('')\n", - " pdf.set_font_size(10.)\n", - " ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n", - " pdf.multi_cell(190, 5, txt = ref_1, align='L')\n", - " ref_2 = '- pix2pix: Isola, Phillip, et al. \"Image-to-image translation with conditional adversarial networks.\" Proceedings of the IEEE conference on computer vision and pattern recognition. 2017.'\n", - " pdf.multi_cell(190, 5, txt = ref_2, align='L')\n", - "\n", - " pdf.ln(3)\n", - " reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n", - "\n", - " pdf.set_font('Arial', size = 11, style='B')\n", - " pdf.multi_cell(190, 5, txt=reminder, align='C')\n", - "\n", - " pdf.output(full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n", - "\n", - "# Build requirements file for local run\n", - "after = [str(m) for m in sys.modules]\n", - "build_requirements_file(before, after)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "HLYcZR9gMv42" - }, - "source": [ - "# **3. Select your parameters and paths**\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "FQ_QxtSWQ7CL" - }, - "source": [ - "## **3.1. Setting main training parameters**\n", - "---\n", - " " - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "AuESFimvMv43" - }, - "source": [ - " **Paths for training, predictions and results**\n", - "\n", - "**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source and Training_target training data respecively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n", - "\n", - "**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n", - "\n", - "**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n", - "\n", - "**Training parameters**\n", - "\n", - "**`number_of_epochs`:**Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10) epochs, but a full training should run for 200 epochs or more. Evaluate the performance after training (see 5). **Default value: 200**\n", - "\n", - "**Advanced Parameters - experienced users only**\n", - "\n", - "**`patch_size`:** pix2pix divides the image into patches for training. Input the size of the patches (length of a side). The value should be smaller than the dimensions of the image and divisible by 8. **Default value: 512**\n", - "\n", - "**When choosing the patch_size, the value should be i) large enough that it will enclose many instances, ii) small enough that the resulting patches fit into the RAM.** \n", - "\n", - "**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 1**\n", - "\n", - "**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0002**" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "ewpNJ_I0Mv47", - "cellView": "form" - }, - "source": [ - "import cv2\n", - "\n", - "#@markdown ###Path to training images:\n", - "\n", - "Training_source = \"\" #@param {type:\"string\"}\n", - "#InputFile = Training_source+\"/*.png\"\n", - "\n", - "Training_target = \"\" #@param {type:\"string\"}\n", - "#OutputFile = Training_target+\"/*.png\"\n", - "\n", - "#@markdown ###Image normalisation:\n", - "\n", - "Normalisation_training_source = \"None\" #@param [\"None\", \"Contrast stretching\", \"Adaptive Equalization\"]\n", - "Normalisation_training_target = \"None\" #@param [\"None\", \"Contrast stretching\", \"Adaptive Equalization\"]\n", - "\n", - "\n", - "#Define where the patch file will be saved\n", - "base = \"/content\"\n", - "\n", - "# model name and path\n", - "#@markdown ###Name of the model and path to model folder:\n", - "model_name = \"\" #@param {type:\"string\"}\n", - "model_path = \"\" #@param {type:\"string\"}\n", - "\n", - "# other parameters for training.\n", - "#@markdown ###Training Parameters\n", - "#@markdown Number of epochs:\n", - "number_of_epochs = 200#@param {type:\"number\"}\n", - "\n", - "#@markdown ###Advanced Parameters\n", - "\n", - "Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n", - "#@markdown ###If not, please input:\n", - "patch_size = 512#@param {type:\"number\"} # in pixels\n", - "batch_size = 1#@param {type:\"number\"}\n", - "initial_learning_rate = 0.0002 #@param {type:\"number\"}\n", - "\n", - "\n", - "if (Use_Default_Advanced_Parameters): \n", - " print(\"Default advanced parameters enabled\")\n", - " batch_size = 1\n", - " patch_size = 512\n", - " initial_learning_rate = 0.0002\n", - "\n", - "#here we check that no model with the same name already exist, if so delete\n", - "if os.path.exists(model_path+'/'+model_name):\n", - " print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n", - " print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\")\n", - " \n", - "#To use pix2pix we need to organise the data in a way the network can understand\n", - "\n", - "Saving_path= \"/content/\"+model_name\n", - "\n", - "if os.path.exists(Saving_path):\n", - " shutil.rmtree(Saving_path)\n", - "os.makedirs(Saving_path)\n", - "\n", - "imageA_folder = Saving_path+\"/A\"\n", - "os.makedirs(imageA_folder)\n", - "\n", - "imageB_folder = Saving_path+\"/B\"\n", - "os.makedirs(imageB_folder)\n", - "\n", - "imageAB_folder = Saving_path+\"/AB\"\n", - "os.makedirs(imageAB_folder)\n", - "\n", - "TrainA_Folder = Saving_path+\"/A/train\"\n", - "os.makedirs(TrainA_Folder)\n", - " \n", - "TrainB_Folder = Saving_path+\"/B/train\"\n", - "os.makedirs(TrainB_Folder)\n", - "\n", - "# Here we disable pre-trained model by default (in case the cell is not ran)\n", - "Use_pretrained_model = False\n", - "\n", - "# Here we disable data augmentation by default (in case the cell is not ran)\n", - "\n", - "Use_Data_augmentation = False\n", - "\n", - "# Here we normalise the image is enabled\n", - "\n", - "if Normalisation_training_source == \"Contrast stretching\":\n", - "\n", - " Training_source_norm = Saving_path+\"/Training_source_norm\"\n", - " os.makedirs(Training_source_norm)\n", - " \n", - " for filename in os.listdir(Training_source):\n", - "\n", - " img = imread(os.path.join(Training_source,filename)).astype(np.float32)\n", - " short_name = os.path.splitext(filename)\n", - "\n", - " p2, p99 = np.percentile(img, (2, 99.9))\n", - " img = exposure.rescale_intensity(img, in_range=(p2, p99))\n", - "\n", - " img = 255 * img # Now scale by 255\n", - " img = img.astype(np.uint8)\n", - " cv2.imwrite(Training_source_norm+\"/\"+short_name[0]+\".png\", img)\n", - " \n", - " Training_source = Training_source_norm\n", - "\n", - "\n", - "if Normalisation_training_target == \"Contrast stretching\":\n", - "\n", - " Training_target_norm = Saving_path+\"/Training_target_norm\"\n", - " os.makedirs(Training_target_norm)\n", - "\n", - " for filename in os.listdir(Training_target):\n", - "\n", - " img = imread(os.path.join(Training_target,filename)).astype(np.float32)\n", - " short_name = os.path.splitext(filename)\n", - "\n", - " p2, p99 = np.percentile(img, (2, 99.9))\n", - " img = exposure.rescale_intensity(img, in_range=(p2, p99))\n", - "\n", - " img = 255 * img # Now scale by 255\n", - " img = img.astype(np.uint8)\n", - " cv2.imwrite(Training_target_norm+\"/\"+short_name[0]+\".png\", img)\n", - "\n", - " Training_target = Training_target_norm\n", - "\n", - "\n", - "if Normalisation_training_source == \"Adaptive Equalization\":\n", - " Training_source_norm = Saving_path+\"/Training_source_norm\"\n", - " os.makedirs(Training_source_norm)\n", - "\n", - " for filename in os.listdir(Training_source):\n", - "\n", - " img = imread(os.path.join(Training_source,filename))\n", - " short_name = os.path.splitext(filename)\n", - "\n", - " img = exposure.equalize_adapthist(img, clip_limit=0.03)\n", - "\n", - " img = 255 * img # Now scale by 255\n", - " img = img.astype(np.uint8)\n", - " cv2.imwrite(Training_source_norm+\"/\"+short_name[0]+\".png\", img)\n", - "\n", - "\n", - " Training_source = Training_source_norm\n", - "\n", - "\n", - "if Normalisation_training_target == \"Adaptive Equalization\":\n", - "\n", - " Training_target_norm = Saving_path+\"/Training_target_norm\"\n", - " os.makedirs(Training_target_norm)\n", - "\n", - " for filename in os.listdir(Training_target):\n", - "\n", - " img = imread(os.path.join(Training_target,filename))\n", - " short_name = os.path.splitext(filename)\n", - "\n", - " p2, p99 = np.percentile(img, (2, 99.8))\n", - " img = exposure.rescale_intensity(img, in_range=(p2, p99))\n", - "\n", - " img = 255 * img # Now scale by 255\n", - " img = img.astype(np.uint8)\n", - " cv2.imwrite(Training_target_norm+\"/\"+short_name[0]+\".png\", img)\n", - "\n", - " Training_target = Training_target_norm\n", - "\n", - "# This will display a randomly chosen dataset input and output\n", - "random_choice = random.choice(os.listdir(Training_source))\n", - "x = io.imread(Training_source+\"/\"+random_choice)\n", - "\n", - "#Find image XY dimension\n", - "Image_Y = x.shape[0]\n", - "Image_X = x.shape[1]\n", - "\n", - "Image_min_dim = min(Image_Y, Image_X)\n", - "\n", - "#Hyperparameters failsafes\n", - "if patch_size > min(Image_Y, Image_X):\n", - " patch_size = min(Image_Y, Image_X)\n", - " print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n", - "\n", - "# Here we check that patch_size is divisible by 4\n", - "if not patch_size % 4 == 0:\n", - " patch_size = ((int(patch_size / 4)-1) * 4)\n", - " print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 4; therefore the patch_size chosen is now:\",patch_size)\n", - "\n", - "# Here we check that patch_size is at least bigger than 256\n", - "if patch_size < 256:\n", - " patch_size = 256\n", - " print (bcolors.WARNING + \" Your chosen patch_size is too small; therefore the patch_size chosen is now:\",patch_size)\n", - "\n", - "y = io.imread(Training_target+\"/\"+random_choice)\n", - "\n", - "n_channel_x = 1 if x.ndim == 2 else x.shape[-1]\n", - "n_channel_y = 1 if y.ndim == 2 else y.shape[-1]\n", - "\n", - "if n_channel_x == 1:\n", - " cmap_x = 'gray'\n", - "else:\n", - " cmap_x = None\n", - "\n", - "if n_channel_y == 1:\n", - " cmap_y = 'gray'\n", - "else:\n", - " cmap_y = None\n", - "\n", - "f=plt.figure(figsize=(16,8))\n", - "plt.subplot(1,2,1)\n", - "plt.imshow(x, cmap=cmap_x, interpolation='nearest')\n", - "plt.title('Training source')\n", - "plt.axis('off');\n", - "\n", - "plt.subplot(1,2,2)\n", - "plt.imshow(y, cmap=cmap_y, interpolation='nearest')\n", - "plt.title('Training target')\n", - "plt.axis('off');\n", - "\n", - "plt.savefig('/content/TrainingDataExample_pix2pix.png',bbox_inches='tight',pad_inches=0)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "xyQZKby8yFME" - }, - "source": [ - "## **3.2. Data augmentation**\n", - "---\n", - "" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "w_jCy7xOx2g3" - }, - "source": [ - "Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n", - "\n", - "Data augmentation is performed here by [Augmentor.](/~https://github.com/mdbloice/Augmentor)\n", - "\n", - "[Augmentor](/~https://github.com/mdbloice/Augmentor) was described in the following article:\n", - "\n", - "Marcus D Bloice, Peter M Roth, Andreas Holzinger, Biomedical image augmentation using Augmentor, Bioinformatics, https://doi.org/10.1093/bioinformatics/btz259\n", - "\n", - "**Please also cite this original paper when publishing results obtained using this notebook with augmentation enabled.** " - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "DMqWq5-AxnFU", - "cellView": "form" - }, - "source": [ - "#Data augmentation\n", - "\n", - "Use_Data_augmentation = False #@param {type:\"boolean\"}\n", - "\n", - "if Use_Data_augmentation:\n", - " !pip install Augmentor\n", - " import Augmentor\n", - "\n", - "\n", - "#@markdown ####Choose a factor by which you want to multiply your original dataset\n", - "\n", - "Multiply_dataset_by = 2 #@param {type:\"slider\", min:1, max:30, step:1}\n", - "\n", - "Save_augmented_images = False #@param {type:\"boolean\"}\n", - "\n", - "Saving_path = \"\" #@param {type:\"string\"}\n", - "\n", - "\n", - "Use_Default_Augmentation_Parameters = True #@param {type:\"boolean\"}\n", - "#@markdown ###If not, please choose the probability of the following image manipulations to be used to augment your dataset (1 = always used; 0 = disabled ):\n", - "\n", - "#@markdown ####Mirror and rotate images\n", - "rotate_90_degrees = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n", - "\n", - "rotate_270_degrees = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n", - "\n", - "flip_left_right = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n", - "\n", - "flip_top_bottom = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n", - "\n", - "#@markdown ####Random image Zoom\n", - "\n", - "random_zoom = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n", - "\n", - "random_zoom_magnification = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n", - "\n", - "#@markdown ####Random image distortion\n", - "\n", - "random_distortion = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n", - "\n", - "\n", - "#@markdown ####Image shearing and skewing \n", - "\n", - "image_shear = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n", - "max_image_shear = 10 #@param {type:\"slider\", min:1, max:25, step:1}\n", - "\n", - "skew_image = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n", - "\n", - "skew_image_magnitude = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n", - "\n", - "\n", - "if Use_Default_Augmentation_Parameters:\n", - " rotate_90_degrees = 0.5\n", - " rotate_270_degrees = 0.5\n", - " flip_left_right = 0.5\n", - " flip_top_bottom = 0.5\n", - "\n", - " if not Multiply_dataset_by >5:\n", - " random_zoom = 0\n", - " random_zoom_magnification = 0.9\n", - " random_distortion = 0\n", - " image_shear = 0\n", - " max_image_shear = 10\n", - " skew_image = 0\n", - " skew_image_magnitude = 0\n", - "\n", - " if Multiply_dataset_by >5:\n", - " random_zoom = 0.1\n", - " random_zoom_magnification = 0.9\n", - " random_distortion = 0.5\n", - " image_shear = 0.2\n", - " max_image_shear = 5\n", - " skew_image = 0.2\n", - " skew_image_magnitude = 0.4\n", - "\n", - " if Multiply_dataset_by >25:\n", - " random_zoom = 0.5\n", - " random_zoom_magnification = 0.8\n", - " random_distortion = 0.5\n", - " image_shear = 0.5\n", - " max_image_shear = 20\n", - " skew_image = 0.5\n", - " skew_image_magnitude = 0.6\n", - "\n", - "\n", - "list_files = os.listdir(Training_source)\n", - "Nb_files = len(list_files)\n", - "\n", - "Nb_augmented_files = (Nb_files * Multiply_dataset_by)\n", - "\n", - "\n", - "if Use_Data_augmentation:\n", - " print(\"Data augmentation enabled\")\n", - "# Here we set the path for the various folder were the augmented images will be loaded\n", - "\n", - "# All images are first saved into the augmented folder\n", - " #Augmented_folder = \"/content/Augmented_Folder\"\n", - " \n", - " if not Save_augmented_images:\n", - " Saving_path= \"/content\"\n", - "\n", - " Augmented_folder = Saving_path+\"/Augmented_Folder\"\n", - " if os.path.exists(Augmented_folder):\n", - " shutil.rmtree(Augmented_folder)\n", - " os.makedirs(Augmented_folder)\n", - "\n", - " #Training_source_augmented = \"/content/Training_source_augmented\"\n", - " Training_source_augmented = Saving_path+\"/Training_source_augmented\"\n", - "\n", - " if os.path.exists(Training_source_augmented):\n", - " shutil.rmtree(Training_source_augmented)\n", - " os.makedirs(Training_source_augmented)\n", - "\n", - " #Training_target_augmented = \"/content/Training_target_augmented\"\n", - " Training_target_augmented = Saving_path+\"/Training_target_augmented\"\n", - "\n", - " if os.path.exists(Training_target_augmented):\n", - " shutil.rmtree(Training_target_augmented)\n", - " os.makedirs(Training_target_augmented)\n", - "\n", - "\n", - "# Here we generate the augmented images\n", - "#Load the images\n", - " p = Augmentor.Pipeline(Training_source, Augmented_folder)\n", - "\n", - "#Define the matching images\n", - " p.ground_truth(Training_target)\n", - "#Define the augmentation possibilities\n", - " if not rotate_90_degrees == 0:\n", - " p.rotate90(probability=rotate_90_degrees)\n", - " \n", - " if not rotate_270_degrees == 0:\n", - " p.rotate270(probability=rotate_270_degrees)\n", - "\n", - " if not flip_left_right == 0:\n", - " p.flip_left_right(probability=flip_left_right)\n", - "\n", - " if not flip_top_bottom == 0:\n", - " p.flip_top_bottom(probability=flip_top_bottom)\n", - "\n", - " if not random_zoom == 0:\n", - " p.zoom_random(probability=random_zoom, percentage_area=random_zoom_magnification)\n", - " \n", - " if not random_distortion == 0:\n", - " p.random_distortion(probability=random_distortion, grid_width=4, grid_height=4, magnitude=8)\n", - "\n", - " if not image_shear == 0:\n", - " p.shear(probability=image_shear,max_shear_left=20,max_shear_right=20)\n", - " \n", - " if not skew_image == 0:\n", - " p.skew(probability=skew_image,magnitude=skew_image_magnitude)\n", - "\n", - " p.sample(int(Nb_augmented_files))\n", - "\n", - " print(int(Nb_augmented_files),\"matching images generated\")\n", - "\n", - "# Here we sort through the images and move them back to augmented trainning source and targets folders\n", - "\n", - " augmented_files = os.listdir(Augmented_folder)\n", - "\n", - " for f in augmented_files:\n", - "\n", - " if (f.startswith(\"_groundtruth_(1)_\")):\n", - " shortname_noprefix = f[17:]\n", - " shutil.copyfile(Augmented_folder+\"/\"+f, Training_target_augmented+\"/\"+shortname_noprefix) \n", - " if not (f.startswith(\"_groundtruth_(1)_\")):\n", - " shutil.copyfile(Augmented_folder+\"/\"+f, Training_source_augmented+\"/\"+f)\n", - " \n", - "\n", - " for filename in os.listdir(Training_source_augmented):\n", - " os.chdir(Training_source_augmented)\n", - " os.rename(filename, filename.replace('_original', ''))\n", - " \n", - " #Here we clean up the extra files\n", - " shutil.rmtree(Augmented_folder)\n", - "\n", - "if not Use_Data_augmentation:\n", - " print(bcolors.WARNING+\"Data augmentation disabled\") \n", - "\n", - "\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "3L9zSGtORKYI" - }, - "source": [ - "\n", - "## **3.3. Using weights from a pre-trained model as initial weights**\n", - "---\n", - " Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a pix2pix model**. \n", - "\n", - " This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "9vC2n-HeLdiJ", - "cellView": "form" - }, - "source": [ - "# @markdown ##Loading weights from a pre-trained network\n", - "\n", - "\n", - "Use_pretrained_model = False #@param {type:\"boolean\"}\n", - "\n", - "\n", - "#@markdown ###If yes, please provide the path to the model folder:\n", - "pretrained_model_path = \"\" #@param {type:\"string\"}\n", - "\n", - "# --------------------- Check if we load a previously trained model ------------------------\n", - "if Use_pretrained_model:\n", - "\n", - " h5_file_path = os.path.join(pretrained_model_path, \"latest_net_G.pth\")\n", - " \n", - "\n", - "# --------------------- Check the model exist ------------------------\n", - "\n", - " if not os.path.exists(h5_file_path):\n", - " print(bcolors.WARNING+'WARNING: Pretrained model does not exist')\n", - " Use_pretrained_model = False\n", - " print(bcolors.WARNING+'No pretrained network will be used.')\n", - "\n", - " if os.path.exists(h5_file_path):\n", - " print(\"Pretrained model \"+os.path.basename(pretrained_model_path)+\" was found and will be loaded prior to training.\")\n", - " \n", - "else:\n", - " print(bcolors.WARNING+'No pretrained network will be used.')\n", - "\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "MCGklf1vZf2M" - }, - "source": [ - "#**4. Train the network**\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1KYOuygETJkT" - }, - "source": [ - "## **4.1. Prepare the training data for training**\n", - "---\n", - "Here, we use the information from Section 3 to prepare the training data into a suitable format for training. **Your data will be copied in the google Colab \"content\" folder which may take some time depending on the size of your dataset.**\n", - "\n", - "\n", - "\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "lIUAOJ_LMv5E", - "cellView": "form" - }, - "source": [ - "#@markdown ##Prepare the data for training\n", - "\n", - "\n", - "# --------------------- Here we load the augmented data or the raw data ------------------------\n", - "\n", - "if Use_Data_augmentation:\n", - " Training_source_dir = Training_source_augmented\n", - " Training_target_dir = Training_target_augmented\n", - "\n", - "if not Use_Data_augmentation:\n", - " Training_source_dir = Training_source\n", - " Training_target_dir = Training_target\n", - "# --------------------- ------------------------------------------------\n", - "\n", - "print(\"Data preparation in progress\")\n", - "\n", - "if os.path.exists(model_path+'/'+model_name):\n", - " shutil.rmtree(model_path+'/'+model_name)\n", - "os.makedirs(model_path+'/'+model_name)\n", - "\n", - "#--------------- Here we move the files to trainA and train B ---------\n", - "\n", - "\n", - "print('Copying training source data...')\n", - "for f in tqdm(os.listdir(Training_source_dir)):\n", - " shutil.copyfile(Training_source_dir+\"/\"+f, TrainA_Folder+\"/\"+f)\n", - "\n", - "print('Copying training target data...')\n", - "for f in tqdm(os.listdir(Training_target_dir)):\n", - " shutil.copyfile(Training_target_dir+\"/\"+f, TrainB_Folder+\"/\"+f)\n", - "\n", - "#---------------------------------------------------------------------\n", - "\n", - "#--------------- Here we combined A and B images---------\n", - "os.chdir(\"/content\")\n", - "!python pytorch-CycleGAN-and-pix2pix/datasets/combine_A_and_B.py --fold_A \"$imageA_folder\" --fold_B \"$imageB_folder\" --fold_AB \"$imageAB_folder\"\n", - "\n", - "\n", - "\n", - "# pix2pix uses EPOCH without lr decay and EPOCH with lr decay, here we automatically choose half and half\n", - "\n", - "number_of_epochs_lr_stable = int(number_of_epochs/2)\n", - "number_of_epochs_lr_decay = int(number_of_epochs/2)\n", - "\n", - "if Use_pretrained_model :\n", - " for f in os.listdir(pretrained_model_path):\n", - " if (f.startswith(\"latest_net_\")): \n", - " shutil.copyfile(pretrained_model_path+\"/\"+f, model_path+'/'+model_name+\"/\"+f)\n", - "\n", - "#Export of pdf summary of training parameters\n", - "pdf_export(augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)\n", - "\n", - "print('------------------------')\n", - "print(\"Data ready for training\")\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "0Dfn8ZsEMv5d" - }, - "source": [ - "## **4.2. Start Training**\n", - "---\n", - "When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n", - "\n", - "* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches or continue the training in a second Colab session. **Pix2pix will save model checkpoints every 5 epochs.**\n", - "\n", - "Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder." - ] - }, - { - "cell_type": "code", - "metadata": { - "scrolled": true, - "id": "iwNmp1PUzRDQ", - "cellView": "form" - }, - "source": [ - "#@markdown ##Start training\n", - "\n", - "start = time.time()\n", - "\n", - "os.chdir(\"/content\")\n", - "\n", - "#--------------------------------- Command line inputs to change pix2pix paramaters------------\n", - "\n", - " # basic parameters\n", - " #('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')\n", - " #('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')\n", - " #('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')\n", - " #('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')\n", - " \n", - " # model parameters\n", - " #('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]')\n", - " #('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')\n", - " #('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')\n", - " #('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')\n", - " #('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')\n", - " #('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')\n", - " #('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]')\n", - " #('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')\n", - " #('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]')\n", - " #('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]')\n", - " #('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')\n", - " #('--no_dropout', action='store_true', help='no dropout for the generator')\n", - " \n", - " # dataset parameters\n", - " #('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')\n", - " #('--direction', type=str, default='AtoB', help='AtoB or BtoA')\n", - " #('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')\n", - " #('--num_threads', default=4, type=int, help='# threads for loading data')\n", - " #('--batch_size', type=int, default=1, help='input batch size')\n", - " #('--load_size', type=int, default=286, help='scale images to this size')\n", - " #('--crop_size', type=int, default=256, help='then crop to this size')\n", - " #('--max_dataset_size', type=int, default=float(\"inf\"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')\n", - " #('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')\n", - " #('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')\n", - " #('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')\n", - " \n", - " # additional parameters\n", - " #('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')\n", - " #('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')\n", - " #('--verbose', action='store_true', help='if specified, print more debugging information')\n", - " #('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')\n", - " \n", - " # visdom and HTML visualization parameters\n", - " #('--display_freq', type=int, default=400, help='frequency of showing training results on screen')\n", - " #('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.')\n", - " #('--display_id', type=int, default=1, help='window id of the web display')\n", - " #('--display_server', type=str, default=\"http://localhost\", help='visdom server of the web display')\n", - " #('--display_env', type=str, default='main', help='visdom display environment name (default is \"main\")')\n", - " #('--display_port', type=int, default=8097, help='visdom port of the web display')\n", - " #('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html')\n", - " #('--print_freq', type=int, default=100, help='frequency of showing training results on console')\n", - " #('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')\n", - " \n", - " # network saving and loading parameters\n", - " #('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')\n", - " #('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')\n", - " #('--save_by_iter', action='store_true', help='whether saves model by iteration')\n", - " #('--continue_train', action='store_true', help='continue training: load the latest model')\n", - " #('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...')\n", - " #('--phase', type=str, default='train', help='train, val, test, etc')\n", - " \n", - " # training parameters\n", - " #('--n_epochs', type=int, default=100, help='number of epochs with the initial learning rate')\n", - " #('--n_epochs_decay', type=int, default=100, help='number of epochs to linearly decay learning rate to zero')\n", - " #('--beta1', type=float, default=0.5, help='momentum term of adam')\n", - " #('--lr', type=float, default=0.0002, help='initial learning rate for adam')\n", - " #('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.')\n", - " #('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images')\n", - " #('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]')\n", - " #('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations'\n", - "\n", - "#---------------------------------------------------------\n", - "\n", - "#----- Start the training ------------------------------------\n", - "if not Use_pretrained_model:\n", - " !python pytorch-CycleGAN-and-pix2pix/train.py --dataroot \"$imageAB_folder\" --name $model_name --model pix2pix --batch_size $batch_size --preprocess scale_width_and_crop --load_size $Image_min_dim --crop_size $patch_size --checkpoints_dir \"$model_path\" --no_html --n_epochs $number_of_epochs_lr_stable --n_epochs_decay $number_of_epochs_lr_decay --lr $initial_learning_rate --display_id 0 --save_epoch_freq 5\n", - "\n", - "if Use_pretrained_model:\n", - " !python pytorch-CycleGAN-and-pix2pix/train.py --dataroot \"$imageAB_folder\" --name $model_name --model pix2pix --batch_size $batch_size --preprocess scale_width_and_crop --load_size $Image_min_dim --crop_size $patch_size --checkpoints_dir \"$model_path\" --no_html --n_epochs $number_of_epochs_lr_stable --n_epochs_decay $number_of_epochs_lr_decay --lr $initial_learning_rate --display_id 0 --save_epoch_freq 5 --continue_train\n", - "\n", - "\n", - "#---------------------------------------------------------\n", - "\n", - "print(\"Training, done.\")\n", - "\n", - "# Displaying the time elapsed for training\n", - "dt = time.time() - start\n", - "mins, sec = divmod(dt, 60) \n", - "hour, mins = divmod(mins, 60) \n", - "print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n", - "\n", - "# Export pdf summary after training to update document\n", - "pdf_export(trained = True, augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "_0Hynw3-xHp1" - }, - "source": [ - "# **5. Evaluate your model**\n", - "---\n", - "\n", - "This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n", - "\n", - "**We highly recommend to perform quality control on all newly trained models.**\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "HQqBkYzT4hQS" - }, - "source": [ - "## **5.1. Choose the model you want to assess**" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "eAJzMwPA6tlH", - "cellView": "form" - }, - "source": [ - "# model name and path\n", - "#@markdown ###Do you want to assess the model you just trained ?\n", - "Use_the_current_trained_model = True #@param {type:\"boolean\"}\n", - "\n", - "#@markdown ###If not, please provide the path to the model folder:\n", - "\n", - "QC_model_folder = \"\" #@param {type:\"string\"}\n", - "\n", - "#Here we define the loaded model name and path\n", - "QC_model_name = os.path.basename(QC_model_folder)\n", - "QC_model_path = os.path.dirname(QC_model_folder)\n", - "\n", - "if (Use_the_current_trained_model): \n", - " QC_model_name = model_name\n", - " QC_model_path = model_path\n", - "\n", - "full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n", - "if os.path.exists(full_QC_model_path):\n", - " print(\"The \"+QC_model_name+\" network will be evaluated\")\n", - "else:\n", - " W = '\\033[0m' # white (normal)\n", - " R = '\\033[31m' # red\n", - " print(R+'!! WARNING: The chosen model does not exist !!'+W)\n", - " print('Please make sure you provide a valid model path and model name before proceeding further.')\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "kittWWbs4pc8" - }, - "source": [ - "## **5.2. Identify the best checkpoint to use to make predictions**" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "SeGNGf4A4ukf" - }, - "source": [ - " Pix2pix save model checkpoints every five epochs. Due to the stochastic nature of GAN networks, the last checkpoint is not always the best one to use. As a consequence, it can be challenging to choose the most suitable checkpoint to use to make predictions.\n", - "\n", - "This section allows you to perform predictions using all the saved checkpoints and to estimate the quality of these predictions by comparing them to the provided ground truths images. Metric used include:\n", - "\n", - "**1. The SSIM (structural similarity) map** \n", - "\n", - "The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n", - "\n", - "**mSSIM** is the SSIM value calculated across the entire window of both images.\n", - "\n", - "**The output below shows the SSIM maps with the mSSIM**\n", - "\n", - "**2. The RSE (Root Squared Error) map** \n", - "\n", - "This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n", - "\n", - "\n", - "**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n", - "\n", - "**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n", - "\n", - "**The output below shows the RSE maps with the NRMSE and PSNR values.**\n", - "\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "VfF_oMpI4-Xl", - "cellView": "form" - }, - "source": [ - "#@markdown ##Choose the folders that contain your Quality Control dataset\n", - "\n", - "import glob\n", - "import os.path\n", - "from scipy import stats\n", - "\n", - "#@markdown ###Path to images:\n", - "\n", - "Source_QC_folder = \"\" #@param{type:\"string\"}\n", - "Target_QC_folder = \"\" #@param{type:\"string\"}\n", - "\n", - "#@markdown ###Type of images:\n", - "\n", - "Image_type = \"Grayscale\" #@param [\"Grayscale\", \"RGB\"]\n", - "\n", - "#@markdown ###Image normalisation:\n", - "\n", - "Normalisation_QC_source = \"None\" #@param [\"None\", \"Contrast stretching\", \"Adaptive Equalization\"]\n", - "Normalisation_QC_target = \"None\" #@param [\"None\", \"Contrast stretching\", \"Adaptive Equalization\"]\n", - "\n", - "\n", - "#@markdown ###Advanced Parameters\n", - "\n", - "patch_size_QC = 512#@param {type:\"number\"} # in pixels\n", - "Do_lpips_analysis = False #@param {type:\"boolean\"}\n", - "\n", - "\n", - "\n", - "# Create a quality control folder\n", - "\n", - "if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\"):\n", - " shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n", - "\n", - "os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n", - "\n", - "\n", - "# Create a quality control/Prediction Folder\n", - "\n", - "QC_prediction_results = QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"\n", - "\n", - "if os.path.exists(QC_prediction_results):\n", - " shutil.rmtree(QC_prediction_results)\n", - "\n", - "os.makedirs(QC_prediction_results)\n", - "\n", - "# Here we count how many images are in our folder to be predicted and we had a few\n", - "Nb_files_Data_folder = len(os.listdir(Source_QC_folder)) +10\n", - "\n", - "# List images in Source_QC_folder\n", - "# This will find the image dimension of a randomly choosen image in Source_QC_folder \n", - "random_choice = random.choice(os.listdir(Source_QC_folder))\n", - "x = imageio.imread(Source_QC_folder+\"/\"+random_choice)\n", - "\n", - "#Find image XY dimension\n", - "Image_Y = x.shape[0]\n", - "Image_X = x.shape[1]\n", - "\n", - "Image_min_dim = min(Image_Y, Image_X)\n", - "\n", - "# Here we need to move the data to be analysed so that pix2pix can find them\n", - "\n", - "Saving_path_QC= \"/content/\"+QC_model_name+\"_images\"\n", - "\n", - "if os.path.exists(Saving_path_QC):\n", - " shutil.rmtree(Saving_path_QC)\n", - "os.makedirs(Saving_path_QC)\n", - "\n", - "Saving_path_QC_folder = Saving_path_QC+\"/QC\"\n", - "\n", - "if os.path.exists(Saving_path_QC_folder):\n", - " shutil.rmtree(Saving_path_QC_folder)\n", - "os.makedirs(Saving_path_QC_folder)\n", - "\n", - "imageA_folder = Saving_path_QC_folder+\"/A\"\n", - "os.makedirs(imageA_folder)\n", - "\n", - "imageB_folder = Saving_path_QC_folder+\"/B\"\n", - "os.makedirs(imageB_folder)\n", - "\n", - "imageAB_folder = Saving_path_QC_folder+\"/AB\"\n", - "os.makedirs(imageAB_folder)\n", - "\n", - "testAB_folder = Saving_path_QC_folder+\"/AB/test\"\n", - "os.makedirs(testAB_folder)\n", - "\n", - "testA_Folder = Saving_path_QC_folder+\"/A/test\"\n", - "os.makedirs(testA_Folder)\n", - " \n", - "testB_Folder = Saving_path_QC_folder+\"/B/test\"\n", - "os.makedirs(testB_Folder)\n", - "\n", - "QC_checkpoint_folders = \"/content/\"+QC_model_name\n", - "\n", - "if os.path.exists(QC_checkpoint_folders):\n", - " shutil.rmtree(QC_checkpoint_folders)\n", - "os.makedirs(QC_checkpoint_folders)\n", - "\n", - "#Here we copy and normalise the data\n", - "\n", - "if Normalisation_QC_source == \"Contrast stretching\":\n", - " \n", - " for filename in os.listdir(Source_QC_folder):\n", - "\n", - " img = imread(os.path.join(Source_QC_folder,filename)).astype(np.float32)\n", - " short_name = os.path.splitext(filename)\n", - "\n", - " p2, p99 = np.percentile(img, (2, 99.9))\n", - " img = exposure.rescale_intensity(img, in_range=(p2, p99))\n", - "\n", - " img = 255 * img # Now scale by 255\n", - " img = img.astype(np.uint8)\n", - " cv2.imwrite(testA_Folder+\"/\"+short_name[0]+\".png\", img)\n", - " \n", - "if Normalisation_QC_target == \"Contrast stretching\":\n", - "\n", - " for filename in os.listdir(Target_QC_folder):\n", - "\n", - " img = imread(os.path.join(Target_QC_folder,filename)).astype(np.float32)\n", - " short_name = os.path.splitext(filename)\n", - "\n", - " p2, p99 = np.percentile(img, (2, 99.9))\n", - " img = exposure.rescale_intensity(img, in_range=(p2, p99))\n", - "\n", - " img = 255 * img # Now scale by 255\n", - " img = img.astype(np.uint8)\n", - " cv2.imwrite(testB_Folder+\"/\"+short_name[0]+\".png\", img)\n", - "\n", - "if Normalisation_QC_source == \"Adaptive Equalization\":\n", - " for filename in os.listdir(Source_QC_folder):\n", - "\n", - " img = imread(os.path.join(Source_QC_folder,filename))\n", - " short_name = os.path.splitext(filename)\n", - "\n", - " img = exposure.equalize_adapthist(img, clip_limit=0.03)\n", - "\n", - " img = 255 * img # Now scale by 255\n", - " img = img.astype(np.uint8)\n", - " cv2.imwrite(testA_Folder+\"/\"+short_name[0]+\".png\", img)\n", - "\n", - "\n", - "if Normalisation_QC_target == \"Adaptive Equalization\":\n", - "\n", - " for filename in os.listdir(Target_QC_folder):\n", - "\n", - " img = imread(os.path.join(Target_QC_folder,filename))\n", - " short_name = os.path.splitext(filename)\n", - "\n", - " p2, p99 = np.percentile(img, (2, 99.8))\n", - " img = exposure.rescale_intensity(img, in_range=(p2, p99))\n", - "\n", - " img = 255 * img # Now scale by 255\n", - " img = img.astype(np.uint8)\n", - " cv2.imwrite(testB_Folder+\"/\"+short_name[0]+\".png\", img)\n", - "\n", - "if Normalisation_QC_source == \"None\":\n", - " for files in os.listdir(Source_QC_folder):\n", - " shutil.copyfile(Source_QC_folder+\"/\"+files, testA_Folder+\"/\"+files)\n", - "\n", - "if Normalisation_QC_target == \"None\":\n", - " for files in os.listdir(Target_QC_folder):\n", - " shutil.copyfile(Target_QC_folder+\"/\"+files, testB_Folder+\"/\"+files)\n", - "\n", - "\n", - "#Here we create a merged folder containing only imageA\n", - "os.chdir(\"/content\")\n", - "\n", - "!python pytorch-CycleGAN-and-pix2pix/datasets/combine_A_and_B.py --fold_A \"$imageA_folder\" --fold_B \"$imageB_folder\" --fold_AB \"$imageAB_folder\"\n", - "\n", - "# This will find the image dimension of a randomly choosen image in Source_QC_folder \n", - "random_choice = random.choice(os.listdir(Source_QC_folder))\n", - "x = imageio.imread(Source_QC_folder+\"/\"+random_choice)\n", - "\n", - "#Find image XY dimension\n", - "Image_Y = x.shape[0]\n", - "Image_X = x.shape[1]\n", - "\n", - "Image_min_dim = int(min(Image_Y, Image_X))\n", - "\n", - "if not patch_size_QC % 256 == 0:\n", - " patch_size_QC = ((int(patch_size_QC / 256)) * 256)\n", - " print (\" Your image dimensions are not divisible by 256; therefore your images have now been resized to:\",patch_size_QC)\n", - "\n", - "if patch_size_QC < 256:\n", - " patch_size_QC = 256\n", - "\n", - "Nb_Checkpoint = len(glob.glob(os.path.join(full_QC_model_path, '*G.pth')))\n", - "\n", - "print(Nb_Checkpoint)\n", - "\n", - "## Initiate lists\n", - "\n", - "Checkpoint_list = []\n", - "Average_ssim_score_list = []\n", - "Average_lpips_score_list = []\n", - "\n", - "for j in range(1, len(glob.glob(os.path.join(full_QC_model_path, '*G.pth')))+1):\n", - " checkpoints = j*5\n", - "\n", - " if checkpoints == Nb_Checkpoint*5:\n", - " checkpoints = \"latest\"\n", - "\n", - " print(\"The checkpoint currently analysed is =\"+str(checkpoints))\n", - "\n", - " Checkpoint_list.append(checkpoints)\n", - "\n", - " # Create a quality control/Prediction Folder\n", - "\n", - " QC_prediction_results = QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)\n", - "\n", - " if os.path.exists(QC_prediction_results):\n", - " shutil.rmtree(QC_prediction_results)\n", - "\n", - " os.makedirs(QC_prediction_results)\n", - "\n", - "#---------------------------- Predictions are performed here ----------------------\n", - " os.chdir(\"/content\")\n", - " !python pytorch-CycleGAN-and-pix2pix/test.py --dataroot \"$imageAB_folder\" --name \"$QC_model_name\" --model pix2pix --epoch $checkpoints --no_dropout --preprocess scale_width --load_size $patch_size_QC --crop_size $patch_size_QC --results_dir \"$QC_prediction_results\" --checkpoints_dir \"$QC_model_path\" --direction AtoB --num_test $Nb_files_Data_folder\n", - "#-----------------------------------------------------------------------------------\n", - "\n", - "#Here we need to move the data again and remove all the unnecessary folders\n", - "\n", - " Checkpoint_name = \"test_\"+str(checkpoints)\n", - "\n", - " QC_results_images = QC_prediction_results+\"/\"+QC_model_name+\"/\"+Checkpoint_name+\"/images\"\n", - "\n", - " QC_results_images_files = os.listdir(QC_results_images)\n", - "\n", - " for f in QC_results_images_files: \n", - " shutil.copyfile(QC_results_images+\"/\"+f, QC_prediction_results+\"/\"+f)\n", - "\n", - " os.chdir(\"/content\") \n", - "\n", - " #Here we clean up the extra files\n", - " shutil.rmtree(QC_prediction_results+\"/\"+QC_model_name)\n", - "\n", - " #-------------------------------- QC for RGB ------------------------------------\n", - " if Image_type == \"RGB\":\n", - "# List images in Source_QC_folder\n", - "# This will find the image dimension of a randomly choosen image in Source_QC_folder \n", - " random_choice = random.choice(os.listdir(Source_QC_folder))\n", - " x = imageio.imread(Source_QC_folder+\"/\"+random_choice)\n", - "\n", - " def ssim(img1, img2):\n", - " return structural_similarity(img1,img2,data_range=1.,full=True, multichannel=True)\n", - "\n", - "# Open and create the csv file that will contain all the QC metrics\n", - " with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", \"w\", newline='') as file:\n", - " writer = csv.writer(file)\n", - "\n", - " # Write the header in the csv file\n", - " writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\" , \"Prediction v. GT lpips\", \"Input v. GT lpips\"])\n", - " \n", - " \n", - " # Initiate list\n", - " ssim_score_list = []\n", - " lpips_score_list = [] \n", - "\n", - "\n", - " # Let's loop through the provided dataset in the QC folders\n", - "\n", - "\n", - " for i in os.listdir(Source_QC_folder):\n", - " if not os.path.isdir(os.path.join(Source_QC_folder,i)):\n", - " print('Running QC on: '+i)\n", - "\n", - " shortname_no_PNG = i[:-4]\n", - " \n", - " # -------------------------------- Target test data (Ground truth) --------------------------------\n", - " \n", - " test_GT = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), shortname_no_PNG+\"_real_B.png\"))\n", - "\n", - " # -------------------------------- Source test data --------------------------------\n", - " test_source = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_real_A.png\"))\n", - " \n", - " \n", - " # -------------------------------- Prediction --------------------------------\n", - " \n", - " test_prediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_fake_B.png\"))\n", - " \n", - " #--------------------------- Here we normalise using histograms matching--------------------------------\n", - " test_prediction_matched = match_histograms(test_prediction, test_GT, multichannel=True)\n", - " test_source_matched = match_histograms(test_source, test_GT, multichannel=True)\n", - " \n", - " # -------------------------------- Calculate the metric maps and save them --------------------------------\n", - "\n", - " # Calculate the SSIM maps\n", - " index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT, test_prediction_matched)\n", - " index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT, test_source_matched)\n", - "\n", - " ssim_score_list.append(index_SSIM_GTvsPrediction)\n", - "\n", - " #Save ssim_maps\n", - " img_SSIM_GTvsPrediction_8bit = (img_SSIM_GTvsPrediction* 255).astype(\"uint8\")\n", - " io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsPrediction_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsPrediction_8bit)\n", - " img_SSIM_GTvsSource_8bit = (img_SSIM_GTvsSource* 255).astype(\"uint8\")\n", - " io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsSource_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsSource_8bit)\n", - "\n", - " # -------------------------------- Pearson correlation coefficient --------------------------------\n", - "\n", - "\n", - "\n", - "\n", - "\n", - " # -------------------------------- Calculate the perceptual difference metrics map and save them --------------------------------\n", - " if Do_lpips_analysis:\n", - "\n", - " lpips_GTvsPrediction = perceptual_diff(test_GT, test_prediction, 'alex', True)\n", - " lpips_GTvsPrediction_image = lpips_GTvsPrediction[0,0,...].data.cpu().numpy()\n", - " lpips_GTvsPrediction_score= lpips_GTvsPrediction.mean().data.numpy()\n", - " lpips_score_list.append(lpips_GTvsPrediction_score)\n", - "\n", - "\n", - " lpips_GTvsSource = perceptual_diff(test_GT, test_source, 'alex', True)\n", - " lpips_GTvsSource_image = lpips_GTvsSource[0,0,...].data.cpu().numpy()\n", - " lpips_GTvsSource_score= lpips_GTvsSource.mean().data.numpy()\n", - "\n", - "\n", - " #lpips_GTvsPrediction_image_8bit = (lpips_GTvsPrediction_image* 255).astype(\"uint8\")\n", - " io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/lpips_GTvsPrediction_\"+shortname_no_PNG+'.tif',lpips_GTvsPrediction_image)\n", - "\n", - " #lpips_GTvsSource_image_8bit = (lpips_GTvsSource_image* 255).astype(\"uint8\")\n", - " io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/lpips_GTvsInput_\"+shortname_no_PNG+'.tif',lpips_GTvsSource_image)\n", - " else:\n", - " lpips_GTvsPrediction_score = 0\n", - " lpips_score_list.append(lpips_GTvsPrediction_score)\n", - "\n", - " lpips_GTvsSource_score = 0\n", - "\n", - "\n", - " \n", - " writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource), str(lpips_GTvsPrediction_score),str(lpips_GTvsSource_score)])\n", - "\n", - " #Here we calculate the ssim average for each image in each checkpoints\n", - "\n", - " Average_SSIM_checkpoint = Average(ssim_score_list)\n", - " Average_ssim_score_list.append(Average_SSIM_checkpoint)\n", - "\n", - " Average_lpips_checkpoint = Average(lpips_score_list)\n", - " Average_lpips_score_list.append(Average_lpips_checkpoint)\n", - "\n", - "#------------------------------------------- QC for Grayscale ----------------------------------------------\n", - "\n", - " if Image_type == \"Grayscale\":\n", - " def ssim(img1, img2):\n", - " return structural_similarity(img1,img2,data_range=1.,full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n", - "\n", - "\n", - " def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n", - "\n", - "\n", - " mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n", - " ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n", - " return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n", - "\n", - "\n", - " def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n", - " \n", - " if dtype is not None:\n", - " x = x.astype(dtype,copy=False)\n", - " mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n", - " ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n", - " eps = dtype(eps)\n", - "\n", - " try:\n", - " import numexpr\n", - " x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n", - " except ImportError:\n", - " x = (x - mi) / ( ma - mi + eps )\n", - "\n", - " if clip:\n", - " x = np.clip(x,0,1)\n", - "\n", - " return x\n", - "\n", - " def norm_minmse(gt, x, normalize_gt=True):\n", - " \n", - " if normalize_gt:\n", - " gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n", - " x = x.astype(np.float32, copy=False) - np.mean(x)\n", - " #x = x - np.mean(x)\n", - " gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n", - " #gt = gt - np.mean(gt)\n", - " scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n", - " return gt, scale * x\n", - "\n", - "\n", - "# Open and create the csv file that will contain all the QC metrics\n", - " with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", \"w\", newline='') as file:\n", - " writer = csv.writer(file)\n", - "\n", - " # Write the header in the csv file\n", - " writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\", \"Prediction v. GT lpips\", \"Input v. GT lpips\"]) \n", - "\n", - " # Initialize the lists\n", - " ssim_score_list = []\n", - " Pearson_correlation_coefficient_list = []\n", - " lpips_score_list = []\n", - " \n", - " # Let's loop through the provided dataset in the QC folders\n", - "\n", - " for i in os.listdir(Source_QC_folder):\n", - " if not os.path.isdir(os.path.join(Source_QC_folder,i)):\n", - " print('Running QC on: '+i)\n", - "\n", - " shortname_no_PNG = i[:-4]\n", - " # -------------------------------- Target test data (Ground truth) --------------------------------\n", - " test_GT_raw = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), shortname_no_PNG+\"_real_B.png\")) \n", - " test_GT = test_GT_raw[:,:,2]\n", - "\n", - " # -------------------------------- Source test data --------------------------------\n", - " test_source_raw = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_real_A.png\")) \n", - " test_source = test_source_raw[:,:,2]\n", - "\n", - " # Normalize the images wrt each other by minimizing the MSE between GT and Source image\n", - " test_GT_norm,test_source_norm = norm_minmse(test_GT, test_source, normalize_gt=True)\n", - "\n", - " # -------------------------------- Prediction --------------------------------\n", - " test_prediction_raw = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_fake_B.png\"))\n", - " \n", - " test_prediction = test_prediction_raw[:,:,2]\n", - "\n", - " # Normalize the images wrt each other by minimizing the MSE between GT and prediction\n", - " test_GT_norm,test_prediction_norm = norm_minmse(test_GT, test_prediction, normalize_gt=True) \n", - "\n", - "\n", - " # -------------------------------- Calculate the metric maps and save them --------------------------------\n", - "\n", - " # Calculate the SSIM maps\n", - " index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT_norm, test_prediction_norm)\n", - " index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT_norm, test_source_norm)\n", - "\n", - " ssim_score_list.append(index_SSIM_GTvsPrediction)\n", - "\n", - " #Save ssim_maps\n", - " \n", - " img_SSIM_GTvsPrediction_8bit = (img_SSIM_GTvsPrediction* 255).astype(\"uint8\")\n", - " io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsPrediction_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsPrediction_8bit)\n", - " img_SSIM_GTvsSource_8bit = (img_SSIM_GTvsSource* 255).astype(\"uint8\")\n", - " io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsSource_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsSource_8bit)\n", - " \n", - " # Calculate the Root Squared Error (RSE) maps\n", - " img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n", - " img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n", - "\n", - " # Save SE maps\n", - " img_RSE_GTvsPrediction_8bit = (img_RSE_GTvsPrediction* 255).astype(\"uint8\")\n", - " io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/RSE_GTvsPrediction_\"+shortname_no_PNG+'.tif',img_RSE_GTvsPrediction_8bit)\n", - " img_RSE_GTvsSource_8bit = (img_RSE_GTvsSource* 255).astype(\"uint8\")\n", - " io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/RSE_GTvsSource_\"+shortname_no_PNG+'.tif',img_RSE_GTvsSource_8bit)\n", - "\n", - "\n", - " # -------------------------------- Calculate the RSE metrics and save them --------------------------------\n", - "\n", - " # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n", - " NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n", - " NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n", - " \n", - " # We can also measure the peak signal to noise ratio between the images\n", - " PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n", - " PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n", - "\n", - "\n", - " \n", - " # -------------------------------- Calculate the perceptual difference metrics map and save them --------------------------------\n", - " if Do_lpips_analysis:\n", - " lpips_GTvsPrediction = perceptual_diff(test_GT_raw, test_prediction_raw, 'alex', True)\n", - " lpips_GTvsPrediction_image = lpips_GTvsPrediction[0,0,...].data.cpu().numpy()\n", - " lpips_GTvsPrediction_score= lpips_GTvsPrediction.mean().data.numpy()\n", - " lpips_score_list.append(lpips_GTvsPrediction_score)\n", - "\n", - " lpips_GTvsSource = perceptual_diff(test_GT_raw, test_source_raw, 'alex', True)\n", - " lpips_GTvsSource_image = lpips_GTvsSource[0,0,...].data.cpu().numpy()\n", - " lpips_GTvsSource_score= lpips_GTvsSource.mean().data.numpy()\n", - "\n", - "\n", - " lpips_GTvsPrediction_image_8bit = (lpips_GTvsPrediction_image* 255).astype(\"uint8\")\n", - " io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/lpips_GTvsPrediction_\"+shortname_no_PNG+'.tif',lpips_GTvsPrediction_image_8bit)\n", - "\n", - " lpips_GTvsSource_image_8bit = (lpips_GTvsSource_image* 255).astype(\"uint8\")\n", - " io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/lpips_GTvsInput_\"+shortname_no_PNG+'.tif',lpips_GTvsSource_image_8bit)\n", - " else:\n", - " lpips_GTvsPrediction_score = 0\n", - " lpips_score_list.append(lpips_GTvsPrediction_score)\n", - "\n", - " lpips_GTvsSource_score = 0\n", - "\n", - "\n", - " writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource),str(PSNR_GTvsPrediction),str(PSNR_GTvsSource),str(lpips_GTvsPrediction_score),str(lpips_GTvsSource_score)])\n", - "\n", - "\n", - " #Here we calculate the ssim average for each image in each checkpoints\n", - "\n", - " Average_SSIM_checkpoint = Average(ssim_score_list)\n", - " Average_ssim_score_list.append(Average_SSIM_checkpoint)\n", - "\n", - " Average_lpips_checkpoint = Average(lpips_score_list)\n", - " Average_lpips_score_list.append(Average_lpips_checkpoint)\n", - "\n", - "\n", - "# All data is now processed saved\n", - " \n", - "\n", - "# -------------------------------- Display --------------------------------\n", - "\n", - "# Display the IoV vs Checkpoint plot\n", - "plt.figure(figsize=(20,5))\n", - "plt.plot(Checkpoint_list, Average_ssim_score_list, label=\"SSIM\")\n", - "plt.title('Checkpoints vs. SSIM')\n", - "plt.ylabel('SSIM')\n", - "plt.xlabel('Checkpoints')\n", - "plt.legend()\n", - "plt.savefig(full_QC_model_path+'/Quality Control/SSIMvsCheckpoint_data.png',bbox_inches='tight',pad_inches=0)\n", - "plt.show()\n", - "\n", - "\n", - "# -------------------------------- Display --------------------------------\n", - "\n", - "if Do_lpips_analysis:\n", - " # Display the lpips vs Checkpoint plot\n", - " plt.figure(figsize=(20,5))\n", - " plt.plot(Checkpoint_list, Average_lpips_score_list, label=\"lpips\")\n", - " plt.title('Checkpoints vs. lpips')\n", - " plt.ylabel('lpips')\n", - " plt.xlabel('Checkpoints')\n", - " plt.legend()\n", - " plt.savefig(full_QC_model_path+'/Quality Control/lpipsvsCheckpoint_data.png',bbox_inches='tight',pad_inches=0)\n", - " plt.show()\n", - "\n", - "\n", - "\n", - "# -------------------------------- Display RGB --------------------------------\n", - "\n", - "from ipywidgets import interact\n", - "import ipywidgets as widgets\n", - "\n", - "\n", - "if Image_type == \"RGB\":\n", - " random_choice_shortname_no_PNG = shortname_no_PNG\n", - "\n", - " @interact\n", - " def show_results(file=os.listdir(Source_QC_folder), checkpoints=Checkpoint_list):\n", - "\n", - " random_choice_shortname_no_PNG = file[:-4]\n", - "\n", - " df1 = pd.read_csv(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", header=0)\n", - " df2 = df1.set_index(\"image #\", drop = False)\n", - " index_SSIM_GTvsPrediction = df2.loc[file, \"Prediction v. GT mSSIM\"]\n", - " index_SSIM_GTvsSource = df2.loc[file, \"Input v. GT mSSIM\"]\n", - " lpips_GTvsPrediction = df2.loc[file, \"Prediction v. GT lpips\"]\n", - " lpips_GTvsSource = df2.loc[file, \"Input v. GT lpips\"]\n", - "\n", - "#Setting up colours\n", - " cmap = None\n", - "\n", - " plt.figure(figsize=(15,15))\n", - "\n", - "# Target (Ground-truth)\n", - " plt.subplot(3,3,1)\n", - " plt.axis('off')\n", - " img_GT = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_real_B.png\"), as_gray=False, pilmode=\"RGB\")\n", - " \n", - " plt.imshow(img_GT, cmap = cmap)\n", - " plt.title('Target',fontsize=15)\n", - "\n", - "# Source\n", - " plt.subplot(3,3,2)\n", - " plt.axis('off')\n", - " img_Source = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_real_A.png\"), as_gray=False, pilmode=\"RGB\")\n", - " plt.imshow(img_Source, cmap = cmap)\n", - " plt.title('Source',fontsize=15)\n", - "\n", - "#Prediction\n", - " plt.subplot(3,3,3)\n", - " plt.axis('off')\n", - "\n", - " img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_fake_B.png\"))\n", - "\n", - " plt.imshow(img_Prediction, cmap = cmap)\n", - " plt.title('Prediction',fontsize=15)\n", - "\n", - "\n", - "#SSIM between GT and Source\n", - " plt.subplot(3,3,5)\n", - "#plt.axis('off')\n", - " plt.tick_params(\n", - " axis='both', # changes apply to the x-axis and y-axis\n", - " which='both', # both major and minor ticks are affected\n", - " bottom=False, # ticks along the bottom edge are off\n", - " top=False, # ticks along the top edge are off\n", - " left=False, # ticks along the left edge are off\n", - " right=False, # ticks along the right edge are off\n", - " labelbottom=False,\n", - " labelleft=False)\n", - "\n", - " img_SSIM_GTvsSource = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsSource_\"+random_choice_shortname_no_PNG+\".tif\"))\n", - "\n", - " imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n", - "#plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n", - " plt.title('Target vs. Source',fontsize=15)\n", - " plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\n", - " plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n", - "\n", - "#SSIM between GT and Prediction\n", - " plt.subplot(3,3,6)\n", - "#plt.axis('off')\n", - " plt.tick_params(\n", - " axis='both', # changes apply to the x-axis and y-axis\n", - " which='both', # both major and minor ticks are affected\n", - " bottom=False, # ticks along the bottom edge are off\n", - " top=False, # ticks along the top edge are off\n", - " left=False, # ticks along the left edge are off\n", - " right=False, # ticks along the right edge are off\n", - " labelbottom=False,\n", - " labelleft=False) \n", - "\n", - " img_SSIM_GTvsPrediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\n", - "\n", - " imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n", - "#plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n", - " plt.title('Target vs. Prediction',fontsize=15)\n", - " plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\n", - "\n", - "#lpips Error between GT and source\n", - "\n", - " if Do_lpips_analysis:\n", - " plt.subplot(3,3,8)\n", - " #plt.axis('off')\n", - " plt.tick_params(\n", - " axis='both', # changes apply to the x-axis and y-axis\n", - " which='both', # both major and minor ticks are affected\n", - " bottom=False, # ticks along the bottom edge are off\n", - " top=False, # ticks along the top edge are off\n", - " left=False, # ticks along the left edge are off\n", - " right=False, # ticks along the right edge are off\n", - " labelbottom=False,\n", - " labelleft=False)\n", - "\n", - " img_lpips_GTvsSource = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"lpips_GTvsInput_\"+random_choice_shortname_no_PNG+\".tif\"))\n", - "\n", - " imlpips_GTvsSource = plt.imshow(img_lpips_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n", - " plt.colorbar(imlpips_GTvsSource,fraction=0.046,pad=0.04)\n", - " plt.title('Target vs. Input',fontsize=15)\n", - " plt.xlabel('lpips: '+str(round(lpips_GTvsSource,3)),fontsize=14)\n", - " plt.ylabel('Lpips maps',fontsize=20, rotation=0, labelpad=75)\n", - "\n", - "\n", - " #lpips Error between GT and Prediction\n", - " plt.subplot(3,3,9)\n", - " #plt.axis('off')\n", - " plt.tick_params(\n", - " axis='both', # changes apply to the x-axis and y-axis\n", - " which='both', # both major and minor ticks are affected\n", - " bottom=False, # ticks along the bottom edge are off\n", - " top=False, # ticks along the top edge are off\n", - " left=False, # ticks along the left edge are off\n", - " right=False, # ticks along the right edge are off\n", - " labelbottom=False,\n", - " labelleft=False)\n", - "\n", - " img_lpips_GTvsPrediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"lpips_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\n", - "\n", - " imlpips_GTvsPrediction = plt.imshow(img_lpips_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n", - " plt.colorbar(imlpips_GTvsPrediction,fraction=0.046,pad=0.04)\n", - " plt.title('Target vs. Prediction',fontsize=15)\n", - " plt.xlabel('lpips: '+str(round(lpips_GTvsPrediction,3)),fontsize=14)\n", - "\n", - " plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n", - "\n", - "# -------------------------------- Display Grayscale --------------------------------\n", - "\n", - "if Image_type == \"Grayscale\":\n", - " random_choice_shortname_no_PNG = shortname_no_PNG\n", - "\n", - " @interact\n", - " def show_results(file=os.listdir(Source_QC_folder), checkpoints=Checkpoint_list):\n", - "\n", - " random_choice_shortname_no_PNG = file[:-4]\n", - "\n", - " df1 = pd.read_csv(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", header=0)\n", - " df2 = df1.set_index(\"image #\", drop = False)\n", - " index_SSIM_GTvsPrediction = df2.loc[file, \"Prediction v. GT mSSIM\"]\n", - " index_SSIM_GTvsSource = df2.loc[file, \"Input v. GT mSSIM\"]\n", - "\n", - " NRMSE_GTvsPrediction = df2.loc[file, \"Prediction v. GT NRMSE\"]\n", - " NRMSE_GTvsSource = df2.loc[file, \"Input v. GT NRMSE\"]\n", - " PSNR_GTvsSource = df2.loc[file, \"Input v. GT PSNR\"]\n", - " PSNR_GTvsPrediction = df2.loc[file, \"Prediction v. GT PSNR\"]\n", - " lpips_GTvsPrediction = df2.loc[file, \"Prediction v. GT lpips\"]\n", - " lpips_GTvsSource = df2.loc[file, \"Input v. GT lpips\"]\n", - "\n", - " plt.figure(figsize=(20,20))\n", - " # Currently only displays the last computed set, from memory\n", - " # Target (Ground-truth)\n", - " plt.subplot(4,3,1)\n", - " plt.axis('off')\n", - " img_GT = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_real_B.png\"))\n", - "\n", - " plt.imshow(img_GT, norm=simple_norm(img_GT, percent = 99))\n", - " plt.title('Target',fontsize=15)\n", - "\n", - "# Source\n", - " plt.subplot(4,3,2)\n", - " plt.axis('off')\n", - " img_Source = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_real_A.png\"))\n", - " plt.imshow(img_Source, norm=simple_norm(img_Source, percent = 99))\n", - " plt.title('Source',fontsize=15)\n", - "\n", - "#Prediction\n", - " plt.subplot(4,3,3)\n", - " plt.axis('off')\n", - " img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_fake_B.png\"))\n", - " plt.imshow(img_Prediction, norm=simple_norm(img_Prediction, percent = 99))\n", - " plt.title('Prediction',fontsize=15)\n", - "\n", - "#Setting up colours\n", - " cmap = plt.cm.CMRmap\n", - "\n", - "#SSIM between GT and Source\n", - " plt.subplot(4,3,5)\n", - "#plt.axis('off')\n", - " plt.tick_params(\n", - " axis='both', # changes apply to the x-axis and y-axis\n", - " which='both', # both major and minor ticks are affected\n", - " bottom=False, # ticks along the bottom edge are off\n", - " top=False, # ticks along the top edge are off\n", - " left=False, # ticks along the left edge are off\n", - " right=False, # ticks along the right edge are off\n", - " labelbottom=False,\n", - " labelleft=False)\n", - "\n", - " img_SSIM_GTvsSource = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsSource_\"+random_choice_shortname_no_PNG+\".tif\"))\n", - " img_SSIM_GTvsSource = img_SSIM_GTvsSource / 255\n", - " imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n", - " \n", - " plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n", - " plt.title('Target vs. Source',fontsize=15)\n", - " plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\n", - " plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n", - "\n", - "#SSIM between GT and Prediction\n", - " plt.subplot(4,3,6)\n", - "#plt.axis('off')\n", - " plt.tick_params(\n", - " axis='both', # changes apply to the x-axis and y-axis\n", - " which='both', # both major and minor ticks are affected\n", - " bottom=False, # ticks along the bottom edge are off\n", - " top=False, # ticks along the top edge are off\n", - " left=False, # ticks along the left edge are off\n", - " right=False, # ticks along the right edge are off\n", - " labelbottom=False,\n", - " labelleft=False) \n", - " \n", - " img_SSIM_GTvsPrediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\n", - " img_SSIM_GTvsPrediction = img_SSIM_GTvsPrediction / 255\n", - " imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n", - " \n", - " plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n", - " plt.title('Target vs. Prediction',fontsize=15)\n", - " plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\n", - "\n", - "#Root Squared Error between GT and Source\n", - " plt.subplot(4,3,8)\n", - "#plt.axis('off')\n", - " plt.tick_params(\n", - " axis='both', # changes apply to the x-axis and y-axis\n", - " which='both', # both major and minor ticks are affected\n", - " bottom=False, # ticks along the bottom edge are off\n", - " top=False, # ticks along the top edge are off\n", - " left=False, # ticks along the left edge are off\n", - " right=False, # ticks along the right edge are off\n", - " labelbottom=False,\n", - " labelleft=False)\n", - "\n", - " img_RSE_GTvsSource = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"RSE_GTvsSource_\"+random_choice_shortname_no_PNG+\".tif\"))\n", - " img_RSE_GTvsSource = img_RSE_GTvsSource / 255\n", - "\n", - " imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource, cmap = cmap, vmin=0, vmax = 1)\n", - " plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n", - " plt.title('Target vs. Source',fontsize=15)\n", - " plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsSource,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n", - "#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\n", - " plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n", - "\n", - "#Root Squared Error between GT and Prediction\n", - " plt.subplot(4,3,9)\n", - "#plt.axis('off')\n", - " plt.tick_params(\n", - " axis='both', # changes apply to the x-axis and y-axis\n", - " which='both', # both major and minor ticks are affected\n", - " bottom=False, # ticks along the bottom edge are off\n", - " top=False, # ticks along the top edge are off\n", - " left=False, # ticks along the left edge are off\n", - " right=False, # ticks along the right edge are off\n", - " labelbottom=False,\n", - " labelleft=False)\n", - "\n", - " img_RSE_GTvsPrediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"RSE_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\n", - "\n", - " img_RSE_GTvsPrediction = img_RSE_GTvsPrediction / 255\n", - "\n", - " imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n", - " plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n", - " plt.title('Target vs. Prediction',fontsize=15)\n", - " plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsPrediction,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n", - "\n", - "#lpips Error between GT and source\n", - "\n", - " if Do_lpips_analysis:\n", - " plt.subplot(4,3,11)\n", - "\n", - " plt.tick_params(\n", - " axis='both', # changes apply to the x-axis and y-axis\n", - " which='both', # both major and minor ticks are affected\n", - " bottom=False, # ticks along the bottom edge are off\n", - " top=False, # ticks along the top edge are off\n", - " left=False, # ticks along the left edge are off\n", - " right=False, # ticks along the right edge are off\n", - " labelbottom=False,\n", - " labelleft=False)\n", - "\n", - " img_lpips_GTvsSource = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"lpips_GTvsInput_\"+random_choice_shortname_no_PNG+\".tif\"))\n", - "\n", - " img_lpips_GTvsSource = img_lpips_GTvsSource / 255\n", - "\n", - " imlpips_GTvsSource = plt.imshow(img_lpips_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n", - " plt.colorbar(imlpips_GTvsSource,fraction=0.046,pad=0.04)\n", - " plt.title('Target vs. Input',fontsize=15)\n", - " plt.xlabel('lpips: '+str(round(lpips_GTvsSource,3)),fontsize=14)\n", - " plt.ylabel('Lpips maps',fontsize=20, rotation=0, labelpad=75)\n", - "\n", - "#lpips Error between GT and Prediction\n", - " plt.subplot(4,3,12)\n", - "#plt.axis('off')\n", - " plt.tick_params(\n", - " axis='both', # changes apply to the x-axis and y-axis\n", - " which='both', # both major and minor ticks are affected\n", - " bottom=False, # ticks along the bottom edge are off\n", - " top=False, # ticks along the top edge are off\n", - " left=False, # ticks along the left edge are off\n", - " right=False, # ticks along the right edge are off\n", - " labelbottom=False,\n", - " labelleft=False)\n", - "\n", - " img_lpips_GTvsPrediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"lpips_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\n", - "\n", - " img_lpips_GTvsPrediction = img_lpips_GTvsPrediction / 255\n", - "\n", - " imlpips_GTvsPrediction = plt.imshow(img_lpips_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n", - " plt.colorbar(imlpips_GTvsPrediction,fraction=0.046,pad=0.04)\n", - " plt.title('Target vs. Prediction',fontsize=15)\n", - " plt.xlabel('lpips: '+str(round(lpips_GTvsPrediction,3)),fontsize=14)\n", - "\n", - " plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n", - "\n", - "#Make a pdf summary of the QC results\n", - "\n", - "qc_pdf_export()\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-tJeeJjLnRkP" - }, - "source": [ - "# **6. Using the trained model**\n", - "\n", - "---\n", - "\n", - "In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "d8wuQGjoq6eN" - }, - "source": [ - "## **6.1. Generate prediction(s) from unseen dataset**\n", - "---\n", - "\n", - "The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as PNG images.\n", - "\n", - "**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n", - "\n", - "**`Result_folder`:** This folder will contain the predicted output images.\n", - "\n", - "**`checkpoint`:** Choose the checkpoint number you would like to use to perform predictions. To use the \"latest\" checkpoint, input \"latest\".\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "y2TD5p7MZrEb", - "cellView": "form" - }, - "source": [ - "#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then play the cell to predict outputs from your unseen images.\n", - "import glob\n", - "import os.path\n", - "\n", - "latest = \"latest\"\n", - "\n", - "Data_folder = \"\" #@param {type:\"string\"}\n", - "Result_folder = \"\" #@param {type:\"string\"}\n", - "\n", - "#@markdown ###Image normalisation:\n", - "\n", - "Normalisation_prediction_source = \"None\" #@param [\"None\", \"Contrast stretching\", \"Adaptive Equalization\"]\n", - "\n", - "# model name and path\n", - "#@markdown ###Do you want to use the current trained model?\n", - "Use_the_current_trained_model = True #@param {type:\"boolean\"}\n", - "\n", - "#@markdown ###If not, please provide the path to the model folder:\n", - "\n", - "Prediction_model_folder = \"\" #@param {type:\"string\"}\n", - "\n", - "#@markdown ###What model checkpoint would you like to use?\n", - "\n", - "checkpoint = latest#@param {type:\"raw\"}\n", - "\n", - "#@markdown ###Advanced Parameters\n", - "\n", - "patch_size = 512#@param {type:\"number\"} # in pixels\n", - "\n", - "#Here we find the loaded model name and parent path\n", - "Prediction_model_name = os.path.basename(Prediction_model_folder)\n", - "Prediction_model_path = os.path.dirname(Prediction_model_folder)\n", - "\n", - "#here we check if we use the newly trained network or not\n", - "if (Use_the_current_trained_model): \n", - " print(\"Using current trained network\")\n", - " Prediction_model_name = model_name\n", - " Prediction_model_path = model_path\n", - "\n", - "if not patch_size % 256 == 0:\n", - " patch_size = ((int(patch_size / 256)) * 256)\n", - " print (\" Your image dimensions are not divisible by 256; therefore your images have now been resized to:\",patch_size_QC)\n", - "\n", - "if patch_size < 256:\n", - " patch_size = 256\n", - "\n", - "#here we check if the model exists\n", - "full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n", - "\n", - "if os.path.exists(full_Prediction_model_path):\n", - " print(\"The \"+Prediction_model_name+\" network will be used.\")\n", - "else:\n", - " W = '\\033[0m' # white (normal)\n", - " R = '\\033[31m' # red\n", - " print(R+'!! WARNING: The chosen model does not exist !!'+W)\n", - " print('Please make sure you provide a valid model path and model name before proceeding further.')\n", - "\n", - "Nb_Checkpoint = len(glob.glob(os.path.join(full_Prediction_model_path, '*G.pth')))+1\n", - "\n", - "if not checkpoint == \"latest\":\n", - "\n", - " if checkpoint < 10:\n", - " checkpoint = 5\n", - "\n", - " if not checkpoint % 5 == 0:\n", - " checkpoint = ((int(checkpoint / 5)-1) * 5)\n", - " print (bcolors.WARNING + \" Your chosen checkpoints is not divisible by 5; therefore the checkpoints chosen is now:\",checkpoints)\n", - " \n", - " if checkpoint == Nb_Checkpoint*5:\n", - " checkpoint = \"latest\"\n", - "\n", - " if checkpoint > Nb_Checkpoint*5:\n", - " checkpoint = \"latest\"\n", - "\n", - "# Here we need to move the data to be analysed so that pix2pix can find them\n", - "\n", - "Saving_path_prediction= \"/content/\"+Prediction_model_name\n", - "\n", - "if os.path.exists(Saving_path_prediction):\n", - " shutil.rmtree(Saving_path_prediction)\n", - "os.makedirs(Saving_path_prediction)\n", - "\n", - "imageA_folder = Saving_path_prediction+\"/A\"\n", - "os.makedirs(imageA_folder)\n", - "\n", - "imageB_folder = Saving_path_prediction+\"/B\"\n", - "os.makedirs(imageB_folder)\n", - "\n", - "imageAB_folder = Saving_path_prediction+\"/AB\"\n", - "os.makedirs(imageAB_folder)\n", - "\n", - "testAB_Folder = Saving_path_prediction+\"/AB/test\"\n", - "os.makedirs(testAB_Folder)\n", - "\n", - "testA_Folder = Saving_path_prediction+\"/A/test\"\n", - "os.makedirs(testA_Folder)\n", - " \n", - "testB_Folder = Saving_path_prediction+\"/B/test\"\n", - "os.makedirs(testB_Folder)\n", - "\n", - "#Here we copy and normalise the data\n", - "\n", - "if Normalisation_prediction_source == \"Contrast stretching\":\n", - " \n", - " for filename in os.listdir(Data_folder):\n", - "\n", - " img = imread(os.path.join(Data_folder,filename)).astype(np.float32)\n", - " short_name = os.path.splitext(filename)\n", - "\n", - " p2, p99 = np.percentile(img, (2, 99.9))\n", - " img = exposure.rescale_intensity(img, in_range=(p2, p99))\n", - "\n", - " img = 255 * img # Now scale by 255\n", - " img = img.astype(np.uint8)\n", - " cv2.imwrite(testA_Folder+\"/\"+short_name[0]+\".png\", img)\n", - " cv2.imwrite(testB_Folder+\"/\"+short_name[0]+\".png\", img)\n", - " \n", - "if Normalisation_prediction_source == \"Adaptive Equalization\":\n", - "\n", - " for filename in os.listdir(Data_folder):\n", - "\n", - " img = imread(os.path.join(Data_folder,filename))\n", - " short_name = os.path.splitext(filename)\n", - "\n", - " img = exposure.equalize_adapthist(img, clip_limit=0.03)\n", - "\n", - " img = 255 * img # Now scale by 255\n", - " img = img.astype(np.uint8)\n", - "\n", - " cv2.imwrite(testA_Folder+\"/\"+short_name[0]+\".png\", img)\n", - " cv2.imwrite(testB_Folder+\"/\"+short_name[0]+\".png\", img)\n", - "\n", - "if Normalisation_prediction_source == \"None\":\n", - " for files in os.listdir(Data_folder):\n", - " shutil.copyfile(Data_folder+\"/\"+files, testA_Folder+\"/\"+files)\n", - " shutil.copyfile(Data_folder+\"/\"+files, testB_Folder+\"/\"+files)\n", - " \n", - "# Here we create a merged A / A image for the prediction\n", - "os.chdir(\"/content\")\n", - "!python pytorch-CycleGAN-and-pix2pix/datasets/combine_A_and_B.py --fold_A \"$imageA_folder\" --fold_B \"$imageB_folder\" --fold_AB \"$imageAB_folder\"\n", - "\n", - "# Here we count how many images are in our folder to be predicted and we had a few\n", - "Nb_files_Data_folder = len(os.listdir(Data_folder)) +10\n", - "\n", - "# This will find the image dimension of a randomly choosen image in Data_folder \n", - "random_choice = random.choice(os.listdir(Data_folder))\n", - "x = imageio.imread(Data_folder+\"/\"+random_choice)\n", - "\n", - "#Find image XY dimension\n", - "Image_Y = x.shape[0]\n", - "Image_X = x.shape[1]\n", - "\n", - "Image_min_dim = min(Image_Y, Image_X)\n", - "\n", - "\n", - "#-------------------------------- Perform predictions -----------------------------\n", - "\n", - "#-------------------------------- Options that can be used to perform predictions -----------------------------\n", - "\n", - "# basic parameters\n", - " #('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')\n", - " #('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')\n", - " #('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')\n", - " #('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')\n", - "\n", - "# model parameters\n", - " #('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]')\n", - " #('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')\n", - " #('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')\n", - " #('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')\n", - " #('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')\n", - " #('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')\n", - " #('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]')\n", - " #('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')\n", - " #('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]')\n", - " #('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]')\n", - " #('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')\n", - " #('--no_dropout', action='store_true', help='no dropout for the generator')\n", - " \n", - "# dataset parameters\n", - " #('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')\n", - " #('--direction', type=str, default='AtoB', help='AtoB or BtoA')\n", - " #('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')\n", - " #('--num_threads', default=4, type=int, help='# threads for loading data')\n", - " #('--batch_size', type=int, default=1, help='input batch size')\n", - " #('--load_size', type=int, default=286, help='scale images to this size')\n", - " #('--crop_size', type=int, default=256, help='then crop to this size')\n", - " #('--max_dataset_size', type=int, default=float(\"inf\"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')\n", - " #('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')\n", - " #('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')\n", - " #('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')\n", - " \n", - "# additional parameters\n", - " #('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')\n", - " #('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')\n", - " #('--verbose', action='store_true', help='if specified, print more debugging information')\n", - " #('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')\n", - " \n", - "\n", - " #('--ntest', type=int, default=float(\"inf\"), help='# of test examples.')\n", - " #('--results_dir', type=str, default='./results/', help='saves results here.')\n", - " #('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')\n", - " #('--phase', type=str, default='test', help='train, val, test, etc')\n", - "\n", - "# Dropout and Batchnorm has different behavioir during training and test.\n", - " #('--eval', action='store_true', help='use eval mode during test time.')\n", - " #('--num_test', type=int, default=50, help='how many test images to run')\n", - " # rewrite devalue values\n", - " \n", - "# To avoid cropping, the load_size should be the same as crop_size\n", - " #parser.set_defaults(load_size=parser.get_default('crop_size'))\n", - "\n", - "#------------------------------------------------------------------------\n", - "\n", - "\n", - "#---------------------------- Predictions are performed here ----------------------\n", - "\n", - "os.chdir(\"/content\")\n", - "\n", - "!python pytorch-CycleGAN-and-pix2pix/test.py --dataroot \"$imageAB_folder\" --name \"$Prediction_model_name\" --model pix2pix --no_dropout --preprocess scale_width --load_size $Image_min_dim --crop_size $patch_size --results_dir \"$Result_folder\" --checkpoints_dir \"$Prediction_model_path\" --num_test $Nb_files_Data_folder --epoch $checkpoint\n", - "\n", - "#-----------------------------------------------------------------------------------\n", - "\n", - "\n", - "Checkpoint_name = \"test_\"+str(checkpoint)\n", - "\n", - "\n", - "Prediction_results_folder = Result_folder+\"/\"+Prediction_model_name+\"/\"+Checkpoint_name+\"/images\"\n", - "\n", - "Prediction_results_images = os.listdir(Prediction_results_folder)\n", - "\n", - "for f in Prediction_results_images: \n", - " if (f.endswith(\"_real_B.png\")): \n", - " os.remove(Prediction_results_folder+\"/\"+f)\n", - "\n", - "\n", - "\n", - "\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Pdnb77E15zLE" - }, - "source": [ - "## **6.2. Inspect the predicted output**\n", - "---\n", - "\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "CrEBdt9T53Eh", - "cellView": "form" - }, - "source": [ - "# @markdown ##Run this cell to display a randomly chosen input and its corresponding predicted output.\n", - "import os\n", - "# This will display a randomly chosen dataset input and predicted output\n", - "random_choice = random.choice(os.listdir(Data_folder))\n", - "\n", - "\n", - "random_choice_no_extension = os.path.splitext(random_choice)\n", - "\n", - "\n", - "x = imageio.imread(Result_folder+\"/\"+Prediction_model_name+\"/test_\"+str(checkpoint)+\"/images/\"+random_choice_no_extension[0]+\"_real_A.png\")\n", - "\n", - "\n", - "y = imageio.imread(Result_folder+\"/\"+Prediction_model_name+\"/test_\"+str(checkpoint)+\"/images/\"+random_choice_no_extension[0]+\"_fake_B.png\")\n", - "\n", - "f=plt.figure(figsize=(16,8))\n", - "plt.subplot(1,2,1)\n", - "plt.imshow(x, interpolation='nearest')\n", - "plt.title('Input')\n", - "plt.axis('off');\n", - "\n", - "plt.subplot(1,2,2)\n", - "plt.imshow(y, interpolation='nearest')\n", - "plt.title('Prediction')\n", - "plt.axis('off');\n" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "hvkd66PldsXB" - }, - "source": [ - "## **6.3. Download your predictions**\n", - "---\n", - "\n", - "**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "UvSlTaH14s3t" - }, - "source": [ - "\n", - "#**Thank you for using pix2pix!**" - ] - } - ] -} \ No newline at end of file +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"pix2pix_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1A26cn0nxWQCv-LuP3UBfyCWlKBGIo0RU","timestamp":1610978553958},{"file_id":"1MmLTCC0nyX3Akb9V4C_OVxM3X_M8u-eX","timestamp":1610543191319},{"file_id":"1paNjUObR5Rcr4BMGADJTz0PQBBLZDPrY","timestamp":1602522500580},{"file_id":"1WZRIoSBNcRUEq4-Rq5M4mDkIaOlEHnxz","timestamp":1588762142860},{"file_id":"10weAY0es-pEfHlACCaBCKK7PmgdoJqdh","timestamp":1587728072051},{"file_id":"10Ze0rFZoooyyTL_OIVWGdFJEhWE6_cSB","timestamp":1586789421439},{"file_id":"1SsGyUbWcMaLGHFepMuKElRNYLdEBUwf6","timestamp":1583244509550}],"collapsed_sections":[],"machine_shape":"hm"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.7"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"IkSguVy8Xv83"},"source":["# **pix2pix**\n","\n","---\n","\n","pix2pix is a deep-learning method allowing image-to-image translation from one image domain type to another image domain type. It was first published by [Isola *et al.* in 2016](https://arxiv.org/abs/1611.07004). The image transformation requires paired images for training (supervised learning) and is made possible here by using a conditional Generative Adversarial Network (GAN) architecture to use information from the input image and obtain the equivalent translated image.\n","\n"," **This particular notebook enables image-to-image translation learned from paired dataset. If you are interested in performing unpaired image-to-image translation, you should consider using the CycleGAN notebook instead.**\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the *Zero-Cost Deep-Learning to Enhance Microscopy* project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories.\n","\n","This notebook is based on the following paper: \n","\n"," **Image-to-Image Translation with Conditional Adversarial Networks** by Isola *et al.* on arXiv in 2016 (https://arxiv.org/abs/1611.07004)\n","\n","The source code of the PyTorch implementation of pix2pix can be found here: /~https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix\n","\n","**Please also cite this original paper when using or developing this notebook.**"]},{"cell_type":"markdown","metadata":{"id":"W7HfryEazzJE"},"source":["# **License**\n","\n","---"]},{"cell_type":"code","metadata":{"cellView":"form","id":"4TTFT14b0J6n"},"source":["#@markdown ##Double click to see the license information\n","\n","#------------------------- LICENSE FOR ZeroCostDL4Mic------------------------------------\n","#This ZeroCostDL4Mic notebook is distributed under the MIT licence\n","\n","\n","\n","#------------------------- LICENSE FOR CycleGAN ------------------------------------\n","\n","#Copyright (c) 2017, Jun-Yan Zhu and Taesung Park\n","#All rights reserved.\n","\n","#Redistribution and use in source and binary forms, with or without\n","#modification, are permitted provided that the following conditions are met:\n","\n","#* Redistributions of source code must retain the above copyright notice, this\n","# list of conditions and the following disclaimer.\n","\n","#* Redistributions in binary form must reproduce the above copyright notice,\n","# this list of conditions and the following disclaimer in the documentation\n","# and/or other materials provided with the distribution.\n","\n","#THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n","#AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n","#IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n","#DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n","#FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n","#DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n","#SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n","#CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n","#OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n","#OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n","\n","\n","#--------------------------- LICENSE FOR pix2pix --------------------------------\n","#BSD License\n","\n","#For pix2pix software\n","#Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu\n","#All rights reserved.\n","\n","#Redistribution and use in source and binary forms, with or without\n","#modification, are permitted provided that the following conditions are met:\n","\n","#* Redistributions of source code must retain the above copyright notice, this\n","# list of conditions and the following disclaimer.\n","\n","#* Redistributions in binary form must reproduce the above copyright notice,\n","# this list of conditions and the following disclaimer in the documentation\n","# and/or other materials provided with the distribution.\n","\n","#----------------------------- LICENSE FOR DCGAN --------------------------------\n","#BSD License\n","\n","#For dcgan.torch software\n","\n","#Copyright (c) 2015, Facebook, Inc. All rights reserved.\n","\n","#Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:\n","\n","#Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.\n","\n","#Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.\n","\n","#Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.\n","\n","#THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE."],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"gKDLkLWUd-YX"},"source":["#**0. Before getting started**\n","---\n"," For pix2pix to train, **it needs to have access to a paired training dataset**. This means that the same image needs to be acquired in the two conditions and provided with indication of correspondence.\n","\n"," Therefore, the data structure is important. It is necessary that all the input data are in the same folder and that all the output data is in a separate folder. The provided training dataset is already split in two folders called Training_source and Training_target. Information on how to generate a training dataset is available in our Wiki page: /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\n","\n","**We strongly recommend that you generate extra paired images. These images can be used to assess the quality of your trained model (Quality control dataset)**. The quality control assessment can be done directly in this notebook.\n","\n"," **Additionally, the corresponding input and output files need to have the same name**.\n","\n"," Please note that you currently can **only use .PNG files!**\n","\n","\n","Here's a common data structure that can work:\n","* Experiment A\n"," - **Training dataset**\n"," - Training_source\n"," - img_1.png, img_2.png, ...\n"," - Training_target\n"," - img_1.png, img_2.png, ...\n"," - **Quality control dataset**\n"," - Training_source\n"," - img_1.png, img_2.png\n"," - Training_target\n"," - img_1.png, img_2.png\n"," - **Data to be predicted**\n"," - **Results**\n","\n","---\n","**Important note**\n","\n","- If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do that), you will need to run **sections 1 - 4**, then use **section 5** to assess the quality of your model and **section 6** to run predictions using the model that you trained.\n","\n","- If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 5** to assess the quality of your model.\n","\n","- If you only wish to **run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **sections 1 and 2** to set up the notebook, then use **section 6** to run the predictions on the desired model.\n","---"]},{"cell_type":"markdown","metadata":{"id":"AdN8B91xZO0x"},"source":["# **1. Install pix2pix and dependencies**\n","---\n"]},{"cell_type":"code","metadata":{"id":"fq21zJVFNASx","cellView":"form"},"source":["Notebook_version = '1.13'\n","Network = 'pix2pix'\n","\n","\n","from builtins import any as b_any\n","\n","def get_requirements_path():\n"," # Store requirements file in 'contents' directory \n"," current_dir = os.getcwd()\n"," dir_count = current_dir.count('/') - 1\n"," path = '../' * (dir_count) + 'requirements.txt'\n"," return path\n","\n","def filter_files(file_list, filter_list):\n"," filtered_list = []\n"," for fname in file_list:\n"," if b_any(fname.split('==')[0] in s for s in filter_list):\n"," filtered_list.append(fname)\n"," return filtered_list\n","\n","def build_requirements_file(before, after):\n"," path = get_requirements_path()\n","\n"," # Exporting requirements.txt for local run\n"," !pip freeze > $path\n","\n"," # Get minimum requirements file\n"," df = pd.read_csv(path, delimiter = \"\\n\")\n"," mod_list = [m.split('.')[0] for m in after if not m in before]\n"," req_list_temp = df.values.tolist()\n"," req_list = [x[0] for x in req_list_temp]\n","\n"," # Replace with package name and handle cases where import name is different to module name\n"," mod_name_list = [['sklearn', 'scikit-learn'], ['skimage', 'scikit-image']]\n"," mod_replace_list = [[x[1] for x in mod_name_list] if s in [x[0] for x in mod_name_list] else s for s in mod_list] \n"," filtered_list = filter_files(req_list, mod_replace_list)\n","\n"," file=open(path,'w')\n"," for item in filtered_list:\n"," file.writelines(item + '\\n')\n","\n"," file.close()\n","\n","import sys\n","before = [str(m) for m in sys.modules]\n","\n","#@markdown ##Install pix2pix and dependencies\n","\n","#Here, we install libraries which are not already included in Colab.\n","!git clone /~https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix\n","import os\n","os.chdir('pytorch-CycleGAN-and-pix2pix/')\n","!pip install -r requirements.txt\n","!pip install fpdf\n","!pip install lpips\n","\n","import lpips\n","from PIL import Image\n","import imageio\n","from skimage import data\n","from skimage import exposure\n","from skimage.exposure import match_histograms\n","import os.path\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","import numpy as np\n","from matplotlib import pyplot as plt\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","from skimage.util import img_as_ubyte\n","from tqdm import tqdm \n","from fpdf import FPDF, HTMLMixin\n","from datetime import datetime\n","import subprocess\n","from pip._internal.operations.freeze import freeze\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","print('----------------------------')\n","print(\"Libraries installed\")\n","\n","\n","# Check if this is the latest version of the notebook\n","All_notebook_versions = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_Notebook_versions.csv\", dtype=str)\n","print('Notebook version: '+Notebook_version)\n","Latest_Notebook_version = All_notebook_versions[All_notebook_versions[\"Notebook\"] == Network]['Version'].iloc[0]\n","print('Latest notebook version: '+Latest_Notebook_version)\n","if Notebook_version == Latest_Notebook_version:\n"," print(\"This notebook is up-to-date.\")\n","else:\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","\n","\n","# average function\n","def Average(lst): \n"," return sum(lst) / len(lst) \n","\n","def perceptual_diff(im0, im1, network, spatial):\n","\n"," tensor0 = lpips.im2tensor(im0)\n"," tensor1 = lpips.im2tensor(im1)\n"," # Set up the loss function we will use\n"," loss_fn = lpips.LPIPS(net=network, spatial=spatial, verbose=False)\n","\n"," diff = loss_fn.forward(tensor0, tensor1)\n","\n"," return diff\n","\n","def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'pix2pix'\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n"," \n"," # add another cell \n"," if trained:\n"," training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n"," pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n"," pdf.ln(1)\n","\n"," Header_2 = 'Information for your materials and method:'\n"," pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n"," #print(all_packages)\n","\n"," #Main Packages\n"," main_packages = ''\n"," version_numbers = []\n"," for name in ['tensorflow','numpy','torch']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n"," cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n"," cuda_version = cuda_version.stdout.decode('utf-8')\n"," cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n"," gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n"," gpu_name = gpu_name.stdout.decode('utf-8')\n"," gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n"," #print(cuda_version[cuda_version.find(', V')+3:-1])\n"," #print(gpu_name)\n","\n"," shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n"," dataset_size = len(os.listdir(Training_source))\n","\n"," text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a vanilla GAN loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), numpy (v '+version_numbers[1]+'), torch (v '+version_numbers[2]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," if Use_pretrained_model:\n"," text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(dataset_size)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a vanilla GAN loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was retrained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), numpy (v '+version_numbers[1]+'), torch (v '+version_numbers[2]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.multi_cell(190, 5, txt = text, align='L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n"," pdf.set_font('')\n"," if augmentation:\n"," aug_text = 'The dataset was augmented by a factor of '+str(Multiply_dataset_by)+' by'\n"," if rotate_270_degrees != 0 or rotate_90_degrees != 0:\n"," aug_text = aug_text+'\\n- rotation'\n"," if flip_left_right != 0 or flip_top_bottom != 0:\n"," aug_text = aug_text+'\\n- flipping'\n"," if random_zoom_magnification != 0:\n"," aug_text = aug_text+'\\n- random zoom magnification'\n"," if random_distortion != 0:\n"," aug_text = aug_text+'\\n- random distortion'\n"," if image_shear != 0:\n"," aug_text = aug_text+'\\n- image shearing'\n"," if skew_image != 0:\n"," aug_text = aug_text+'\\n- image skewing'\n"," else:\n"," aug_text = 'No augmentation was used for training.'\n"," pdf.multi_cell(190, 5, txt=aug_text, align='L')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," if Use_Default_Advanced_Parameters:\n"," pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n"," pdf.cell(200, 5, txt='The following parameters were used for training:')\n"," pdf.ln(1)\n"," html = \"\"\" \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
number_of_epochs{0}
patch_size{1}
batch_size{2}
initial_learning_rate{3}
\n"," \"\"\".format(number_of_epochs,str(patch_size)+'x'+str(patch_size),batch_size,initial_learning_rate)\n"," pdf.write_html(html)\n","\n"," #pdf.multi_cell(190, 5, txt = text_2, align='L')\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," pdf.ln(1)\n"," pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(30, 5, txt= 'Training_source:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(29, 5, txt= 'Training_target:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n"," #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(22, 5, txt= 'Model Path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n"," pdf.ln(1)\n"," pdf.cell(60, 5, txt = 'Example Training pair', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread('/content/TrainingDataExample_pix2pix.png').shape\n"," pdf.image('/content/TrainingDataExample_pix2pix.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- pix2pix: Isola, Phillip, et al. \"Image-to-image translation with conditional adversarial networks.\" Proceedings of the IEEE conference on computer vision and pattern recognition. 2017.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," if augmentation:\n"," ref_3 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\n"," pdf.multi_cell(190, 5, txt = ref_3, align='L')\n"," pdf.ln(3)\n"," reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(model_path+'/'+model_name+'/'+model_name+\"_training_report.pdf\")\n","\n","def qc_pdf_export():\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'pix2pix'\n","\n","\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Quality Control report for '+Network+' model ('+QC_model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n","\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(2)\n"," pdf.cell(190, 5, txt = 'Development of Training Losses', ln=1, align='L')\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/SSIMvsCheckpoint_data.png').shape\n"," pdf.image(full_QC_model_path+'/Quality Control/SSIMvsCheckpoint_data.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(2)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(3)\n"," pdf.cell(80, 5, txt = 'Example Quality Control Visualisation', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread(full_QC_model_path+'/Quality Control/QC_example_data.png').shape\n"," if Image_type == 'RGB':\n"," pdf.image(full_QC_model_path+'/Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/5), h = round(exp_size[0]/5))\n"," if Image_type == 'Grayscale':\n"," pdf.image(full_QC_model_path+'/Quality Control/QC_example_data.png', x = 16, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Quality Control Metrics', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n","\n"," pdf.ln(1)\n"," for checkpoint in os.listdir(full_QC_model_path+'/Quality Control'):\n"," if os.path.isdir(os.path.join(full_QC_model_path,'Quality Control',checkpoint)) and checkpoint != 'Prediction':\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(70, 5, txt = 'Metrics for checkpoint: '+ str(checkpoint), align='L', ln=1)\n"," html = \"\"\"\n"," \n"," \n"," \"\"\"\n"," with open(full_QC_model_path+'/Quality Control/'+str(checkpoint)+'/QC_metrics_'+QC_model_name+str(checkpoint)+'.csv', 'r') as csvfile:\n"," metrics = csv.reader(csvfile)\n"," header = next(metrics)\n"," image = header[0]\n"," mSSIM_PvsGT = header[1]\n"," mSSIM_SvsGT = header[2]\n"," header = \"\"\"\n"," \n"," \n"," \n"," \n"," \"\"\".format(image,mSSIM_PvsGT,mSSIM_SvsGT)\n"," html = html+header\n"," for row in metrics:\n"," image = row[0]\n"," mSSIM_PvsGT = row[1]\n"," mSSIM_SvsGT = row[2]\n"," cells = \"\"\"\n"," \n"," \n"," \n"," \n"," \"\"\".format(image,str(round(float(mSSIM_PvsGT),3)),str(round(float(mSSIM_SvsGT),3)))\n"," html = html+cells\n"," html = html+\"\"\"
{0}{1}{2}
{0}{1}{2}
\"\"\"\n"," pdf.write_html(html)\n"," pdf.ln(2)\n"," else:\n"," continue\n","\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"Democratising deep learning for microscopy with ZeroCostDL4Mic.\" Nature Communications (2021).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- pix2pix: Isola, Phillip, et al. \"Image-to-image translation with conditional adversarial networks.\" Proceedings of the IEEE conference on computer vision and pattern recognition. 2017.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n","\n"," pdf.ln(3)\n"," reminder = 'To find the parameters and other information about how this model was trained, go to the training_report.pdf of this model which should be in the folder of the same name.'\n","\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(full_QC_model_path+'/Quality Control/'+QC_model_name+'_QC_report.pdf')\n","\n","# Build requirements file for local run\n","after = [str(m) for m in sys.modules]\n","build_requirements_file(before, after)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **2. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["\n","## **2.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelerator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"id":"zCvebubeSaGY","cellView":"form"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"sNIVx8_CLolt"},"source":["## **2.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"id":"01Djr8v-5pPk","cellView":"form"},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"HLYcZR9gMv42"},"source":["# **3. Select your parameters and paths**\n","---"]},{"cell_type":"markdown","metadata":{"id":"FQ_QxtSWQ7CL"},"source":["## **3.1. Setting main training parameters**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"id":"AuESFimvMv43"},"source":[" **Paths for training, predictions and results**\n","\n","**`Training_source:`, `Training_target`:** These are the paths to your folders containing the Training_source and Training_target training data respecively. To find the paths of the folders containing the respective datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`model_name`:** Use only my_model -style, not my-model (Use \"_\" not \"-\"). Do not use spaces in the name. Avoid using the name of an existing model (saved in the same folder) as it will be overwritten.\n","\n","**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).\n","\n","**Training parameters**\n","\n","**`number_of_epochs`:**Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10) epochs, but a full training should run for 200 epochs or more. Evaluate the performance after training (see 5). **Default value: 200**\n","\n","**Advanced Parameters - experienced users only**\n","\n","**`patch_size`:** pix2pix divides the image into patches for training. Input the size of the patches (length of a side). The value should be smaller than the dimensions of the image and divisible by 8. **Default value: 512**\n","\n","**When choosing the patch_size, the value should be i) large enough that it will enclose many instances, ii) small enough that the resulting patches fit into the RAM.** \n","\n","**`batch_size:`** This parameter defines the number of patches seen in each training step. Reducing or increasing the **batch size** may slow or speed up your training, respectively, and can influence network performance. **Default value: 1**\n","\n","**`initial_learning_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0002**"]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","cellView":"form"},"source":["import cv2\n","\n","#@markdown ###Path to training images:\n","\n","Training_source = \"\" #@param {type:\"string\"}\n","#InputFile = Training_source+\"/*.png\"\n","\n","Training_target = \"\" #@param {type:\"string\"}\n","#OutputFile = Training_target+\"/*.png\"\n","\n","#@markdown ###Image normalisation:\n","\n","Normalisation_training_source = \"None\" #@param [\"None\", \"Contrast stretching\", \"Adaptive Equalization\"]\n","Normalisation_training_target = \"None\" #@param [\"None\", \"Contrast stretching\", \"Adaptive Equalization\"]\n","\n","\n","#Define where the patch file will be saved\n","base = \"/content\"\n","\n","# model name and path\n","#@markdown ###Name of the model and path to model folder:\n","model_name = \"\" #@param {type:\"string\"}\n","model_path = \"\" #@param {type:\"string\"}\n","\n","# other parameters for training.\n","#@markdown ###Training Parameters\n","#@markdown Number of epochs:\n","number_of_epochs = 200#@param {type:\"number\"}\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True #@param {type:\"boolean\"}\n","#@markdown ###If not, please input:\n","patch_size = 512#@param {type:\"number\"} # in pixels\n","batch_size = 1#@param {type:\"number\"}\n","initial_learning_rate = 0.0002 #@param {type:\"number\"}\n","\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\")\n"," batch_size = 1\n"," patch_size = 512\n"," initial_learning_rate = 0.0002\n","\n","#here we check that no model with the same name already exist, if so delete\n","if os.path.exists(model_path+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n"," print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\")\n"," \n","#To use pix2pix we need to organise the data in a way the network can understand\n","\n","Saving_path= \"/content/\"+model_name\n","\n","if os.path.exists(Saving_path):\n"," shutil.rmtree(Saving_path)\n","os.makedirs(Saving_path)\n","\n","imageA_folder = Saving_path+\"/A\"\n","os.makedirs(imageA_folder)\n","\n","imageB_folder = Saving_path+\"/B\"\n","os.makedirs(imageB_folder)\n","\n","imageAB_folder = Saving_path+\"/AB\"\n","os.makedirs(imageAB_folder)\n","\n","TrainA_Folder = Saving_path+\"/A/train\"\n","os.makedirs(TrainA_Folder)\n"," \n","TrainB_Folder = Saving_path+\"/B/train\"\n","os.makedirs(TrainB_Folder)\n","\n","# Here we disable pre-trained model by default (in case the cell is not ran)\n","Use_pretrained_model = False\n","\n","# Here we disable data augmentation by default (in case the cell is not ran)\n","\n","Use_Data_augmentation = False\n","\n","# Here we normalise the image is enabled\n","\n","if Normalisation_training_source == \"Contrast stretching\":\n","\n"," Training_source_norm = Saving_path+\"/Training_source_norm\"\n"," os.makedirs(Training_source_norm)\n"," \n"," for filename in os.listdir(Training_source):\n","\n"," img = imread(os.path.join(Training_source,filename)).astype(np.float32)\n"," short_name = os.path.splitext(filename)\n","\n"," p2, p99 = np.percentile(img, (2, 99.9))\n"," img = exposure.rescale_intensity(img, in_range=(p2, p99))\n","\n"," img = 255 * img # Now scale by 255\n"," img = img.astype(np.uint8)\n"," cv2.imwrite(Training_source_norm+\"/\"+short_name[0]+\".png\", img)\n"," \n"," Training_source = Training_source_norm\n","\n","\n","if Normalisation_training_target == \"Contrast stretching\":\n","\n"," Training_target_norm = Saving_path+\"/Training_target_norm\"\n"," os.makedirs(Training_target_norm)\n","\n"," for filename in os.listdir(Training_target):\n","\n"," img = imread(os.path.join(Training_target,filename)).astype(np.float32)\n"," short_name = os.path.splitext(filename)\n","\n"," p2, p99 = np.percentile(img, (2, 99.9))\n"," img = exposure.rescale_intensity(img, in_range=(p2, p99))\n","\n"," img = 255 * img # Now scale by 255\n"," img = img.astype(np.uint8)\n"," cv2.imwrite(Training_target_norm+\"/\"+short_name[0]+\".png\", img)\n","\n"," Training_target = Training_target_norm\n","\n","\n","if Normalisation_training_source == \"Adaptive Equalization\":\n"," Training_source_norm = Saving_path+\"/Training_source_norm\"\n"," os.makedirs(Training_source_norm)\n","\n"," for filename in os.listdir(Training_source):\n","\n"," img = imread(os.path.join(Training_source,filename))\n"," short_name = os.path.splitext(filename)\n","\n"," img = exposure.equalize_adapthist(img, clip_limit=0.03)\n","\n"," img = 255 * img # Now scale by 255\n"," img = img.astype(np.uint8)\n"," cv2.imwrite(Training_source_norm+\"/\"+short_name[0]+\".png\", img)\n","\n","\n"," Training_source = Training_source_norm\n","\n","\n","if Normalisation_training_target == \"Adaptive Equalization\":\n","\n"," Training_target_norm = Saving_path+\"/Training_target_norm\"\n"," os.makedirs(Training_target_norm)\n","\n"," for filename in os.listdir(Training_target):\n","\n"," img = imread(os.path.join(Training_target,filename))\n"," short_name = os.path.splitext(filename)\n","\n"," p2, p99 = np.percentile(img, (2, 99.8))\n"," img = exposure.rescale_intensity(img, in_range=(p2, p99))\n","\n"," img = 255 * img # Now scale by 255\n"," img = img.astype(np.uint8)\n"," cv2.imwrite(Training_target_norm+\"/\"+short_name[0]+\".png\", img)\n","\n"," Training_target = Training_target_norm\n","\n","# This will display a randomly chosen dataset input and output\n","random_choice = random.choice(os.listdir(Training_source))\n","x = io.imread(Training_source+\"/\"+random_choice)\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","Image_min_dim = min(Image_Y, Image_X)\n","\n","#Hyperparameters failsafes\n","if patch_size > min(Image_Y, Image_X):\n"," patch_size = min(Image_Y, Image_X)\n"," print (bcolors.WARNING + \" Your chosen patch_size is bigger than the xy dimension of your image; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is divisible by 4\n","if not patch_size % 4 == 0:\n"," patch_size = ((int(patch_size / 4)-1) * 4)\n"," print (bcolors.WARNING + \" Your chosen patch_size is not divisible by 4; therefore the patch_size chosen is now:\",patch_size)\n","\n","# Here we check that patch_size is at least bigger than 256\n","if patch_size < 256:\n"," patch_size = 256\n"," print (bcolors.WARNING + \" Your chosen patch_size is too small; therefore the patch_size chosen is now:\",patch_size)\n","\n","y = io.imread(Training_target+\"/\"+random_choice)\n","\n","n_channel_x = 1 if x.ndim == 2 else x.shape[-1]\n","n_channel_y = 1 if y.ndim == 2 else y.shape[-1]\n","\n","if n_channel_x == 1:\n"," cmap_x = 'gray'\n","else:\n"," cmap_x = None\n","\n","if n_channel_y == 1:\n"," cmap_y = 'gray'\n","else:\n"," cmap_y = None\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, cmap=cmap_x, interpolation='nearest')\n","plt.title('Training source')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(y, cmap=cmap_y, interpolation='nearest')\n","plt.title('Training target')\n","plt.axis('off');\n","\n","plt.savefig('/content/TrainingDataExample_pix2pix.png',bbox_inches='tight',pad_inches=0)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xyQZKby8yFME"},"source":["## **3.2. Data augmentation**\n","---\n",""]},{"cell_type":"markdown","metadata":{"id":"w_jCy7xOx2g3"},"source":["Data augmentation can improve training progress by amplifying differences in the dataset. This can be useful if the available dataset is small since, in this case, it is possible that a network could quickly learn every example in the dataset (overfitting), without augmentation. Augmentation is not necessary for training and if your training dataset is large you should disable it.\n","\n","Data augmentation is performed here by [Augmentor.](/~https://github.com/mdbloice/Augmentor)\n","\n","[Augmentor](/~https://github.com/mdbloice/Augmentor) was described in the following article:\n","\n","Marcus D Bloice, Peter M Roth, Andreas Holzinger, Biomedical image augmentation using Augmentor, Bioinformatics, https://doi.org/10.1093/bioinformatics/btz259\n","\n","**Please also cite this original paper when publishing results obtained using this notebook with augmentation enabled.** "]},{"cell_type":"code","metadata":{"id":"DMqWq5-AxnFU","cellView":"form"},"source":["#Data augmentation\n","\n","Use_Data_augmentation = False #@param {type:\"boolean\"}\n","\n","if Use_Data_augmentation:\n"," !pip install Augmentor\n"," import Augmentor\n","\n","\n","#@markdown ####Choose a factor by which you want to multiply your original dataset\n","\n","Multiply_dataset_by = 2 #@param {type:\"slider\", min:1, max:30, step:1}\n","\n","Save_augmented_images = False #@param {type:\"boolean\"}\n","\n","Saving_path = \"\" #@param {type:\"string\"}\n","\n","\n","Use_Default_Augmentation_Parameters = True #@param {type:\"boolean\"}\n","#@markdown ###If not, please choose the probability of the following image manipulations to be used to augment your dataset (1 = always used; 0 = disabled ):\n","\n","#@markdown ####Mirror and rotate images\n","rotate_90_degrees = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","rotate_270_degrees = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","flip_left_right = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","flip_top_bottom = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","#@markdown ####Random image Zoom\n","\n","random_zoom = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","random_zoom_magnification = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","#@markdown ####Random image distortion\n","\n","random_distortion = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","\n","#@markdown ####Image shearing and skewing \n","\n","image_shear = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","max_image_shear = 10 #@param {type:\"slider\", min:1, max:25, step:1}\n","\n","skew_image = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","skew_image_magnitude = 0 #@param {type:\"slider\", min:0, max:1, step:0.1}\n","\n","\n","if Use_Default_Augmentation_Parameters:\n"," rotate_90_degrees = 0.5\n"," rotate_270_degrees = 0.5\n"," flip_left_right = 0.5\n"," flip_top_bottom = 0.5\n","\n"," if not Multiply_dataset_by >5:\n"," random_zoom = 0\n"," random_zoom_magnification = 0.9\n"," random_distortion = 0\n"," image_shear = 0\n"," max_image_shear = 10\n"," skew_image = 0\n"," skew_image_magnitude = 0\n","\n"," if Multiply_dataset_by >5:\n"," random_zoom = 0.1\n"," random_zoom_magnification = 0.9\n"," random_distortion = 0.5\n"," image_shear = 0.2\n"," max_image_shear = 5\n"," skew_image = 0.2\n"," skew_image_magnitude = 0.4\n","\n"," if Multiply_dataset_by >25:\n"," random_zoom = 0.5\n"," random_zoom_magnification = 0.8\n"," random_distortion = 0.5\n"," image_shear = 0.5\n"," max_image_shear = 20\n"," skew_image = 0.5\n"," skew_image_magnitude = 0.6\n","\n","\n","list_files = os.listdir(Training_source)\n","Nb_files = len(list_files)\n","\n","Nb_augmented_files = (Nb_files * Multiply_dataset_by)\n","\n","\n","if Use_Data_augmentation:\n"," print(\"Data augmentation enabled\")\n","# Here we set the path for the various folder were the augmented images will be loaded\n","\n","# All images are first saved into the augmented folder\n"," #Augmented_folder = \"/content/Augmented_Folder\"\n"," \n"," if not Save_augmented_images:\n"," Saving_path= \"/content\"\n","\n"," Augmented_folder = Saving_path+\"/Augmented_Folder\"\n"," if os.path.exists(Augmented_folder):\n"," shutil.rmtree(Augmented_folder)\n"," os.makedirs(Augmented_folder)\n","\n"," #Training_source_augmented = \"/content/Training_source_augmented\"\n"," Training_source_augmented = Saving_path+\"/Training_source_augmented\"\n","\n"," if os.path.exists(Training_source_augmented):\n"," shutil.rmtree(Training_source_augmented)\n"," os.makedirs(Training_source_augmented)\n","\n"," #Training_target_augmented = \"/content/Training_target_augmented\"\n"," Training_target_augmented = Saving_path+\"/Training_target_augmented\"\n","\n"," if os.path.exists(Training_target_augmented):\n"," shutil.rmtree(Training_target_augmented)\n"," os.makedirs(Training_target_augmented)\n","\n","\n","# Here we generate the augmented images\n","#Load the images\n"," p = Augmentor.Pipeline(Training_source, Augmented_folder)\n","\n","#Define the matching images\n"," p.ground_truth(Training_target)\n","#Define the augmentation possibilities\n"," if not rotate_90_degrees == 0:\n"," p.rotate90(probability=rotate_90_degrees)\n"," \n"," if not rotate_270_degrees == 0:\n"," p.rotate270(probability=rotate_270_degrees)\n","\n"," if not flip_left_right == 0:\n"," p.flip_left_right(probability=flip_left_right)\n","\n"," if not flip_top_bottom == 0:\n"," p.flip_top_bottom(probability=flip_top_bottom)\n","\n"," if not random_zoom == 0:\n"," p.zoom_random(probability=random_zoom, percentage_area=random_zoom_magnification)\n"," \n"," if not random_distortion == 0:\n"," p.random_distortion(probability=random_distortion, grid_width=4, grid_height=4, magnitude=8)\n","\n"," if not image_shear == 0:\n"," p.shear(probability=image_shear,max_shear_left=20,max_shear_right=20)\n"," \n"," if not skew_image == 0:\n"," p.skew(probability=skew_image,magnitude=skew_image_magnitude)\n","\n"," p.sample(int(Nb_augmented_files))\n","\n"," print(int(Nb_augmented_files),\"matching images generated\")\n","\n","# Here we sort through the images and move them back to augmented trainning source and targets folders\n","\n"," augmented_files = os.listdir(Augmented_folder)\n","\n"," for f in augmented_files:\n","\n"," if (f.startswith(\"_groundtruth_(1)_\")):\n"," shortname_noprefix = f[17:]\n"," shutil.copyfile(Augmented_folder+\"/\"+f, Training_target_augmented+\"/\"+shortname_noprefix) \n"," if not (f.startswith(\"_groundtruth_(1)_\")):\n"," shutil.copyfile(Augmented_folder+\"/\"+f, Training_source_augmented+\"/\"+f)\n"," \n","\n"," for filename in os.listdir(Training_source_augmented):\n"," os.chdir(Training_source_augmented)\n"," os.rename(filename, filename.replace('_original', ''))\n"," \n"," #Here we clean up the extra files\n"," shutil.rmtree(Augmented_folder)\n","\n","if not Use_Data_augmentation:\n"," print(bcolors.WARNING+\"Data augmentation disabled\") \n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"3L9zSGtORKYI"},"source":["\n","## **3.3. Using weights from a pre-trained model as initial weights**\n","---\n"," Here, you can set the the path to a pre-trained model from which the weights can be extracted and used as a starting point for this training session. **This pre-trained model needs to be a pix2pix model**. \n","\n"," This option allows you to perform training over multiple Colab runtimes or to do transfer learning using models trained outside of ZeroCostDL4Mic. **You do not need to run this section if you want to train a network from scratch**.\n"]},{"cell_type":"code","metadata":{"id":"9vC2n-HeLdiJ","cellView":"form"},"source":["# @markdown ##Loading weights from a pre-trained network\n","\n","\n","Use_pretrained_model = False #@param {type:\"boolean\"}\n","\n","\n","#@markdown ###If yes, please provide the path to the model folder:\n","pretrained_model_path = \"\" #@param {type:\"string\"}\n","\n","# --------------------- Check if we load a previously trained model ------------------------\n","if Use_pretrained_model:\n","\n"," h5_file_path = os.path.join(pretrained_model_path, \"latest_net_G.pth\")\n"," \n","\n","# --------------------- Check the model exist ------------------------\n","\n"," if not os.path.exists(h5_file_path):\n"," print(bcolors.WARNING+'WARNING: Pretrained model does not exist')\n"," Use_pretrained_model = False\n"," print(bcolors.WARNING+'No pretrained network will be used.')\n","\n"," if os.path.exists(h5_file_path):\n"," print(\"Pretrained model \"+os.path.basename(pretrained_model_path)+\" was found and will be loaded prior to training.\")\n"," \n","else:\n"," print(bcolors.WARNING+'No pretrained network will be used.')\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"MCGklf1vZf2M"},"source":["#**4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"1KYOuygETJkT"},"source":["## **4.1. Prepare the training data for training**\n","---\n","Here, we use the information from Section 3 to prepare the training data into a suitable format for training. **Your data will be copied in the google Colab \"content\" folder which may take some time depending on the size of your dataset.**\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"lIUAOJ_LMv5E","cellView":"form"},"source":["#@markdown ##Prepare the data for training\n","\n","\n","# --------------------- Here we load the augmented data or the raw data ------------------------\n","\n","if Use_Data_augmentation:\n"," Training_source_dir = Training_source_augmented\n"," Training_target_dir = Training_target_augmented\n","\n","if not Use_Data_augmentation:\n"," Training_source_dir = Training_source\n"," Training_target_dir = Training_target\n","# --------------------- ------------------------------------------------\n","\n","print(\"Data preparation in progress\")\n","\n","if os.path.exists(model_path+'/'+model_name):\n"," shutil.rmtree(model_path+'/'+model_name)\n","os.makedirs(model_path+'/'+model_name)\n","\n","#--------------- Here we move the files to trainA and train B ---------\n","\n","\n","print('Copying training source data...')\n","for f in tqdm(os.listdir(Training_source_dir)):\n"," shutil.copyfile(Training_source_dir+\"/\"+f, TrainA_Folder+\"/\"+f)\n","\n","print('Copying training target data...')\n","for f in tqdm(os.listdir(Training_target_dir)):\n"," shutil.copyfile(Training_target_dir+\"/\"+f, TrainB_Folder+\"/\"+f)\n","\n","#---------------------------------------------------------------------\n","\n","#--------------- Here we combined A and B images---------\n","os.chdir(\"/content\")\n","!python pytorch-CycleGAN-and-pix2pix/datasets/combine_A_and_B.py --fold_A \"$imageA_folder\" --fold_B \"$imageB_folder\" --fold_AB \"$imageAB_folder\"\n","\n","\n","\n","# pix2pix uses EPOCH without lr decay and EPOCH with lr decay, here we automatically choose half and half\n","\n","number_of_epochs_lr_stable = int(number_of_epochs/2)\n","number_of_epochs_lr_decay = int(number_of_epochs/2)\n","\n","if Use_pretrained_model :\n"," for f in os.listdir(pretrained_model_path):\n"," if (f.startswith(\"latest_net_\")): \n"," shutil.copyfile(pretrained_model_path+\"/\"+f, model_path+'/'+model_name+\"/\"+f)\n","\n","#Export of pdf summary of training parameters\n","pdf_export(augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)\n","\n","print('------------------------')\n","print(\"Data ready for training\")\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"0Dfn8ZsEMv5d"},"source":["## **4.2. Start Training**\n","---\n","When playing the cell below you should see updates after each epoch (round). Network training can take some time.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches or continue the training in a second Colab session. **Pix2pix will save model checkpoints every 5 epochs.**\n","\n","Once training is complete, the trained model is automatically saved on your Google Drive, in the **model_path** folder that was selected in Section 3. It is however wise to download the folder as all data can be erased at the next training if using the same folder."]},{"cell_type":"code","metadata":{"scrolled":true,"id":"iwNmp1PUzRDQ","cellView":"form"},"source":["#@markdown ##Start training\n","\n","start = time.time()\n","\n","os.chdir(\"/content\")\n","\n","#--------------------------------- Command line inputs to change pix2pix paramaters------------\n","\n"," # basic parameters\n"," #('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')\n"," #('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')\n"," #('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')\n"," #('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')\n"," \n"," # model parameters\n"," #('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]')\n"," #('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')\n"," #('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')\n"," #('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')\n"," #('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')\n"," #('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')\n"," #('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]')\n"," #('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')\n"," #('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]')\n"," #('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]')\n"," #('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')\n"," #('--no_dropout', action='store_true', help='no dropout for the generator')\n"," \n"," # dataset parameters\n"," #('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')\n"," #('--direction', type=str, default='AtoB', help='AtoB or BtoA')\n"," #('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')\n"," #('--num_threads', default=4, type=int, help='# threads for loading data')\n"," #('--batch_size', type=int, default=1, help='input batch size')\n"," #('--load_size', type=int, default=286, help='scale images to this size')\n"," #('--crop_size', type=int, default=256, help='then crop to this size')\n"," #('--max_dataset_size', type=int, default=float(\"inf\"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')\n"," #('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')\n"," #('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')\n"," #('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')\n"," \n"," # additional parameters\n"," #('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')\n"," #('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')\n"," #('--verbose', action='store_true', help='if specified, print more debugging information')\n"," #('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')\n"," \n"," # visdom and HTML visualization parameters\n"," #('--display_freq', type=int, default=400, help='frequency of showing training results on screen')\n"," #('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.')\n"," #('--display_id', type=int, default=1, help='window id of the web display')\n"," #('--display_server', type=str, default=\"http://localhost\", help='visdom server of the web display')\n"," #('--display_env', type=str, default='main', help='visdom display environment name (default is \"main\")')\n"," #('--display_port', type=int, default=8097, help='visdom port of the web display')\n"," #('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html')\n"," #('--print_freq', type=int, default=100, help='frequency of showing training results on console')\n"," #('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')\n"," \n"," # network saving and loading parameters\n"," #('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')\n"," #('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')\n"," #('--save_by_iter', action='store_true', help='whether saves model by iteration')\n"," #('--continue_train', action='store_true', help='continue training: load the latest model')\n"," #('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...')\n"," #('--phase', type=str, default='train', help='train, val, test, etc')\n"," \n"," # training parameters\n"," #('--n_epochs', type=int, default=100, help='number of epochs with the initial learning rate')\n"," #('--n_epochs_decay', type=int, default=100, help='number of epochs to linearly decay learning rate to zero')\n"," #('--beta1', type=float, default=0.5, help='momentum term of adam')\n"," #('--lr', type=float, default=0.0002, help='initial learning rate for adam')\n"," #('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.')\n"," #('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images')\n"," #('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]')\n"," #('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations'\n","\n","#---------------------------------------------------------\n","\n","#----- Start the training ------------------------------------\n","if not Use_pretrained_model:\n"," !python pytorch-CycleGAN-and-pix2pix/train.py --dataroot \"$imageAB_folder\" --name $model_name --model pix2pix --batch_size $batch_size --preprocess scale_width_and_crop --load_size $Image_min_dim --crop_size $patch_size --checkpoints_dir \"$model_path\" --no_html --n_epochs $number_of_epochs_lr_stable --n_epochs_decay $number_of_epochs_lr_decay --lr $initial_learning_rate --display_id 0 --save_epoch_freq 5\n","\n","if Use_pretrained_model:\n"," !python pytorch-CycleGAN-and-pix2pix/train.py --dataroot \"$imageAB_folder\" --name $model_name --model pix2pix --batch_size $batch_size --preprocess scale_width_and_crop --load_size $Image_min_dim --crop_size $patch_size --checkpoints_dir \"$model_path\" --no_html --n_epochs $number_of_epochs_lr_stable --n_epochs_decay $number_of_epochs_lr_decay --lr $initial_learning_rate --display_id 0 --save_epoch_freq 5 --continue_train\n","\n","\n","#---------------------------------------------------------\n","\n","print(\"Training, done.\")\n","\n","# Displaying the time elapsed for training\n","dt = time.time() - start\n","mins, sec = divmod(dt, 60) \n","hour, mins = divmod(mins, 60) \n","print(\"Time elapsed:\",hour, \"hour(s)\",mins,\"min(s)\",round(sec),\"sec(s)\")\n","\n","# Export pdf summary after training to update document\n","pdf_export(trained = True, augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"_0Hynw3-xHp1"},"source":["# **5. Evaluate your model**\n","---\n","\n","This section allows the user to perform important quality checks on the validity and generalisability of the trained model. \n","\n","**We highly recommend to perform quality control on all newly trained models.**\n","\n"]},{"cell_type":"markdown","metadata":{"id":"HQqBkYzT4hQS"},"source":["## **5.1. Choose the model you want to assess**"]},{"cell_type":"code","metadata":{"id":"eAJzMwPA6tlH","cellView":"form"},"source":["# model name and path\n","#@markdown ###Do you want to assess the model you just trained ?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","QC_model_folder = \"\" #@param {type:\"string\"}\n","\n","#Here we define the loaded model name and path\n","QC_model_name = os.path.basename(QC_model_folder)\n","QC_model_path = os.path.dirname(QC_model_folder)\n","\n","if (Use_the_current_trained_model): \n"," QC_model_name = model_name\n"," QC_model_path = model_path\n","\n","full_QC_model_path = QC_model_path+'/'+QC_model_name+'/'\n","if os.path.exists(full_QC_model_path):\n"," print(\"The \"+QC_model_name+\" network will be evaluated\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"kittWWbs4pc8"},"source":["## **5.2. Identify the best checkpoint to use to make predictions**"]},{"cell_type":"markdown","metadata":{"id":"SeGNGf4A4ukf"},"source":[" Pix2pix save model checkpoints every five epochs. Due to the stochastic nature of GAN networks, the last checkpoint is not always the best one to use. As a consequence, it can be challenging to choose the most suitable checkpoint to use to make predictions.\n","\n","This section allows you to perform predictions using all the saved checkpoints and to estimate the quality of these predictions by comparing them to the provided ground truths images. Metric used include:\n","\n","**1. The SSIM (structural similarity) map** \n","\n","The SSIM metric is used to evaluate whether two images contain the same structures. It is a normalized metric and an SSIM of 1 indicates a perfect similarity between two images. Therefore for SSIM, the closer to 1, the better. The SSIM maps are constructed by calculating the SSIM metric in each pixel by considering the surrounding structural similarity in the neighbourhood of that pixel (currently defined as window of 11 pixels and with Gaussian weighting of 1.5 pixel standard deviation, see our Wiki for more info). \n","\n","**mSSIM** is the SSIM value calculated across the entire window of both images.\n","\n","**The output below shows the SSIM maps with the mSSIM**\n","\n","**2. The RSE (Root Squared Error) map** \n","\n","This is a display of the root of the squared difference between the normalized predicted and target or the source and the target. In this case, a smaller RSE is better. A perfect agreement between target and prediction will lead to an RSE map showing zeros everywhere (dark).\n","\n","\n","**NRMSE (normalised root mean squared error)** gives the average difference between all pixels in the images compared to each other. Good agreement yields low NRMSE scores.\n","\n","**PSNR (Peak signal-to-noise ratio)** is a metric that gives the difference between the ground truth and prediction (or source input) in decibels, using the peak pixel values of the prediction and the MSE between the images. The higher the score the better the agreement.\n","\n","**The output below shows the RSE maps with the NRMSE and PSNR values.**\n","\n"]},{"cell_type":"code","metadata":{"id":"VfF_oMpI4-Xl","cellView":"form"},"source":["#@markdown ##Choose the folders that contain your Quality Control dataset\n","\n","import glob\n","import os.path\n","from scipy import stats\n","\n","#@markdown ###Path to images:\n","\n","Source_QC_folder = \"\" #@param{type:\"string\"}\n","Target_QC_folder = \"\" #@param{type:\"string\"}\n","\n","#@markdown ###Type of images:\n","\n","Image_type = \"RGB\" #@param [\"Grayscale\", \"RGB\"]\n","\n","#@markdown ###Image normalisation:\n","\n","Normalisation_QC_source = \"None\" #@param [\"None\", \"Contrast stretching\", \"Adaptive Equalization\"]\n","Normalisation_QC_target = \"None\" #@param [\"None\", \"Contrast stretching\", \"Adaptive Equalization\"]\n","\n","\n","#@markdown ###Advanced Parameters\n","\n","patch_size_QC = 1024#@param {type:\"number\"} # in pixels\n","Do_lpips_analysis = True #@param {type:\"boolean\"}\n","\n","\n","\n","# Create a quality control folder\n","\n","if os.path.exists(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\"):\n"," shutil.rmtree(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n","\n","os.makedirs(QC_model_path+\"/\"+QC_model_name+\"/Quality Control\")\n","\n","\n","# Create a quality control/Prediction Folder\n","\n","QC_prediction_results = QC_model_path+\"/\"+QC_model_name+\"/Quality Control/Prediction\"\n","\n","if os.path.exists(QC_prediction_results):\n"," shutil.rmtree(QC_prediction_results)\n","\n","os.makedirs(QC_prediction_results)\n","\n","# Here we count how many images are in our folder to be predicted and we had a few\n","Nb_files_Data_folder = len(os.listdir(Source_QC_folder)) +10\n","\n","# List images in Source_QC_folder\n","# This will find the image dimension of a randomly choosen image in Source_QC_folder \n","random_choice = random.choice(os.listdir(Source_QC_folder))\n","x = imageio.imread(Source_QC_folder+\"/\"+random_choice)\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","Image_min_dim = min(Image_Y, Image_X)\n","\n","# Here we need to move the data to be analysed so that pix2pix can find them\n","\n","Saving_path_QC= \"/content/\"+QC_model_name+\"_images\"\n","\n","if os.path.exists(Saving_path_QC):\n"," shutil.rmtree(Saving_path_QC)\n","os.makedirs(Saving_path_QC)\n","\n","Saving_path_QC_folder = Saving_path_QC+\"/QC\"\n","\n","if os.path.exists(Saving_path_QC_folder):\n"," shutil.rmtree(Saving_path_QC_folder)\n","os.makedirs(Saving_path_QC_folder)\n","\n","imageA_folder = Saving_path_QC_folder+\"/A\"\n","os.makedirs(imageA_folder)\n","\n","imageB_folder = Saving_path_QC_folder+\"/B\"\n","os.makedirs(imageB_folder)\n","\n","imageAB_folder = Saving_path_QC_folder+\"/AB\"\n","os.makedirs(imageAB_folder)\n","\n","testAB_folder = Saving_path_QC_folder+\"/AB/test\"\n","os.makedirs(testAB_folder)\n","\n","testA_Folder = Saving_path_QC_folder+\"/A/test\"\n","os.makedirs(testA_Folder)\n"," \n","testB_Folder = Saving_path_QC_folder+\"/B/test\"\n","os.makedirs(testB_Folder)\n","\n","QC_checkpoint_folders = \"/content/\"+QC_model_name\n","\n","if os.path.exists(QC_checkpoint_folders):\n"," shutil.rmtree(QC_checkpoint_folders)\n","os.makedirs(QC_checkpoint_folders)\n","\n","#Here we copy and normalise the data\n","\n","if Normalisation_QC_source == \"Contrast stretching\":\n"," \n"," for filename in os.listdir(Source_QC_folder):\n","\n"," img = imread(os.path.join(Source_QC_folder,filename)).astype(np.float32)\n"," short_name = os.path.splitext(filename)\n","\n"," p2, p99 = np.percentile(img, (2, 99.9))\n"," img = exposure.rescale_intensity(img, in_range=(p2, p99))\n","\n"," img = 255 * img # Now scale by 255\n"," img = img.astype(np.uint8)\n"," cv2.imwrite(testA_Folder+\"/\"+short_name[0]+\".png\", img)\n"," \n","if Normalisation_QC_target == \"Contrast stretching\":\n","\n"," for filename in os.listdir(Target_QC_folder):\n","\n"," img = imread(os.path.join(Target_QC_folder,filename)).astype(np.float32)\n"," short_name = os.path.splitext(filename)\n","\n"," p2, p99 = np.percentile(img, (2, 99.9))\n"," img = exposure.rescale_intensity(img, in_range=(p2, p99))\n","\n"," img = 255 * img # Now scale by 255\n"," img = img.astype(np.uint8)\n"," cv2.imwrite(testB_Folder+\"/\"+short_name[0]+\".png\", img)\n","\n","if Normalisation_QC_source == \"Adaptive Equalization\":\n"," for filename in os.listdir(Source_QC_folder):\n","\n"," img = imread(os.path.join(Source_QC_folder,filename))\n"," short_name = os.path.splitext(filename)\n","\n"," img = exposure.equalize_adapthist(img, clip_limit=0.03)\n","\n"," img = 255 * img # Now scale by 255\n"," img = img.astype(np.uint8)\n"," cv2.imwrite(testA_Folder+\"/\"+short_name[0]+\".png\", img)\n","\n","\n","if Normalisation_QC_target == \"Adaptive Equalization\":\n","\n"," for filename in os.listdir(Target_QC_folder):\n","\n"," img = imread(os.path.join(Target_QC_folder,filename))\n"," short_name = os.path.splitext(filename)\n","\n"," p2, p99 = np.percentile(img, (2, 99.8))\n"," img = exposure.rescale_intensity(img, in_range=(p2, p99))\n","\n"," img = 255 * img # Now scale by 255\n"," img = img.astype(np.uint8)\n"," cv2.imwrite(testB_Folder+\"/\"+short_name[0]+\".png\", img)\n","\n","if Normalisation_QC_source == \"None\":\n"," for files in os.listdir(Source_QC_folder):\n"," shutil.copyfile(Source_QC_folder+\"/\"+files, testA_Folder+\"/\"+files)\n","\n","if Normalisation_QC_target == \"None\":\n"," for files in os.listdir(Target_QC_folder):\n"," shutil.copyfile(Target_QC_folder+\"/\"+files, testB_Folder+\"/\"+files)\n","\n","\n","#Here we create a merged folder containing only imageA\n","os.chdir(\"/content\")\n","\n","!python pytorch-CycleGAN-and-pix2pix/datasets/combine_A_and_B.py --fold_A \"$imageA_folder\" --fold_B \"$imageB_folder\" --fold_AB \"$imageAB_folder\"\n","\n","# This will find the image dimension of a randomly choosen image in Source_QC_folder \n","random_choice = random.choice(os.listdir(Source_QC_folder))\n","x = imageio.imread(Source_QC_folder+\"/\"+random_choice)\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","Image_min_dim = int(min(Image_Y, Image_X))\n","\n","if not patch_size_QC % 256 == 0:\n"," patch_size_QC = ((int(patch_size_QC / 256)) * 256)\n"," print (\" Your image dimensions are not divisible by 256; therefore your images have now been resized to:\",patch_size_QC)\n","\n","if patch_size_QC < 256:\n"," patch_size_QC = 256\n","\n","Nb_Checkpoint = len(glob.glob(os.path.join(full_QC_model_path, '*G.pth')))\n","\n","print(Nb_Checkpoint)\n","\n","## Initiate lists\n","\n","Checkpoint_list = []\n","Average_ssim_score_list = []\n","Average_lpips_score_list = []\n","\n","for j in range(1, len(glob.glob(os.path.join(full_QC_model_path, '*G.pth')))+1):\n"," checkpoints = j*5\n","\n"," if checkpoints == Nb_Checkpoint*5:\n"," checkpoints = \"latest\"\n","\n"," print(\"The checkpoint currently analysed is =\"+str(checkpoints))\n","\n"," Checkpoint_list.append(checkpoints)\n","\n"," # Create a quality control/Prediction Folder\n","\n"," QC_prediction_results = QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)\n","\n"," if os.path.exists(QC_prediction_results):\n"," shutil.rmtree(QC_prediction_results)\n","\n"," os.makedirs(QC_prediction_results)\n","\n","#---------------------------- Predictions are performed here ----------------------\n"," os.chdir(\"/content\")\n"," !python pytorch-CycleGAN-and-pix2pix/test.py --dataroot \"$imageAB_folder\" --name \"$QC_model_name\" --model pix2pix --epoch $checkpoints --no_dropout --preprocess scale_width --load_size $patch_size_QC --crop_size $patch_size_QC --results_dir \"$QC_prediction_results\" --checkpoints_dir \"$QC_model_path\" --direction AtoB --num_test $Nb_files_Data_folder\n","#-----------------------------------------------------------------------------------\n","\n","#Here we need to move the data again and remove all the unnecessary folders\n","\n"," Checkpoint_name = \"test_\"+str(checkpoints)\n","\n"," QC_results_images = QC_prediction_results+\"/\"+QC_model_name+\"/\"+Checkpoint_name+\"/images\"\n","\n"," QC_results_images_files = os.listdir(QC_results_images)\n","\n"," for f in QC_results_images_files: \n"," shutil.copyfile(QC_results_images+\"/\"+f, QC_prediction_results+\"/\"+f)\n","\n"," os.chdir(\"/content\") \n","\n"," #Here we clean up the extra files\n"," shutil.rmtree(QC_prediction_results+\"/\"+QC_model_name)\n","\n"," #-------------------------------- QC for RGB ------------------------------------\n"," if Image_type == \"RGB\":\n","# List images in Source_QC_folder\n","# This will find the image dimension of a randomly choosen image in Source_QC_folder \n"," random_choice = random.choice(os.listdir(Source_QC_folder))\n"," x = imageio.imread(Source_QC_folder+\"/\"+random_choice)\n","\n"," def ssim(img1, img2):\n"," return structural_similarity(img1,img2,data_range=1.,full=True, multichannel=True)\n","\n","# Open and create the csv file that will contain all the QC metrics\n"," with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\" , \"Prediction v. GT lpips\", \"Input v. GT lpips\"])\n"," \n"," \n"," # Initiate list\n"," ssim_score_list = []\n"," lpips_score_list = [] \n","\n","\n"," # Let's loop through the provided dataset in the QC folders\n","\n","\n"," for i in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder,i)):\n"," print('Running QC on: '+i)\n","\n"," shortname_no_PNG = i[:-4]\n"," \n"," # -------------------------------- Target test data (Ground truth) --------------------------------\n"," \n"," test_GT = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), shortname_no_PNG+\"_real_B.png\"))\n","\n"," # -------------------------------- Source test data --------------------------------\n"," test_source = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_real_A.png\"))\n"," \n"," \n"," # -------------------------------- Prediction --------------------------------\n"," \n"," test_prediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_fake_B.png\"))\n"," \n"," #--------------------------- Here we normalise using histograms matching--------------------------------\n"," test_prediction_matched = match_histograms(test_prediction, test_GT, multichannel=True)\n"," test_source_matched = match_histograms(test_source, test_GT, multichannel=True)\n"," \n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\n","\n"," # Calculate the SSIM maps\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT, test_prediction_matched)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT, test_source_matched)\n","\n"," ssim_score_list.append(index_SSIM_GTvsPrediction)\n","\n"," #Save ssim_maps\n"," img_SSIM_GTvsPrediction_8bit = (img_SSIM_GTvsPrediction* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsPrediction_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsPrediction_8bit)\n"," img_SSIM_GTvsSource_8bit = (img_SSIM_GTvsSource* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsSource_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsSource_8bit)\n","\n"," # -------------------------------- Pearson correlation coefficient --------------------------------\n","\n","\n","\n","\n","\n"," # -------------------------------- Calculate the perceptual difference metrics map and save them --------------------------------\n"," if Do_lpips_analysis:\n","\n"," lpips_GTvsPrediction = perceptual_diff(test_GT, test_prediction, 'alex', True)\n"," lpips_GTvsPrediction_image = lpips_GTvsPrediction[0,0,...].data.cpu().numpy()\n"," lpips_GTvsPrediction_score= lpips_GTvsPrediction.mean().data.numpy()\n"," lpips_score_list.append(lpips_GTvsPrediction_score)\n","\n","\n"," lpips_GTvsSource = perceptual_diff(test_GT, test_source, 'alex', True)\n"," lpips_GTvsSource_image = lpips_GTvsSource[0,0,...].data.cpu().numpy()\n"," lpips_GTvsSource_score= lpips_GTvsSource.mean().data.numpy()\n","\n","\n"," #lpips_GTvsPrediction_image_8bit = (lpips_GTvsPrediction_image* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/lpips_GTvsPrediction_\"+shortname_no_PNG+'.tif',lpips_GTvsPrediction_image)\n","\n"," #lpips_GTvsSource_image_8bit = (lpips_GTvsSource_image* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/lpips_GTvsInput_\"+shortname_no_PNG+'.tif',lpips_GTvsSource_image)\n"," else:\n"," lpips_GTvsPrediction_score = 0\n"," lpips_score_list.append(lpips_GTvsPrediction_score)\n","\n"," lpips_GTvsSource_score = 0\n","\n","\n"," \n"," writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource), str(lpips_GTvsPrediction_score),str(lpips_GTvsSource_score)])\n","\n"," #Here we calculate the ssim average for each image in each checkpoints\n","\n"," Average_SSIM_checkpoint = Average(ssim_score_list)\n"," Average_ssim_score_list.append(Average_SSIM_checkpoint)\n","\n"," Average_lpips_checkpoint = Average(lpips_score_list)\n"," Average_lpips_score_list.append(Average_lpips_checkpoint)\n","\n","#------------------------------------------- QC for Grayscale ----------------------------------------------\n","\n"," if Image_type == \"Grayscale\":\n"," def ssim(img1, img2):\n"," return structural_similarity(img1,img2,data_range=1.,full=True, gaussian_weights=True, use_sample_covariance=False, sigma=1.5)\n","\n","\n"," def normalize(x, pmin=3, pmax=99.8, axis=None, clip=False, eps=1e-20, dtype=np.float32):\n","\n","\n"," mi = np.percentile(x,pmin,axis=axis,keepdims=True)\n"," ma = np.percentile(x,pmax,axis=axis,keepdims=True)\n"," return normalize_mi_ma(x, mi, ma, clip=clip, eps=eps, dtype=dtype)\n","\n","\n"," def normalize_mi_ma(x, mi, ma, clip=False, eps=1e-20, dtype=np.float32):#dtype=np.float32\n"," \n"," if dtype is not None:\n"," x = x.astype(dtype,copy=False)\n"," mi = dtype(mi) if np.isscalar(mi) else mi.astype(dtype,copy=False)\n"," ma = dtype(ma) if np.isscalar(ma) else ma.astype(dtype,copy=False)\n"," eps = dtype(eps)\n","\n"," try:\n"," import numexpr\n"," x = numexpr.evaluate(\"(x - mi) / ( ma - mi + eps )\")\n"," except ImportError:\n"," x = (x - mi) / ( ma - mi + eps )\n","\n"," if clip:\n"," x = np.clip(x,0,1)\n","\n"," return x\n","\n"," def norm_minmse(gt, x, normalize_gt=True):\n"," \n"," if normalize_gt:\n"," gt = normalize(gt, 0.1, 99.9, clip=False).astype(np.float32, copy = False)\n"," x = x.astype(np.float32, copy=False) - np.mean(x)\n"," #x = x - np.mean(x)\n"," gt = gt.astype(np.float32, copy=False) - np.mean(gt)\n"," #gt = gt - np.mean(gt)\n"," scale = np.cov(x.flatten(), gt.flatten())[0, 1] / np.var(x.flatten())\n"," return gt, scale * x\n","\n","\n","# Open and create the csv file that will contain all the QC metrics\n"," with open(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", \"w\", newline='') as file:\n"," writer = csv.writer(file)\n","\n"," # Write the header in the csv file\n"," writer.writerow([\"image #\",\"Prediction v. GT mSSIM\",\"Input v. GT mSSIM\", \"Prediction v. GT NRMSE\", \"Input v. GT NRMSE\", \"Prediction v. GT PSNR\", \"Input v. GT PSNR\", \"Prediction v. GT lpips\", \"Input v. GT lpips\"]) \n","\n"," # Initialize the lists\n"," ssim_score_list = []\n"," Pearson_correlation_coefficient_list = []\n"," lpips_score_list = []\n"," \n"," # Let's loop through the provided dataset in the QC folders\n","\n"," for i in os.listdir(Source_QC_folder):\n"," if not os.path.isdir(os.path.join(Source_QC_folder,i)):\n"," print('Running QC on: '+i)\n","\n"," shortname_no_PNG = i[:-4]\n"," # -------------------------------- Target test data (Ground truth) --------------------------------\n"," test_GT_raw = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), shortname_no_PNG+\"_real_B.png\")) \n"," test_GT = test_GT_raw[:,:,2]\n","\n"," # -------------------------------- Source test data --------------------------------\n"," test_source_raw = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_real_A.png\")) \n"," test_source = test_source_raw[:,:,2]\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and Source image\n"," test_GT_norm,test_source_norm = norm_minmse(test_GT, test_source, normalize_gt=True)\n","\n"," # -------------------------------- Prediction --------------------------------\n"," test_prediction_raw = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints),shortname_no_PNG+\"_fake_B.png\"))\n"," \n"," test_prediction = test_prediction_raw[:,:,2]\n","\n"," # Normalize the images wrt each other by minimizing the MSE between GT and prediction\n"," test_GT_norm,test_prediction_norm = norm_minmse(test_GT, test_prediction, normalize_gt=True) \n","\n","\n"," # -------------------------------- Calculate the metric maps and save them --------------------------------\n","\n"," # Calculate the SSIM maps\n"," index_SSIM_GTvsPrediction, img_SSIM_GTvsPrediction = ssim(test_GT_norm, test_prediction_norm)\n"," index_SSIM_GTvsSource, img_SSIM_GTvsSource = ssim(test_GT_norm, test_source_norm)\n","\n"," ssim_score_list.append(index_SSIM_GTvsPrediction)\n","\n"," #Save ssim_maps\n"," \n"," img_SSIM_GTvsPrediction_8bit = (img_SSIM_GTvsPrediction* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsPrediction_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsPrediction_8bit)\n"," img_SSIM_GTvsSource_8bit = (img_SSIM_GTvsSource* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/SSIM_GTvsSource_\"+shortname_no_PNG+'.tif',img_SSIM_GTvsSource_8bit)\n"," \n"," # Calculate the Root Squared Error (RSE) maps\n"," img_RSE_GTvsPrediction = np.sqrt(np.square(test_GT_norm - test_prediction_norm))\n"," img_RSE_GTvsSource = np.sqrt(np.square(test_GT_norm - test_source_norm))\n","\n"," # Save SE maps\n"," img_RSE_GTvsPrediction_8bit = (img_RSE_GTvsPrediction* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/RSE_GTvsPrediction_\"+shortname_no_PNG+'.tif',img_RSE_GTvsPrediction_8bit)\n"," img_RSE_GTvsSource_8bit = (img_RSE_GTvsSource* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/RSE_GTvsSource_\"+shortname_no_PNG+'.tif',img_RSE_GTvsSource_8bit)\n","\n","\n"," # -------------------------------- Calculate the RSE metrics and save them --------------------------------\n","\n"," # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)\n"," NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))\n"," NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))\n"," \n"," # We can also measure the peak signal to noise ratio between the images\n"," PSNR_GTvsPrediction = psnr(test_GT_norm,test_prediction_norm,data_range=1.0)\n"," PSNR_GTvsSource = psnr(test_GT_norm,test_source_norm,data_range=1.0)\n","\n","\n"," \n"," # -------------------------------- Calculate the perceptual difference metrics map and save them --------------------------------\n"," if Do_lpips_analysis:\n"," lpips_GTvsPrediction = perceptual_diff(test_GT_raw, test_prediction_raw, 'alex', True)\n"," lpips_GTvsPrediction_image = lpips_GTvsPrediction[0,0,...].data.cpu().numpy()\n"," lpips_GTvsPrediction_score= lpips_GTvsPrediction.mean().data.numpy()\n"," lpips_score_list.append(lpips_GTvsPrediction_score)\n","\n"," lpips_GTvsSource = perceptual_diff(test_GT_raw, test_source_raw, 'alex', True)\n"," lpips_GTvsSource_image = lpips_GTvsSource[0,0,...].data.cpu().numpy()\n"," lpips_GTvsSource_score= lpips_GTvsSource.mean().data.numpy()\n","\n","\n"," lpips_GTvsPrediction_image_8bit = (lpips_GTvsPrediction_image* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/lpips_GTvsPrediction_\"+shortname_no_PNG+'.tif',lpips_GTvsPrediction_image_8bit)\n","\n"," lpips_GTvsSource_image_8bit = (lpips_GTvsSource_image* 255).astype(\"uint8\")\n"," io.imsave(QC_model_path+'/'+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/lpips_GTvsInput_\"+shortname_no_PNG+'.tif',lpips_GTvsSource_image_8bit)\n"," else:\n"," lpips_GTvsPrediction_score = 0\n"," lpips_score_list.append(lpips_GTvsPrediction_score)\n","\n"," lpips_GTvsSource_score = 0\n","\n","\n"," writer.writerow([i,str(index_SSIM_GTvsPrediction),str(index_SSIM_GTvsSource),str(NRMSE_GTvsPrediction),str(NRMSE_GTvsSource),str(PSNR_GTvsPrediction),str(PSNR_GTvsSource),str(lpips_GTvsPrediction_score),str(lpips_GTvsSource_score)])\n","\n","\n"," #Here we calculate the ssim average for each image in each checkpoints\n","\n"," Average_SSIM_checkpoint = Average(ssim_score_list)\n"," Average_ssim_score_list.append(Average_SSIM_checkpoint)\n","\n"," Average_lpips_checkpoint = Average(lpips_score_list)\n"," Average_lpips_score_list.append(Average_lpips_checkpoint)\n","\n","\n","# All data is now processed saved\n"," \n","\n","# -------------------------------- Display --------------------------------\n","\n","# Display the IoV vs Checkpoint plot\n","plt.figure(figsize=(20,5))\n","plt.plot(Checkpoint_list, Average_ssim_score_list, label=\"SSIM\")\n","plt.title('Checkpoints vs. SSIM')\n","plt.ylabel('SSIM')\n","plt.xlabel('Checkpoints')\n","plt.legend()\n","plt.savefig(full_QC_model_path+'/Quality Control/SSIMvsCheckpoint_data.png',bbox_inches='tight',pad_inches=0)\n","plt.show()\n","\n","\n","# -------------------------------- Display --------------------------------\n","\n","if Do_lpips_analysis:\n"," # Display the lpips vs Checkpoint plot\n"," plt.figure(figsize=(20,5))\n"," plt.plot(Checkpoint_list, Average_lpips_score_list, label=\"lpips\")\n"," plt.title('Checkpoints vs. lpips')\n"," plt.ylabel('lpips')\n"," plt.xlabel('Checkpoints')\n"," plt.legend()\n"," plt.savefig(full_QC_model_path+'/Quality Control/lpipsvsCheckpoint_data.png',bbox_inches='tight',pad_inches=0)\n"," plt.show()\n","\n","\n","\n","# -------------------------------- Display RGB --------------------------------\n","\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","\n","\n","if Image_type == \"RGB\":\n"," random_choice_shortname_no_PNG = shortname_no_PNG\n","\n"," @interact\n"," def show_results(file=os.listdir(Source_QC_folder), checkpoints=Checkpoint_list):\n","\n"," random_choice_shortname_no_PNG = file[:-4]\n","\n"," df1 = pd.read_csv(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", header=0)\n"," df2 = df1.set_index(\"image #\", drop = False)\n"," index_SSIM_GTvsPrediction = df2.loc[file, \"Prediction v. GT mSSIM\"]\n"," index_SSIM_GTvsSource = df2.loc[file, \"Input v. GT mSSIM\"]\n"," lpips_GTvsPrediction = df2.loc[file, \"Prediction v. GT lpips\"]\n"," lpips_GTvsSource = df2.loc[file, \"Input v. GT lpips\"]\n","\n","#Setting up colours\n"," cmap = None\n","\n"," plt.figure(figsize=(15,15))\n","\n","# Target (Ground-truth)\n"," plt.subplot(3,3,1)\n"," plt.axis('off')\n"," img_GT = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_real_B.png\"), as_gray=False, pilmode=\"RGB\")\n"," \n"," plt.imshow(img_GT, cmap = cmap)\n"," plt.title('Target',fontsize=15)\n","\n","# Source\n"," plt.subplot(3,3,2)\n"," plt.axis('off')\n"," img_Source = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_real_A.png\"), as_gray=False, pilmode=\"RGB\")\n"," plt.imshow(img_Source, cmap = cmap)\n"," plt.title('Source',fontsize=15)\n","\n","#Prediction\n"," plt.subplot(3,3,3)\n"," plt.axis('off')\n","\n"," img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_fake_B.png\"))\n","\n"," plt.imshow(img_Prediction, cmap = cmap)\n"," plt.title('Prediction',fontsize=15)\n","\n","\n","#SSIM between GT and Source\n"," plt.subplot(3,3,5)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","\n"," img_SSIM_GTvsSource = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsSource_\"+random_choice_shortname_no_PNG+\".tif\"))\n","\n"," imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n","#plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Source',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\n"," plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n"," plt.subplot(3,3,6)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n","\n"," img_SSIM_GTvsPrediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\n","\n"," imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n","#plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\n","\n","#lpips Error between GT and source\n","\n"," if Do_lpips_analysis:\n"," plt.subplot(3,3,8)\n"," #plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","\n"," img_lpips_GTvsSource = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"lpips_GTvsInput_\"+random_choice_shortname_no_PNG+\".tif\"))\n","\n"," imlpips_GTvsSource = plt.imshow(img_lpips_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n"," plt.colorbar(imlpips_GTvsSource,fraction=0.046,pad=0.04)\n"," plt.title('Target vs. Input',fontsize=15)\n"," plt.xlabel('lpips: '+str(round(lpips_GTvsSource,3)),fontsize=14)\n"," plt.ylabel('Lpips maps',fontsize=20, rotation=0, labelpad=75)\n","\n","\n"," #lpips Error between GT and Prediction\n"," plt.subplot(3,3,9)\n"," #plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","\n"," img_lpips_GTvsPrediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"lpips_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\n","\n"," imlpips_GTvsPrediction = plt.imshow(img_lpips_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n"," plt.colorbar(imlpips_GTvsPrediction,fraction=0.046,pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('lpips: '+str(round(lpips_GTvsPrediction,3)),fontsize=14)\n","\n"," plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","\n","# -------------------------------- Display Grayscale --------------------------------\n","\n","if Image_type == \"Grayscale\":\n"," random_choice_shortname_no_PNG = shortname_no_PNG\n","\n"," @interact\n"," def show_results(file=os.listdir(Source_QC_folder), checkpoints=Checkpoint_list):\n","\n"," random_choice_shortname_no_PNG = file[:-4]\n","\n"," df1 = pd.read_csv(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints)+\"/\"+\"QC_metrics_\"+QC_model_name+str(checkpoints)+\".csv\", header=0)\n"," df2 = df1.set_index(\"image #\", drop = False)\n"," index_SSIM_GTvsPrediction = df2.loc[file, \"Prediction v. GT mSSIM\"]\n"," index_SSIM_GTvsSource = df2.loc[file, \"Input v. GT mSSIM\"]\n","\n"," NRMSE_GTvsPrediction = df2.loc[file, \"Prediction v. GT NRMSE\"]\n"," NRMSE_GTvsSource = df2.loc[file, \"Input v. GT NRMSE\"]\n"," PSNR_GTvsSource = df2.loc[file, \"Input v. GT PSNR\"]\n"," PSNR_GTvsPrediction = df2.loc[file, \"Prediction v. GT PSNR\"]\n"," lpips_GTvsPrediction = df2.loc[file, \"Prediction v. GT lpips\"]\n"," lpips_GTvsSource = df2.loc[file, \"Input v. GT lpips\"]\n","\n"," plt.figure(figsize=(20,20))\n"," # Currently only displays the last computed set, from memory\n"," # Target (Ground-truth)\n"," plt.subplot(4,3,1)\n"," plt.axis('off')\n"," img_GT = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_real_B.png\"))\n","\n"," plt.imshow(img_GT, norm=simple_norm(img_GT, percent = 99))\n"," plt.title('Target',fontsize=15)\n","\n","# Source\n"," plt.subplot(4,3,2)\n"," plt.axis('off')\n"," img_Source = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_real_A.png\"))\n"," plt.imshow(img_Source, norm=simple_norm(img_Source, percent = 99))\n"," plt.title('Source',fontsize=15)\n","\n","#Prediction\n"," plt.subplot(4,3,3)\n"," plt.axis('off')\n"," img_Prediction = io.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), random_choice_shortname_no_PNG+\"_fake_B.png\"))\n"," plt.imshow(img_Prediction, norm=simple_norm(img_Prediction, percent = 99))\n"," plt.title('Prediction',fontsize=15)\n","\n","#Setting up colours\n"," cmap = plt.cm.CMRmap\n","\n","#SSIM between GT and Source\n"," plt.subplot(4,3,5)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","\n"," img_SSIM_GTvsSource = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsSource_\"+random_choice_shortname_no_PNG+\".tif\"))\n"," img_SSIM_GTvsSource = img_SSIM_GTvsSource / 255\n"," imSSIM_GTvsSource = plt.imshow(img_SSIM_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n"," \n"," plt.colorbar(imSSIM_GTvsSource,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Source',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsSource,3)),fontsize=14)\n"," plt.ylabel('SSIM maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#SSIM between GT and Prediction\n"," plt.subplot(4,3,6)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False) \n"," \n"," img_SSIM_GTvsPrediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"SSIM_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\n"," img_SSIM_GTvsPrediction = img_SSIM_GTvsPrediction / 255\n"," imSSIM_GTvsPrediction = plt.imshow(img_SSIM_GTvsPrediction, cmap = cmap, vmin=0,vmax=1)\n"," \n"," plt.colorbar(imSSIM_GTvsPrediction,fraction=0.046, pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('mSSIM: '+str(round(index_SSIM_GTvsPrediction,3)),fontsize=14)\n","\n","#Root Squared Error between GT and Source\n"," plt.subplot(4,3,8)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","\n"," img_RSE_GTvsSource = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"RSE_GTvsSource_\"+random_choice_shortname_no_PNG+\".tif\"))\n"," img_RSE_GTvsSource = img_RSE_GTvsSource / 255\n","\n"," imRSE_GTvsSource = plt.imshow(img_RSE_GTvsSource, cmap = cmap, vmin=0, vmax = 1)\n"," plt.colorbar(imRSE_GTvsSource,fraction=0.046,pad=0.04)\n"," plt.title('Target vs. Source',fontsize=15)\n"," plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsSource,3))+', PSNR: '+str(round(PSNR_GTvsSource,3)),fontsize=14)\n","#plt.title('Target vs. Source PSNR: '+str(round(PSNR_GTvsSource,3)))\n"," plt.ylabel('RSE maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#Root Squared Error between GT and Prediction\n"," plt.subplot(4,3,9)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","\n"," img_RSE_GTvsPrediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"RSE_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\n","\n"," img_RSE_GTvsPrediction = img_RSE_GTvsPrediction / 255\n","\n"," imRSE_GTvsPrediction = plt.imshow(img_RSE_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n"," plt.colorbar(imRSE_GTvsPrediction,fraction=0.046,pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('NRMSE: '+str(round(NRMSE_GTvsPrediction,3))+', PSNR: '+str(round(PSNR_GTvsPrediction,3)),fontsize=14)\n","\n","#lpips Error between GT and source\n","\n"," if Do_lpips_analysis:\n"," plt.subplot(4,3,11)\n","\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","\n"," img_lpips_GTvsSource = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"lpips_GTvsInput_\"+random_choice_shortname_no_PNG+\".tif\"))\n","\n"," img_lpips_GTvsSource = img_lpips_GTvsSource / 255\n","\n"," imlpips_GTvsSource = plt.imshow(img_lpips_GTvsSource, cmap = cmap, vmin=0, vmax=1)\n"," plt.colorbar(imlpips_GTvsSource,fraction=0.046,pad=0.04)\n"," plt.title('Target vs. Input',fontsize=15)\n"," plt.xlabel('lpips: '+str(round(lpips_GTvsSource,3)),fontsize=14)\n"," plt.ylabel('Lpips maps',fontsize=20, rotation=0, labelpad=75)\n","\n","#lpips Error between GT and Prediction\n"," plt.subplot(4,3,12)\n","#plt.axis('off')\n"," plt.tick_params(\n"," axis='both', # changes apply to the x-axis and y-axis\n"," which='both', # both major and minor ticks are affected\n"," bottom=False, # ticks along the bottom edge are off\n"," top=False, # ticks along the top edge are off\n"," left=False, # ticks along the left edge are off\n"," right=False, # ticks along the right edge are off\n"," labelbottom=False,\n"," labelleft=False)\n","\n"," img_lpips_GTvsPrediction = imageio.imread(os.path.join(QC_model_path+\"/\"+QC_model_name+\"/Quality Control/\"+str(checkpoints), \"lpips_GTvsPrediction_\"+random_choice_shortname_no_PNG+\".tif\"))\n","\n"," img_lpips_GTvsPrediction = img_lpips_GTvsPrediction / 255\n","\n"," imlpips_GTvsPrediction = plt.imshow(img_lpips_GTvsPrediction, cmap = cmap, vmin=0, vmax=1)\n"," plt.colorbar(imlpips_GTvsPrediction,fraction=0.046,pad=0.04)\n"," plt.title('Target vs. Prediction',fontsize=15)\n"," plt.xlabel('lpips: '+str(round(lpips_GTvsPrediction,3)),fontsize=14)\n","\n"," plt.savefig(full_QC_model_path+'/Quality Control/QC_example_data.png',bbox_inches='tight',pad_inches=0)\n","\n","#Make a pdf summary of the QC results\n","\n","qc_pdf_export()\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"-tJeeJjLnRkP"},"source":["# **6. Using the trained model**\n","\n","---\n","\n","In this section the unseen data is processed using the trained model (in section 4). First, your unseen images are uploaded and prepared for prediction. After that your trained model from section 4 is activated and finally saved into your Google Drive."]},{"cell_type":"markdown","metadata":{"id":"d8wuQGjoq6eN"},"source":["## **6.1. Generate prediction(s) from unseen dataset**\n","---\n","\n","The current trained model (from section 4.2) can now be used to process images. If you want to use an older model, untick the **Use_the_current_trained_model** box and enter the name and path of the model to use. Predicted output images are saved in your **Result_folder** folder as PNG images.\n","\n","**`Data_folder`:** This folder should contain the images that you want to use your trained network on for processing.\n","\n","**`Result_folder`:** This folder will contain the predicted output images.\n","\n","**`checkpoint`:** Choose the checkpoint number you would like to use to perform predictions. To use the \"latest\" checkpoint, input \"latest\".\n"]},{"cell_type":"code","metadata":{"id":"y2TD5p7MZrEb","cellView":"form"},"source":["#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then play the cell to predict outputs from your unseen images.\n","import glob\n","import os.path\n","\n","latest = \"latest\"\n","\n","Data_folder = \"\" #@param {type:\"string\"}\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n","#@markdown ###Image normalisation:\n","\n","Normalisation_prediction_source = \"None\" #@param [\"None\", \"Contrast stretching\", \"Adaptive Equalization\"]\n","\n","# model name and path\n","#@markdown ###Do you want to use the current trained model?\n","Use_the_current_trained_model = True #@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please provide the path to the model folder:\n","\n","Prediction_model_folder = \"\" #@param {type:\"string\"}\n","\n","#@markdown ###What model checkpoint would you like to use?\n","\n","checkpoint = latest#@param {type:\"raw\"}\n","\n","#@markdown ###Advanced Parameters\n","\n","patch_size = 512#@param {type:\"number\"} # in pixels\n","\n","#Here we find the loaded model name and parent path\n","Prediction_model_name = os.path.basename(Prediction_model_folder)\n","Prediction_model_path = os.path.dirname(Prediction_model_folder)\n","\n","#here we check if we use the newly trained network or not\n","if (Use_the_current_trained_model): \n"," print(\"Using current trained network\")\n"," Prediction_model_name = model_name\n"," Prediction_model_path = model_path\n","\n","if not patch_size % 256 == 0:\n"," patch_size = ((int(patch_size / 256)) * 256)\n"," print (\" Your image dimensions are not divisible by 256; therefore your images have now been resized to:\",patch_size_QC)\n","\n","if patch_size < 256:\n"," patch_size = 256\n","\n","#here we check if the model exists\n","full_Prediction_model_path = Prediction_model_path+'/'+Prediction_model_name+'/'\n","\n","if os.path.exists(full_Prediction_model_path):\n"," print(\"The \"+Prediction_model_name+\" network will be used.\")\n","else:\n"," W = '\\033[0m' # white (normal)\n"," R = '\\033[31m' # red\n"," print(R+'!! WARNING: The chosen model does not exist !!'+W)\n"," print('Please make sure you provide a valid model path and model name before proceeding further.')\n","\n","Nb_Checkpoint = len(glob.glob(os.path.join(full_Prediction_model_path, '*G.pth')))+1\n","\n","if not checkpoint == \"latest\":\n","\n"," if checkpoint < 10:\n"," checkpoint = 5\n","\n"," if not checkpoint % 5 == 0:\n"," checkpoint = ((int(checkpoint / 5)-1) * 5)\n"," print (bcolors.WARNING + \" Your chosen checkpoints is not divisible by 5; therefore the checkpoints chosen is now:\",checkpoints)\n"," \n"," if checkpoint == Nb_Checkpoint*5:\n"," checkpoint = \"latest\"\n","\n"," if checkpoint > Nb_Checkpoint*5:\n"," checkpoint = \"latest\"\n","\n","# Here we need to move the data to be analysed so that pix2pix can find them\n","\n","Saving_path_prediction= \"/content/\"+Prediction_model_name\n","\n","if os.path.exists(Saving_path_prediction):\n"," shutil.rmtree(Saving_path_prediction)\n","os.makedirs(Saving_path_prediction)\n","\n","imageA_folder = Saving_path_prediction+\"/A\"\n","os.makedirs(imageA_folder)\n","\n","imageB_folder = Saving_path_prediction+\"/B\"\n","os.makedirs(imageB_folder)\n","\n","imageAB_folder = Saving_path_prediction+\"/AB\"\n","os.makedirs(imageAB_folder)\n","\n","testAB_Folder = Saving_path_prediction+\"/AB/test\"\n","os.makedirs(testAB_Folder)\n","\n","testA_Folder = Saving_path_prediction+\"/A/test\"\n","os.makedirs(testA_Folder)\n"," \n","testB_Folder = Saving_path_prediction+\"/B/test\"\n","os.makedirs(testB_Folder)\n","\n","#Here we copy and normalise the data\n","\n","if Normalisation_prediction_source == \"Contrast stretching\":\n"," \n"," for filename in os.listdir(Data_folder):\n","\n"," img = imread(os.path.join(Data_folder,filename)).astype(np.float32)\n"," short_name = os.path.splitext(filename)\n","\n"," p2, p99 = np.percentile(img, (2, 99.9))\n"," img = exposure.rescale_intensity(img, in_range=(p2, p99))\n","\n"," img = 255 * img # Now scale by 255\n"," img = img.astype(np.uint8)\n"," cv2.imwrite(testA_Folder+\"/\"+short_name[0]+\".png\", img)\n"," cv2.imwrite(testB_Folder+\"/\"+short_name[0]+\".png\", img)\n"," \n","if Normalisation_prediction_source == \"Adaptive Equalization\":\n","\n"," for filename in os.listdir(Data_folder):\n","\n"," img = imread(os.path.join(Data_folder,filename))\n"," short_name = os.path.splitext(filename)\n","\n"," img = exposure.equalize_adapthist(img, clip_limit=0.03)\n","\n"," img = 255 * img # Now scale by 255\n"," img = img.astype(np.uint8)\n","\n"," cv2.imwrite(testA_Folder+\"/\"+short_name[0]+\".png\", img)\n"," cv2.imwrite(testB_Folder+\"/\"+short_name[0]+\".png\", img)\n","\n","if Normalisation_prediction_source == \"None\":\n"," for files in os.listdir(Data_folder):\n"," shutil.copyfile(Data_folder+\"/\"+files, testA_Folder+\"/\"+files)\n"," shutil.copyfile(Data_folder+\"/\"+files, testB_Folder+\"/\"+files)\n"," \n","# Here we create a merged A / A image for the prediction\n","os.chdir(\"/content\")\n","!python pytorch-CycleGAN-and-pix2pix/datasets/combine_A_and_B.py --fold_A \"$imageA_folder\" --fold_B \"$imageB_folder\" --fold_AB \"$imageAB_folder\"\n","\n","# Here we count how many images are in our folder to be predicted and we had a few\n","Nb_files_Data_folder = len(os.listdir(Data_folder)) +10\n","\n","# This will find the image dimension of a randomly choosen image in Data_folder \n","random_choice = random.choice(os.listdir(Data_folder))\n","x = imageio.imread(Data_folder+\"/\"+random_choice)\n","\n","#Find image XY dimension\n","Image_Y = x.shape[0]\n","Image_X = x.shape[1]\n","\n","Image_min_dim = min(Image_Y, Image_X)\n","\n","\n","#-------------------------------- Perform predictions -----------------------------\n","\n","#-------------------------------- Options that can be used to perform predictions -----------------------------\n","\n","# basic parameters\n"," #('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')\n"," #('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')\n"," #('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')\n"," #('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')\n","\n","# model parameters\n"," #('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]')\n"," #('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')\n"," #('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')\n"," #('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')\n"," #('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')\n"," #('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')\n"," #('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]')\n"," #('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')\n"," #('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]')\n"," #('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]')\n"," #('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')\n"," #('--no_dropout', action='store_true', help='no dropout for the generator')\n"," \n","# dataset parameters\n"," #('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')\n"," #('--direction', type=str, default='AtoB', help='AtoB or BtoA')\n"," #('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')\n"," #('--num_threads', default=4, type=int, help='# threads for loading data')\n"," #('--batch_size', type=int, default=1, help='input batch size')\n"," #('--load_size', type=int, default=286, help='scale images to this size')\n"," #('--crop_size', type=int, default=256, help='then crop to this size')\n"," #('--max_dataset_size', type=int, default=float(\"inf\"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')\n"," #('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')\n"," #('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')\n"," #('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')\n"," \n","# additional parameters\n"," #('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')\n"," #('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')\n"," #('--verbose', action='store_true', help='if specified, print more debugging information')\n"," #('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')\n"," \n","\n"," #('--ntest', type=int, default=float(\"inf\"), help='# of test examples.')\n"," #('--results_dir', type=str, default='./results/', help='saves results here.')\n"," #('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')\n"," #('--phase', type=str, default='test', help='train, val, test, etc')\n","\n","# Dropout and Batchnorm has different behavioir during training and test.\n"," #('--eval', action='store_true', help='use eval mode during test time.')\n"," #('--num_test', type=int, default=50, help='how many test images to run')\n"," # rewrite devalue values\n"," \n","# To avoid cropping, the load_size should be the same as crop_size\n"," #parser.set_defaults(load_size=parser.get_default('crop_size'))\n","\n","#------------------------------------------------------------------------\n","\n","\n","#---------------------------- Predictions are performed here ----------------------\n","\n","os.chdir(\"/content\")\n","\n","!python pytorch-CycleGAN-and-pix2pix/test.py --dataroot \"$imageAB_folder\" --name \"$Prediction_model_name\" --model pix2pix --no_dropout --preprocess scale_width --load_size $Image_min_dim --crop_size $patch_size --results_dir \"$Result_folder\" --checkpoints_dir \"$Prediction_model_path\" --num_test $Nb_files_Data_folder --epoch $checkpoint\n","\n","#-----------------------------------------------------------------------------------\n","\n","\n","Checkpoint_name = \"test_\"+str(checkpoint)\n","\n","\n","Prediction_results_folder = Result_folder+\"/\"+Prediction_model_name+\"/\"+Checkpoint_name+\"/images\"\n","\n","Prediction_results_images = os.listdir(Prediction_results_folder)\n","\n","for f in Prediction_results_images: \n"," if (f.endswith(\"_real_B.png\")): \n"," os.remove(Prediction_results_folder+\"/\"+f)\n","\n","\n","\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Pdnb77E15zLE"},"source":["## **6.2. Inspect the predicted output**\n","---\n","\n"]},{"cell_type":"code","metadata":{"id":"CrEBdt9T53Eh","cellView":"form"},"source":["# @markdown ##Run this cell to display a randomly chosen input and its corresponding predicted output.\n","import os\n","# This will display a randomly chosen dataset input and predicted output\n","random_choice = random.choice(os.listdir(Data_folder))\n","\n","\n","random_choice_no_extension = os.path.splitext(random_choice)\n","\n","\n","x = imageio.imread(Result_folder+\"/\"+Prediction_model_name+\"/test_\"+str(checkpoint)+\"/images/\"+random_choice_no_extension[0]+\"_real_A.png\")\n","\n","\n","y = imageio.imread(Result_folder+\"/\"+Prediction_model_name+\"/test_\"+str(checkpoint)+\"/images/\"+random_choice_no_extension[0]+\"_fake_B.png\")\n","\n","f=plt.figure(figsize=(16,8))\n","plt.subplot(1,2,1)\n","plt.imshow(x, interpolation='nearest')\n","plt.title('Input')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(y, interpolation='nearest')\n","plt.title('Prediction')\n","plt.axis('off');\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"hvkd66PldsXB"},"source":["## **6.3. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"HD0yZaIhUhth"},"source":["# **7. Version log**\n","---\n","**v1.13**: \n","\n","\n","\n","* The section 1 and 2 are now swapped for better export of *requirements.txt*.\n","This version also now includes built-in version check and the version log that \n","\n","* This version also now includes built-in version check and the version log that you're reading now.\n","\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"UvSlTaH14s3t"},"source":["\n","#**Thank you for using pix2pix!**"]}]} \ No newline at end of file