From 6a8315626a525b0faca384e13c9ad781b5756a46 Mon Sep 17 00:00:00 2001 From: guijacquemet Date: Mon, 8 Feb 2021 14:55:43 +0200 Subject: [PATCH] v1.12 --- Colab_notebooks/Beta notebooks/DRMIME_2D_ZeroCostDL4Mic.ipynb | 1 + 1 file changed, 1 insertion(+) create mode 100644 Colab_notebooks/Beta notebooks/DRMIME_2D_ZeroCostDL4Mic.ipynb diff --git a/Colab_notebooks/Beta notebooks/DRMIME_2D_ZeroCostDL4Mic.ipynb b/Colab_notebooks/Beta notebooks/DRMIME_2D_ZeroCostDL4Mic.ipynb new file mode 100644 index 00000000..3ed99d39 --- /dev/null +++ b/Colab_notebooks/Beta notebooks/DRMIME_2D_ZeroCostDL4Mic.ipynb @@ -0,0 +1 @@ +{"nbformat":4,"nbformat_minor":0,"metadata":{"accelerator":"GPU","colab":{"name":"DRMIME_2D_ZeroCostDL4Mic.ipynb","provenance":[{"file_id":"1hzAI0joLETcG5sI2Qvo8AKDr0TWRKySJ","timestamp":1587653755731},{"file_id":"1QFcz4NnQv4rMwDNl7AzHajN-Ola9sUFW","timestamp":1586411847878},{"file_id":"12UDRQ7abcnXcf5FctR9IUStgCpBiQWn7","timestamp":1584466922281},{"file_id":"1zXCn3A39GI1MCnXK_g_Z-AWh9vkB0YhU","timestamp":1583244415636}],"collapsed_sections":[],"toc_visible":true},"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.9"}},"cells":[{"cell_type":"markdown","metadata":{"id":"IkSguVy8Xv83"},"source":["# **DRMIME (2D)**\n","\n","---\n","\n"," DRMIME is a self-supervised deep-learning method that can be used to register 2D images.\n","\n"," **This particular notebook enables self-supervised registration of 2D dataset.**\n","\n","---\n","\n","*Disclaimer*:\n","\n","This notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (/~https://github.com/HenriquesLab/DeepLearning_Collab/wiki). Jointly developed by the Jacquemet (link to https://cellmig.org/) and Henriques (https://henriqueslab.github.io/) laboratories. \n","\n","\n","While this notebook is part of the Zero-Cost Deep-Learning to Enhance Microscopy project (ZeroCostDL4Mic), this notebook structure substantially deviates from other ZeroCostDL4Mic notebooks and our template. This is because the deep learning method employed here is used to improve the image registration process. No Deep Learning models are actually saved, only the registered images. \n","\n","\n","This notebook is largely based on the following paper:\n","\n","DRMIME: Differentiable Mutual Information and Matrix Exponential for Multi-Resolution Image Registration by Abhishek Nan\n"," *et al.* published on arXiv in 2020 (https://arxiv.org/abs/2001.09865)\n","\n","And source code found in: /~https://github.com/abnan/DRMIME\n","\n","**Please also cite this original paper when using or developing this notebook.**\n"]},{"cell_type":"markdown","metadata":{"id":"jWAz2i7RdxUV"},"source":["# **How to use this notebook?**\n","\n","---\n","\n","Video describing how to use our notebooks are available on youtube:\n"," - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook\n"," - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook\n","\n","\n","---\n","###**Structure of a notebook**\n","\n","The notebook contains two types of cell: \n","\n","**Text cells** provide information and can be modified by douple-clicking the cell. You are currently reading the text cell. You can create a new text by clicking `+ Text`.\n","\n","**Code cells** contain code and the code can be modfied by selecting the cell. To execute the cell, move your cursor on the `[ ]`-mark on the left side of the cell (play button appears). Click to execute the cell. After execution is done the animation of play button stops. You can create a new coding cell by clicking `+ Code`.\n","\n","---\n","###**Table of contents, Code snippets** and **Files**\n","\n","On the top left side of the notebook you find three tabs which contain from top to bottom:\n","\n","*Table of contents* = contains structure of the notebook. Click the content to move quickly between sections.\n","\n","*Code snippets* = contain examples how to code certain tasks. You can ignore this when using this notebook.\n","\n","*Files* = contain all available files. After mounting your google drive (see section 1.) you will find your files and folders here. \n","\n","**Remember that all uploaded files are purged after changing the runtime.** All files saved in Google Drive will remain. You do not need to use the Mount Drive-button; your Google Drive is connected in section 1.2.\n","\n","**Note:** The \"sample data\" in \"Files\" contains default files. Do not upload anything in here!\n","\n","---\n","###**Making changes to the notebook**\n","\n","**You can make a copy** of the notebook and save it to your Google Drive. To do this click file -> save a copy in drive.\n","\n","To **edit a cell**, double click on the text. This will show you either the source code (in code cells) or the source text (in text cells).\n","You can use the `#`-mark in code cells to comment out parts of the code. This allows you to keep the original code piece in the cell as a comment."]},{"cell_type":"markdown","metadata":{"id":"gKDLkLWUd-YX"},"source":["# **0. Before getting started**\n","---\n","\n","Before you run the notebook, please ensure that you are logged into your Google account and have the training and/or data to process in your Google Drive.\n","\n","For DRMIME to train, it requires at least two images. One **`\"Fixed image\"`** (template for the registration) and one **`Moving Image`** (image to be registered). Multiple **`Moving Images`** can also be provided if you want to register them to the same **`\"Fixed image\"`**. If you provide several **`Moving Images`**, multiple DRMIME instances will run one after another. \n","\n","The registration can also be applied to other channels. If you wish to apply the registration to other channels, please provide the images in another folder and carefully check your file names. Additional channels need to have the same name as the registered images and a prefix indicating the channel number starting at \"C1_\". See the example below. \n","\n","Here is a common data structure that can work:\n","\n","* Data\n"," \n"," - **Fixed_image_folder**\n"," - img_1.tif (image used as template for the registration)\n"," - **Moving_image_folder**\n"," - img_3.tif, img_4.tif, ... (images to be registered) \n"," - **Folder_containing_additional_channels** (optional, if you want to apply the registration to other channel(s))\n"," - C1_img_3.tif, C1_img_4.tif, ...\n"," - C2_img_3.tif, C2_img_4.tif, ...\n"," - C3_img_3.tif, C3_img_4.tif, ...\n"," - **Results**\n","\n","The **Results** folder will contain the processed images and PDF reports. Your original images remain unmodified.\n","\n","---\n","\n"]},{"cell_type":"markdown","metadata":{"id":"cbTknRcviyT7"},"source":["# **1. Initialise the Colab session**\n","\n","\n","\n","\n","---\n","\n","\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"DMNHVZfHmbKb"},"source":["## **1.1. Check for GPU access**\n","---\n","\n","By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:\n","\n","Go to **Runtime -> Change the Runtime type**\n","\n","**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*\n","\n","**Accelator: GPU** *(Graphics processing unit)*\n"]},{"cell_type":"code","metadata":{"cellView":"form","id":"h5i5CS2bSmZr"},"source":["#@markdown ##Run this cell to check if you have GPU access\n","#%tensorflow_version 1.x\n","\n","\n","import tensorflow as tf\n","if tf.test.gpu_device_name()=='':\n"," print('You do not have GPU access.') \n"," print('Did you change your runtime ?') \n"," print('If the runtime setting is correct then Google did not allocate a GPU for your session')\n"," print('Expect slow performance. To access GPU try reconnecting later')\n","\n","else:\n"," print('You have GPU access')\n"," !nvidia-smi"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n3B3meGTbYVi"},"source":["## **1.2. Mount your Google Drive**\n","---\n"," To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.\n","\n"," Play the cell below to mount your Google Drive and follow the link. In the new browser window, select your drive and select 'Allow', copy the code, paste into the cell and press enter. This will give Colab access to the data on the drive. \n","\n"," Once this is done, your data are available in the **Files** tab on the top left of notebook."]},{"cell_type":"code","metadata":{"cellView":"form","id":"01Djr8v-5pPk"},"source":["#@markdown ##Play the cell to connect your Google Drive to Colab\n","\n","#@markdown * Click on the URL. \n","\n","#@markdown * Sign in your Google Account. \n","\n","#@markdown * Copy the authorization code. \n","\n","#@markdown * Enter the authorization code. \n","\n","#@markdown * Click on \"Files\" site on the right. Refresh the site. Your Google Drive folder should now be available here as \"drive\". \n","\n","# mount user's Google Drive to Google Colab.\n","from google.colab import drive\n","drive.mount('/content/gdrive')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"n4yWFoJNnoin"},"source":["# **2. Install DRMIME and dependencies**\n","---"]},{"cell_type":"code","metadata":{"id":"fq21zJVFNASx","cellView":"form"},"source":["Notebook_version = ['1.12']\n","\n","\n","\n","#@markdown ##Install DRMIME and dependencies\n","\n","\n","# Here we install DRMIME and other required packages\n","\n","!pip install wget\n","\n","from skimage import io\n","import numpy as np\n","import math\n","import matplotlib.pyplot as plt\n","import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","from torch.autograd import Variable\n","import torch.optim as optim\n","from skimage.transform import pyramid_gaussian\n","from skimage.filters import gaussian\n","from skimage.filters import threshold_otsu\n","from skimage.filters import sobel\n","from skimage.color import rgb2gray\n","from skimage import feature\n","from torch.autograd import Function\n","import cv2\n","from IPython.display import clear_output\n","import pandas as pd\n","from skimage.io import imsave\n","\n","\n","\n","device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n","\n","\n","\n","# ------- Common variable to all ZeroCostDL4Mic notebooks -------\n","\n","import urllib\n","import os, random\n","import shutil \n","import zipfile\n","from tifffile import imread, imsave\n","import time\n","import sys\n","import wget\n","from pathlib import Path\n","import pandas as pd\n","import csv\n","from glob import glob\n","from scipy import signal\n","from scipy import ndimage\n","from skimage import io\n","from sklearn.linear_model import LinearRegression\n","from skimage.util import img_as_uint\n","import matplotlib as mpl\n","from skimage.metrics import structural_similarity\n","from skimage.metrics import peak_signal_noise_ratio as psnr\n","from astropy.visualization import simple_norm\n","from skimage import img_as_float32\n","\n","# Colors for the warning messages\n","class bcolors:\n"," WARNING = '\\033[31m'\n","W = '\\033[0m' # white (normal)\n","R = '\\033[31m' # red\n","\n","#Disable some of the tensorflow warnings\n","import warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","\n","# Check if this is the latest version of the notebook\n","Latest_notebook_version = pd.read_csv(\"https://raw.githubusercontent.com/HenriquesLab/ZeroCostDL4Mic/master/Colab_notebooks/Latest_ZeroCostDL4Mic_Release.csv\")\n","\n","if Notebook_version == list(Latest_notebook_version.columns):\n"," print(\"This notebook is up-to-date.\")\n","\n","if not Notebook_version == list(Latest_notebook_version.columns):\n"," print(bcolors.WARNING +\"A new version of this notebook has been released. We recommend that you download it at /~https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki\")\n","\n","!pip freeze > requirements.txt\n","\n","#Create a pdf document with training summary, not yet implemented\n","\n","def pdf_export(trained = False, augmentation = False, pretrained_model = False):\n"," # save FPDF() class into a \n"," # variable pdf \n"," #from datetime import datetime\n","\n"," class MyFPDF(FPDF, HTMLMixin):\n"," pass\n","\n"," pdf = MyFPDF()\n"," pdf.add_page()\n"," pdf.set_right_margin(-1)\n"," pdf.set_font(\"Arial\", size = 11, style='B') \n","\n"," Network = 'CARE 2D'\n"," day = datetime.now()\n"," datetime_str = str(day)[0:10]\n","\n"," Header = 'Training report for '+Network+' model ('+model_name+')\\nDate: '+datetime_str\n"," pdf.multi_cell(180, 5, txt = Header, align = 'L') \n","\n"," # add another cell \n"," if trained:\n"," training_time = \"Training time: \"+str(hour)+ \"hour(s) \"+str(mins)+\"min(s) \"+str(round(sec))+\"sec(s)\"\n"," pdf.cell(190, 5, txt = training_time, ln = 1, align='L')\n"," pdf.ln(1)\n","\n"," Header_2 = 'Information for your materials and methods:'\n"," pdf.cell(190, 5, txt=Header_2, ln=1, align='L')\n","\n"," all_packages = ''\n"," for requirement in freeze(local_only=True):\n"," all_packages = all_packages+requirement+', '\n"," #print(all_packages)\n","\n"," #Main Packages\n"," main_packages = ''\n"," version_numbers = []\n"," for name in ['tensorflow','numpy','Keras','csbdeep']:\n"," find_name=all_packages.find(name)\n"," main_packages = main_packages+all_packages[find_name:all_packages.find(',',find_name)]+', '\n"," #Version numbers only here:\n"," version_numbers.append(all_packages[find_name+len(name)+2:all_packages.find(',',find_name)])\n","\n"," cuda_version = subprocess.run('nvcc --version',stdout=subprocess.PIPE, shell=True)\n"," cuda_version = cuda_version.stdout.decode('utf-8')\n"," cuda_version = cuda_version[cuda_version.find(', V')+3:-1]\n"," gpu_name = subprocess.run('nvidia-smi',stdout=subprocess.PIPE, shell=True)\n"," gpu_name = gpu_name.stdout.decode('utf-8')\n"," gpu_name = gpu_name[gpu_name.find('Tesla'):gpu_name.find('Tesla')+10]\n"," #print(cuda_version[cuda_version.find(', V')+3:-1])\n"," #print(gpu_name)\n","\n"," shape = io.imread(Training_source+'/'+os.listdir(Training_source)[1]).shape\n"," dataset_size = len(os.listdir(Training_source))\n","\n"," text = 'The '+Network+' model was trained from scratch for '+str(number_of_epochs)+' epochs on '+str(dataset_size*number_of_patches)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+config.train_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," if pretrained_model:\n"," text = 'The '+Network+' model was trained for '+str(number_of_epochs)+' epochs on '+str(dataset_size*number_of_patches)+' paired image patches (image dimensions: '+str(shape)+', patch size: ('+str(patch_size)+','+str(patch_size)+')) with a batch size of '+str(batch_size)+' and a '+config.train_loss+' loss function, using the '+Network+' ZeroCostDL4Mic notebook (v '+Notebook_version[0]+') (von Chamier & Laine et al., 2020). The model was re-trained from a pretrained model. Key python packages used include tensorflow (v '+version_numbers[0]+'), Keras (v '+version_numbers[2]+'), csbdeep (v '+version_numbers[3]+'), numpy (v '+version_numbers[1]+'), cuda (v '+cuda_version+'). The training was accelerated using a '+gpu_name+'GPU.'\n","\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," pdf.multi_cell(190, 5, txt = text, align='L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(28, 5, txt='Augmentation: ', ln=0)\n"," pdf.set_font('')\n"," if augmentation:\n"," aug_text = 'The dataset was augmented by a factor of '+str(Multiply_dataset_by)+' by'\n"," if rotate_270_degrees != 0 or rotate_90_degrees != 0:\n"," aug_text = aug_text+'\\n- rotation'\n"," if flip_left_right != 0 or flip_top_bottom != 0:\n"," aug_text = aug_text+'\\n- flipping'\n"," if random_zoom_magnification != 0:\n"," aug_text = aug_text+'\\n- random zoom magnification'\n"," if random_distortion != 0:\n"," aug_text = aug_text+'\\n- random distortion'\n"," if image_shear != 0:\n"," aug_text = aug_text+'\\n- image shearing'\n"," if skew_image != 0:\n"," aug_text = aug_text+'\\n- image skewing'\n"," else:\n"," aug_text = 'No augmentation was used for training.'\n"," pdf.multi_cell(190, 5, txt=aug_text, align='L')\n"," pdf.set_font('Arial', size = 11, style = 'B')\n"," pdf.ln(1)\n"," pdf.cell(180, 5, txt = 'Parameters', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font_size(10.)\n"," if Use_Default_Advanced_Parameters:\n"," pdf.cell(200, 5, txt='Default Advanced Parameters were enabled')\n"," pdf.cell(200, 5, txt='The following parameters were used for training:')\n"," pdf.ln(1)\n"," html = \"\"\" \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
ParameterValue
number_of_epochs{0}
patch_size{1}
number_of_patches{2}
batch_size{3}
number_of_steps{4}
percentage_validation{5}
initial_learning_rate{6}
\n"," \"\"\".format(number_of_epochs,str(patch_size)+'x'+str(patch_size),number_of_patches,batch_size,number_of_steps,percentage_validation,initial_learning_rate)\n"," pdf.write_html(html)\n","\n"," #pdf.multi_cell(190, 5, txt = text_2, align='L')\n"," pdf.set_font(\"Arial\", size = 11, style='B')\n"," pdf.ln(1)\n"," pdf.cell(190, 5, txt = 'Training Dataset', align='L', ln=1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(29, 5, txt= 'Training_source:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_source, align = 'L')\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(27, 5, txt= 'Training_target:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = Training_target, align = 'L')\n"," #pdf.cell(190, 5, txt=aug_text, align='L', ln=1)\n"," pdf.ln(1)\n"," pdf.set_font('')\n"," pdf.set_font('Arial', size = 10, style = 'B')\n"," pdf.cell(22, 5, txt= 'Model Path:', align = 'L', ln=0)\n"," pdf.set_font('')\n"," pdf.multi_cell(170, 5, txt = model_path+'/'+model_name, align = 'L')\n"," pdf.ln(1)\n"," pdf.cell(60, 5, txt = 'Example Training pair', ln=1)\n"," pdf.ln(1)\n"," exp_size = io.imread('/content/TrainingDataExample_CARE2D.png').shape\n"," pdf.image('/content/TrainingDataExample_CARE2D.png', x = 11, y = None, w = round(exp_size[1]/8), h = round(exp_size[0]/8))\n"," pdf.ln(1)\n"," ref_1 = 'References:\\n - ZeroCostDL4Mic: von Chamier, Lucas & Laine, Romain, et al. \"ZeroCostDL4Mic: an open platform to simplify access and use of Deep-Learning in Microscopy.\" BioRxiv (2020).'\n"," pdf.multi_cell(190, 5, txt = ref_1, align='L')\n"," ref_2 = '- CARE: Weigert, Martin, et al. \"Content-aware image restoration: pushing the limits of fluorescence microscopy.\" Nature methods 15.12 (2018): 1090-1097.'\n"," pdf.multi_cell(190, 5, txt = ref_2, align='L')\n"," if augmentation:\n"," ref_3 = '- Augmentor: Bloice, Marcus D., Christof Stocker, and Andreas Holzinger. \"Augmentor: an image augmentation library for machine learning.\" arXiv preprint arXiv:1708.04680 (2017).'\n"," pdf.multi_cell(190, 5, txt = ref_3, align='L')\n"," pdf.ln(3)\n"," reminder = 'Important:\\nRemember to perform the quality control step on all newly trained models\\nPlease consider depositing your training dataset on Zenodo'\n"," pdf.set_font('Arial', size = 11, style='B')\n"," pdf.multi_cell(190, 5, txt=reminder, align='C')\n","\n"," pdf.output(model_path+'/'+model_name+'/'+model_name+\"_training_report.pdf\")\n","\n","\n","\n","\n","print(\"Libraries installed\")\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"HLYcZR9gMv42"},"source":["# **3. Select your parameters and paths**\n","---"]},{"cell_type":"markdown","metadata":{"id":"Kbn9_JdqnNnK"},"source":["## **3.1. Setting main training parameters**\n","---\n"," "]},{"cell_type":"markdown","metadata":{"id":"CB6acvUFtWqd"},"source":[" **Paths for training, predictions and results**\n","These is the path to your folders containing the image you want to register. To find the path of the folder containing your datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.\n","\n","**`Fixed_image_folder`:** This is the folder containing your \"Fixed image\".\n","\n","**`Moving_image_folder`:** This is the folder containing your \"Moving Image(s)\".\n","\n","**`Result_folder`:** This is the folder where your results will be saved.\n","\n","\n","**Training Parameters**\n","\n","**`model_name`:** Choose a name for your model.\n","\n","**`number_of_iteration`:** Input how many iteration (rounds) the network will be trained. Preliminary results can already be observed after a 200 iterations, but a full training should run for 500-1000 iterations. **Default value: 500**\n","\n","**`Registration_mode`:** Choose which registration method you would like to use.\n","\n","**Additional channels**\n","\n"," This option enable you to apply the registration to other images (for instance other channels). Place these images in the **`Additional_channels_folder`**. Additional channels need to have the same name as the images you want to register (found in **`Moving_image_folder`**) and a prefix indicating the channel number starting at \"C1_\".\n","\n"," \n","**Advanced Parameters - experienced users only**\n","\n","**`n_neurons`:** Number of neurons (elementary constituents) that will assemble your model. **Default value: 100**.\n","\n","**`mine_initial_learning_rate`:** Input the initial value to be used as learning rate for MINE. **Default value: 0.001**\n","**`homography_net_vL_initial_learning_rate`:** Input the initial value to be used as learning rate for homography_net_vL. **Default value: 0.001**\n","\n","**`homography_net_v1_initial_learning_rate`:** Input the initial value to be used as learning rate for homography_net_v1. **Default value: 0.0001**\n"]},{"cell_type":"code","metadata":{"id":"ewpNJ_I0Mv47","cellView":"form"},"source":["\n","#@markdown ###Path to the Fixed and Moving image folders: \n","Fixed_image_folder = \"\" #@param {type:\"string\"}\n","\n","\n","import os.path\n","from os import path\n","\n","if path.isfile(Fixed_image_folder):\n"," I = imread(Fixed_image_folder).astype(np.float32) # fixed image\n","\n","if path.isdir(Fixed_image_folder):\n"," Fixed_image = os.listdir(Fixed_image_folder)\n"," I = imread(Fixed_image_folder+\"/\"+Fixed_image[0]).astype(np.float32) # fixed image\n","\n","\n","Moving_image_folder = \"\" #@param {type:\"string\"}\n","\n","#@markdown ### Provide the path to the folder where the predictions are to be saved\n","Result_folder = \"\" #@param {type:\"string\"}\n","\n","\n","#@markdown ###Training Parameters\n","model_name = \"\" #@param {type:\"string\"}\n","\n","number_of_iteration = 500#@param {type:\"number\"}\n","\n","Registration_mode = \"Affine\" #@param [\"Affine\", \"Perspective\"]\n","\n","\n","#@markdown ###Do you want to apply the registration to other channel(s)?\n","Apply_registration_to_other_channels = False#@param {type:\"boolean\"}\n","\n","Additional_channels_folder = \"\" #@param {type:\"string\"}\n","\n","#@markdown ###Advanced Parameters\n","\n","Use_Default_Advanced_Parameters = True#@param {type:\"boolean\"}\n","\n","#@markdown ###If not, please input:\n","\n","n_neurons = 100 #@param {type:\"number\"}\n","mine_initial_learning_rate = 0.001 #@param {type:\"number\"}\n","homography_net_vL_initial_learning_rate = 0.001 #@param {type:\"number\"}\n","homography_net_v1_initial_learning_rate = 0.0001 #@param {type:\"number\"}\n","\n","if (Use_Default_Advanced_Parameters): \n"," print(\"Default advanced parameters enabled\") \n"," n_neurons = 100\n"," mine_initial_learning_rate = 0.001\n"," homography_net_vL_initial_learning_rate = 0.001\n"," homography_net_v1_initial_learning_rate = 0.0001\n","\n","\n","#failsafe for downscale could be useful \n","#to be added\n","\n","\n","#Load a random moving image to visualise and test the settings\n","random_choice = random.choice(os.listdir(Moving_image_folder))\n","J = imread(Moving_image_folder+\"/\"+random_choice).astype(np.float32)\n","\n","# Check if additional channel(s) need to be registered and if so how many\n","\n","print(str(len(os.listdir(Moving_image_folder)))+\" image(s) will be registered.\")\n","\n","if Apply_registration_to_other_channels:\n","\n"," other_channel_images = os.listdir(Additional_channels_folder)\n"," Number_of_other_channels = len(other_channel_images)/len(os.listdir(Moving_image_folder))\n","\n"," if Number_of_other_channels.is_integer():\n"," print(\"The registration(s) will be propagated to \"+str(Number_of_other_channels)+\" other channel(s)\")\n"," else:\n"," print(bcolors.WARNING +\"!! WARNING: Incorrect number of images in Folder_containing_additional_channels\"+W)\n","\n","#here we check that no model with the same name already exist, if so print a warning\n","if os.path.exists(Result_folder+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: \"+model_name+\" already exists and will be deleted in the following cell !!\")\n"," print(bcolors.WARNING +\"To continue training \"+model_name+\", choose a new model_name here, and load \"+model_name+\" in section 3.3\"+W)\n"," \n","\n","print(\"Example of two images to be registered\")\n","\n","#Here we display one image\n","f=plt.figure(figsize=(10,10))\n","plt.subplot(1,2,1)\n","plt.imshow(I, norm=simple_norm(I, percent = 99), interpolation='nearest')\n","\n","\n","plt.title('Fixed image')\n","plt.axis('off');\n","\n","plt.subplot(1,2,2)\n","plt.imshow(J, norm=simple_norm(J, percent = 99), interpolation='nearest')\n","plt.title('Moving image')\n","plt.axis('off');\n","plt.savefig('/content/TrainingDataExample_DRMIME2D.png',bbox_inches='tight',pad_inches=0)\n","plt.show()\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"QpKgUER3y9tn"},"source":["## **3.2. Choose and test the image pre-processing settings**\n","---\n"," DRMIME makes use of multi-resolution image pyramids to perform registration. Unlike a conventional method where computation starts at the highest level of the image pyramid and gradually proceeds to the lower levels, DRMIME simultaneously use all the levels in gradient descent-based optimization using automatic differentiation. Here, you can choose the parameters that define the multi-resolution image pyramids that will be used.\n","\n","**`nb_images_pyramid`:** Choose the number of images to use to assemble the pyramid. **Default value: 10**.\n","\n","**`Level_downscaling`:** Choose the level of downscaling that will be used to create the images of the pyramid **Default value: 1.8**.\n","\n","**`sampling`:** amount of sampling used for the perspective registration. **Default value: 0.1**.\n","\n"]},{"cell_type":"code","metadata":{"cellView":"form","id":"MoNXLwG6yd76"},"source":["\n","#@markdown ##Image pre-processing settings\n","\n","nb_images_pyramid = 10#@param {type:\"number\"} # where registration starts (at the coarsest resolution)\n","\n","L = nb_images_pyramid\n","\n","Level_downscaling = 1.8#@param {type:\"number\"}\n","\n","downscale = Level_downscaling\n","\n","sampling = 0.1#@param {type:\"number\"} # 10% sampling used only for perspective registration\n","\n","\n","ifplot=True\n","if np.ndim(I) == 3:\n"," nChannel=I.shape[2]\n"," pyramid_I = tuple(pyramid_gaussian(gaussian(I, sigma=1, multichannel=True), downscale=downscale, multichannel=True))\n"," pyramid_J = tuple(pyramid_gaussian(gaussian(J, sigma=1, multichannel=True), downscale=downscale, multichannel=True))\n","elif np.ndim(I) == 2:\n"," nChannel=1\n"," pyramid_I = tuple(pyramid_gaussian(gaussian(I, sigma=1, multichannel=False), downscale=downscale, multichannel=False))\n"," pyramid_J = tuple(pyramid_gaussian(gaussian(J, sigma=1, multichannel=False), downscale=downscale, multichannel=False))\n","else:\n"," print(\"Unknown rank for an image\")\n","\n","\n","# Control the display\n","width=5\n","height=5\n","rows = int(L/5)+1\n","cols = 5\n","axes=[]\n","fig=plt.figure(figsize=(16,16))\n","\n","if Registration_mode == \"Affine\":\n","\n"," print(\"Affine registration selected\")\n","\n","# create a list of necessary objects you will need and commit to GPU\n"," I_lst,J_lst,h_lst,w_lst,xy_lst,ind_lst=[],[],[],[],[],[]\n"," for s in range(L):\n"," I_ = torch.tensor(cv2.normalize(pyramid_I[s].astype(np.float32), None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)).to(device)\n"," J_ = torch.tensor(cv2.normalize(pyramid_J[s].astype(np.float32), None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)).to(device)\n","\n"," if nChannel>1:\n"," I_lst.append(I_.permute(2,0,1))\n"," J_lst.append(J_.permute(2,0,1))\n"," h_, w_ = I_lst[s].shape[1], I_lst[s].shape[2]\n","\n"," edges_grayscale = cv2.dilate(cv2.Canny(cv2.GaussianBlur(rgb2gray(pyramid_I[s]),(21,21),0).astype(np.uint8), 0, 30),\n"," np.ones((5,5),np.uint8),\n"," iterations = 1)\n"," ind_ = torch.nonzero(torch.tensor(edges_grayscale).view(h_*w_)).squeeze().to(device)[:1000000]\n"," ind_lst.append(ind_)\n"," else:\n"," I_lst.append(I_)\n"," J_lst.append(J_)\n"," h_, w_ = I_lst[s].shape[0], I_lst[s].shape[1]\n","\n"," edges_grayscale = cv2.dilate(cv2.Canny(cv2.GaussianBlur(rgb2gray(pyramid_I[s]),(21,21),0).astype(np.uint8), 0, 30),\n"," np.ones((5,5),np.uint8),\n"," iterations = 1)\n"," ind_ = torch.nonzero(torch.tensor(edges_grayscale).view(h_*w_)).squeeze().to(device)[:1000000]\n"," ind_lst.append(ind_) \n"," \n"," axes.append( fig.add_subplot(rows, cols, s+1) )\n"," subplot_title=(str(s))\n"," axes[-1].set_title(subplot_title) \n"," plt.imshow(edges_grayscale)\n"," plt.axis('off');\n","\n"," h_lst.append(h_)\n"," w_lst.append(w_)\n","\n"," y_, x_ = torch.meshgrid([torch.arange(0,h_).float().to(device), torch.arange(0,w_).float().to(device)])\n"," y_, x_ = 2.0*y_/(h_-1) - 1.0, 2.0*x_/(w_-1) - 1.0\n"," xy_ = torch.stack([x_,y_],2)\n"," xy_lst.append(xy_)\n","\n"," fig.tight_layout()\n","\n"," plt.show()\n","\n","\n","if Registration_mode == \"Perspective\":\n","\n"," print(\"Perspective registration selected\")\n","\n","# create a list of necessary objects you will need and commit to GPU\n"," I_lst,J_lst,h_lst,w_lst,xy_lst,ind_lst=[],[],[],[],[],[]\n"," for s in range(L):\n"," I_ = torch.tensor(cv2.normalize(pyramid_I[s].astype(np.float32), None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)).to(device)\n"," J_ = torch.tensor(cv2.normalize(pyramid_J[s].astype(np.float32), None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)).to(device)\n"," \n"," if nChannel>1:\n"," I_lst.append(I_.permute(2,0,1))\n"," J_lst.append(J_.permute(2,0,1))\n"," h_, w_ = I_lst[s].shape[1], I_lst[s].shape[2]\n","\n"," ind_ = torch.randperm(int(h_*w_*sampling))\n"," ind_lst.append(ind_)\n"," else:\n"," I_lst.append(I_)\n"," J_lst.append(J_)\n"," h_, w_ = I_lst[s].shape[0], I_lst[s].shape[1]\n","\n"," edges_grayscale = cv2.dilate(cv2.Canny(cv2.GaussianBlur(rgb2gray(pyramid_I[s]),(21,21),0).astype(np.uint8), 0, 10),\n"," np.ones((5,5),np.uint8),\n"," iterations = 1)\n"," ind_ = torch.randperm(int(h_*w_*sampling))\n"," ind_lst.append(ind_) \n"," \n"," axes.append( fig.add_subplot(rows, cols, s+1) )\n"," subplot_title=(str(s))\n"," axes[-1].set_title(subplot_title) \n"," plt.imshow(edges_grayscale)\n"," plt.axis('off');\n","\n"," h_lst.append(h_)\n"," w_lst.append(w_)\n","\n"," y_, x_ = torch.meshgrid([torch.arange(0,h_).float().to(device), torch.arange(0,w_).float().to(device)])\n"," y_, x_ = 2.0*y_/(h_-1) - 1.0, 2.0*x_/(w_-1) - 1.0\n"," xy_ = torch.stack([x_,y_],2)\n"," xy_lst.append(xy_)\n","\n"," fig.tight_layout()\n","\n"," plt.show()\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"keIQhCmOMv5S"},"source":["# **4. Train the network**\n","---"]},{"cell_type":"markdown","metadata":{"id":"Ovu0ESxivcxx"},"source":["## **4.1. Prepare for training**\n","---\n","Here, we use the information from 3. to load the correct dependencies."]},{"cell_type":"code","metadata":{"id":"t4QTv4vQvbnS","cellView":"form"},"source":["#@markdown ##Load the dependencies required for training\n","\n","print(\"--------------------------------------------------\")\n","\n","# Remove the model name folder if exists\n","\n","if os.path.exists(Result_folder+'/'+model_name):\n"," print(bcolors.WARNING +\"!! WARNING: Model folder already exists and has been removed !!\"+W)\n"," shutil.rmtree(Result_folder+'/'+model_name)\n","os.makedirs(Result_folder+'/'+model_name)\n","\n","\n","\n","if Registration_mode == \"Affine\":\n","\n"," class HomographyNet(nn.Module):\n"," def __init__(self):\n"," super(HomographyNet, self).__init__()\n"," # affine transform basis matrices\n","\n"," self.B = torch.zeros(6,3,3).to(device)\n"," self.B[0,0,2] = 1.0\n"," self.B[1,1,2] = 1.0\n"," self.B[2,0,1] = 1.0\n"," self.B[3,1,0] = 1.0\n"," self.B[4,0,0], self.B[4,1,1] = 1.0, -1.0\n"," self.B[5,1,1], self.B[5,2,2] = -1.0, 1.0\n","\n"," self.v1 = torch.nn.Parameter(torch.zeros(6,1,1).to(device), requires_grad=True)\n"," self.vL = torch.nn.Parameter(torch.zeros(6,1,1).to(device), requires_grad=True)\n","\n"," def forward(self, s):\n"," C = torch.sum(self.B*self.vL,0)\n"," if s==0:\n"," C += torch.sum(self.B*self.v1,0)\n"," A = torch.eye(3).to(device)\n"," H = A\n"," for i in torch.arange(1,10):\n"," A = torch.mm(A/i,C)\n"," H = H + A\n"," return H\n","\n"," class MINE(nn.Module): #https://arxiv.org/abs/1801.04062\n"," def __init__(self):\n"," super(MINE, self).__init__()\n"," self.fc1 = nn.Linear(2*nChannel, n_neurons)\n"," self.fc2 = nn.Linear(n_neurons, n_neurons)\n"," self.fc3 = nn.Linear(n_neurons, 1)\n"," self.bsize = 1 # 1 may be sufficient\n","\n"," def forward(self, x, ind):\n"," x = x.view(x.size()[0]*x.size()[1],x.size()[2])\n"," MI_lb=0.0\n"," for i in range(self.bsize):\n"," ind_perm = ind[torch.randperm(len(ind))]\n"," z1 = self.fc3(F.relu(self.fc2(F.relu(self.fc1(x[ind,:])))))\n"," z2 = self.fc3(F.relu(self.fc2(F.relu(self.fc1(torch.cat((x[ind,0:nChannel],x[ind_perm,nChannel:2*nChannel]),1))))))\n"," MI_lb += torch.mean(z1) - torch.log(torch.mean(torch.exp(z2)))\n","\n"," return MI_lb/self.bsize\n","\n"," def AffineTransform(I, H, xv, yv):\n"," # apply affine transform\n"," xvt = (xv*H[0,0]+yv*H[0,1]+H[0,2])/(xv*H[2,0]+yv*H[2,1]+H[2,2])\n"," yvt = (xv*H[1,0]+yv*H[1,1]+H[1,2])/(xv*H[2,0]+yv*H[2,1]+H[2,2])\n"," J = F.grid_sample(I,torch.stack([xvt,yvt],2).unsqueeze(0)).squeeze()\n"," return J\n","\n","\n"," def multi_resolution_loss():\n"," loss=0.0\n"," for s in np.arange(L-1,-1,-1):\n"," if nChannel>1:\n"," Jw_ = AffineTransform(J_lst[s].unsqueeze(0), homography_net(s), xy_lst[s][:,:,0], xy_lst[s][:,:,1]).squeeze()\n"," mi = mine_net(torch.cat([I_lst[s],Jw_],0).permute(1,2,0),ind_lst[s])\n"," loss = loss - (1./L)*mi\n"," else:\n"," Jw_ = AffineTransform(J_lst[s].unsqueeze(0).unsqueeze(0), homography_net(s), xy_lst[s][:,:,0], xy_lst[s][:,:,1]).squeeze()\n"," mi = mine_net(torch.stack([I_lst[s],Jw_],2),ind_lst[s])\n"," loss = loss - (1./L)*mi\n","\n"," return loss\n","\n","\n","\n","if Registration_mode == \"Perspective\":\n","\n"," class HomographyNet(nn.Module):\n"," def __init__(self):\n"," super(HomographyNet, self).__init__()\n"," # affine transform basis matrices\n","\n"," self.B = torch.zeros(8,3,3).to(device)\n"," self.B[0,0,2] = 1.0\n"," self.B[1,1,2] = 1.0\n"," self.B[2,0,1] = 1.0\n"," self.B[3,1,0] = 1.0\n"," self.B[4,0,0], self.B[4,1,1] = 1.0, -1.0\n"," self.B[5,1,1], self.B[5,2,2] = -1.0, 1.0\n"," self.B[6,2,0] = 1.0\n"," self.B[7,2,1] = 1.0\n","\n"," self.v1 = torch.nn.Parameter(torch.zeros(8,1,1).to(device), requires_grad=True)\n"," self.vL = torch.nn.Parameter(torch.zeros(8,1,1).to(device), requires_grad=True)\n","\n"," def forward(self, s):\n"," C = torch.sum(self.B*self.vL,0)\n"," if s==0:\n"," C += torch.sum(self.B*self.v1,0)\n"," A = torch.eye(3).to(device)\n"," H = A\n"," for i in torch.arange(1,10):\n"," A = torch.mm(A/i,C)\n"," H = H + A\n"," return H\n","\n","\n"," class MINE(nn.Module): #https://arxiv.org/abs/1801.04062\n"," def __init__(self):\n"," super(MINE, self).__init__()\n"," self.fc1 = nn.Linear(2*nChannel, n_neurons)\n"," self.fc2 = nn.Linear(n_neurons, n_neurons)\n"," self.fc3 = nn.Linear(n_neurons, 1)\n"," self.bsize = 1 # 1 may be sufficient\n","\n"," def forward(self, x, ind):\n"," x = x.view(x.size()[0]*x.size()[1],x.size()[2])\n"," MI_lb=0.0\n"," for i in range(self.bsize):\n"," ind_perm = ind[torch.randperm(len(ind))]\n"," z1 = self.fc3(F.relu(self.fc2(F.relu(self.fc1(x[ind,:])))))\n"," z2 = self.fc3(F.relu(self.fc2(F.relu(self.fc1(torch.cat((x[ind,0:nChannel],x[ind_perm,nChannel:2*nChannel]),1))))))\n"," MI_lb += torch.mean(z1) - torch.log(torch.mean(torch.exp(z2)))\n","\n"," return MI_lb/self.bsize\n","\n","\n"," def PerspectiveTransform(I, H, xv, yv):\n"," # apply homography\n"," xvt = (xv*H[0,0]+yv*H[0,1]+H[0,2])/(xv*H[2,0]+yv*H[2,1]+H[2,2])\n"," yvt = (xv*H[1,0]+yv*H[1,1]+H[1,2])/(xv*H[2,0]+yv*H[2,1]+H[2,2])\n"," J = F.grid_sample(I,torch.stack([xvt,yvt],2).unsqueeze(0)).squeeze()\n"," return J\n","\n","\n"," def multi_resolution_loss():\n"," loss=0.0\n"," for s in np.arange(L-1,-1,-1):\n"," if nChannel>1:\n"," Jw_ = PerspectiveTransform(J_lst[s].unsqueeze(0), homography_net(s), xy_lst[s][:,:,0], xy_lst[s][:,:,1]).squeeze()\n"," mi = mine_net(torch.cat([I_lst[s],Jw_],0).permute(1,2,0),ind_lst[s])\n"," loss = loss - (1./L)*mi\n"," else:\n"," Jw_ = PerspectiveTransform(J_lst[s].unsqueeze(0).unsqueeze(0), homography_net(s), xy_lst[s][:,:,0], xy_lst[s][:,:,1]).squeeze()\n"," mi = mine_net(torch.stack([I_lst[s],Jw_],2),ind_lst[s])\n"," loss = loss - (1./L)*mi\n","\n"," return loss\n","\n"," def histogram_mutual_information(image1, image2):\n"," hgram, x_edges, y_edges = np.histogram2d(image1.ravel(), image2.ravel(), bins=100)\n"," pxy = hgram / float(np.sum(hgram))\n"," px = np.sum(pxy, axis=1)\n"," py = np.sum(pxy, axis=0)\n"," px_py = px[:, None] * py[None, :]\n"," nzs = pxy > 0\n"," return np.sum(pxy[nzs] * np.log(pxy[nzs] / px_py[nzs]))\n","\n","\n","print(\"Done\")\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"0Dfn8ZsEMv5d"},"source":["## **4.2. Start Trainning**\n","---\n","When playing the cell below you should see updates after each iterations (round). A new network will be trained for each image that need to be registered.\n","\n","* **CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training time must be less than 12 hours! If training takes longer than 12 hours, please decrease the number of epochs or number of patches. Another way circumvent this is to save the parameters of the model after training and start training again from this point.\n","\n"]},{"cell_type":"code","metadata":{"id":"fisJmA13Mv5e","scrolled":true,"cellView":"form"},"source":["#@markdown ##Start training and the registration process\n","\n","start = time.time()\n","\n","loop_number = 1\n","\n","\n","\n","if Registration_mode == \"Affine\":\n","\n"," print(\"Affine registration.....\")\n","\n"," for image in os.listdir(Moving_image_folder):\n","\n"," if path.isfile(Fixed_image_folder):\n"," I = imread(Fixed_image_folder).astype(np.float32) # fixed image\n","\n"," if path.isdir(Fixed_image_folder):\n"," Fixed_image = os.listdir(Fixed_image_folder)\n"," I = imread(Fixed_image_folder+\"/\"+Fixed_image[0]).astype(np.float32) # fixed image\n","\n"," J = imread(Moving_image_folder+\"/\"+image).astype(np.float32)\n","\n"," # Here we generate the pyramidal images\n"," ifplot=True\n"," if np.ndim(I) == 3:\n"," nChannel=I.shape[2]\n"," pyramid_I = tuple(pyramid_gaussian(gaussian(I, sigma=1, multichannel=True), downscale=downscale, multichannel=True))\n"," pyramid_J = tuple(pyramid_gaussian(gaussian(J, sigma=1, multichannel=True), downscale=downscale, multichannel=True))\n"," elif np.ndim(I) == 2:\n"," nChannel=1\n"," pyramid_I = tuple(pyramid_gaussian(gaussian(I, sigma=1, multichannel=False), downscale=downscale, multichannel=False))\n"," pyramid_J = tuple(pyramid_gaussian(gaussian(J, sigma=1, multichannel=False), downscale=downscale, multichannel=False))\n"," else:\n"," print(\"Unknown rank for an image\")\n","\n","\n"," # create a list of necessary objects you will need and commit to GPU\n"," I_lst,J_lst,h_lst,w_lst,xy_lst,ind_lst=[],[],[],[],[],[]\n","\n","\n"," for s in range(L):\n"," I_ = torch.tensor(cv2.normalize(pyramid_I[s].astype(np.float32), None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)).to(device)\n"," J_ = torch.tensor(cv2.normalize(pyramid_J[s].astype(np.float32), None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)).to(device)\n","\n"," if nChannel>1:\n"," I_lst.append(I_.permute(2,0,1))\n"," J_lst.append(J_.permute(2,0,1))\n"," h_, w_ = I_lst[s].shape[1], I_lst[s].shape[2]\n","\n"," edges_grayscale = cv2.dilate(cv2.Canny(cv2.GaussianBlur(rgb2gray(pyramid_I[s]),(21,21),0).astype(np.uint8), 0, 30),\n"," np.ones((5,5),np.uint8),\n"," iterations = 1)\n"," ind_ = torch.nonzero(torch.tensor(edges_grayscale).view(h_*w_)).squeeze().to(device)[:1000000]\n"," ind_lst.append(ind_)\n"," else:\n"," I_lst.append(I_)\n"," J_lst.append(J_)\n"," h_, w_ = I_lst[s].shape[0], I_lst[s].shape[1]\n","\n"," edges_grayscale = cv2.dilate(cv2.Canny(cv2.GaussianBlur(rgb2gray(pyramid_I[s]),(21,21),0).astype(np.uint8), 0, 30),\n"," np.ones((5,5),np.uint8),\n"," iterations = 1)\n"," ind_ = torch.nonzero(torch.tensor(edges_grayscale).view(h_*w_)).squeeze().to(device)[:1000000]\n"," ind_lst.append(ind_)\n","\n"," h_lst.append(h_)\n"," w_lst.append(w_)\n","\n"," y_, x_ = torch.meshgrid([torch.arange(0,h_).float().to(device), torch.arange(0,w_).float().to(device)])\n"," y_, x_ = 2.0*y_/(h_-1) - 1.0, 2.0*x_/(w_-1) - 1.0\n"," xy_ = torch.stack([x_,y_],2)\n"," xy_lst.append(xy_)\n","\n"," homography_net = HomographyNet().to(device)\n"," mine_net = MINE().to(device)\n","\n"," optimizer = optim.Adam([{'params': mine_net.parameters(), 'lr': 1e-3},\n"," {'params': homography_net.vL, 'lr': 5e-3},\n"," {'params': homography_net.v1, 'lr': 1e-4}], amsgrad=True)\n"," mi_list = []\n"," for itr in range(number_of_iteration):\n"," optimizer.zero_grad()\n"," loss = multi_resolution_loss()\n"," mi_list.append(-loss.item())\n"," loss.backward()\n"," optimizer.step()\n"," clear_output(wait=True)\n"," plt.plot(mi_list)\n"," plt.xlabel('Iteration number')\n"," plt.ylabel('MI')\n"," plt.title(image+\". Image registration \"+str(loop_number)+\" out of \"+str(len(os.listdir(Moving_image_folder)))+\".\")\n"," plt.show()\n","\n"," I_t = torch.tensor(I).to(device) # without Gaussian\n"," J_t = torch.tensor(J).to(device) # without Gaussian\n"," H = homography_net(0)\n"," if nChannel>1:\n"," J_w = AffineTransform(J_t.permute(2,0,1).unsqueeze(0), H, xy_lst[0][:,:,0], xy_lst[0][:,:,1]).squeeze().permute(1,2,0)\n"," else:\n"," J_w = AffineTransform(J_t.unsqueeze(0).unsqueeze(0), H , xy_lst[0][:,:,0], xy_lst[0][:,:,1]).squeeze()\n","\n"," #Apply registration to other channels\n","\n"," if Apply_registration_to_other_channels:\n","\n"," for n_channel in range(1, int(Number_of_other_channels)+1):\n","\n"," channel = imread(Additional_channels_folder+\"/C\"+str(n_channel)+\"_\"+image).astype(np.float32)\n"," channel_t = torch.tensor(channel).to(device)\n"," channel_w = AffineTransform(channel_t.unsqueeze(0).unsqueeze(0), H , xy_lst[0][:,:,0], xy_lst[0][:,:,1]).squeeze()\n"," channel_registered = channel_w.cpu().data.numpy()\n"," io.imsave(Result_folder+'/'+model_name+\"/\"+\"C\"+str(n_channel)+\"_\"+image+\"_\"+Registration_mode+\"_registered.tif\", channel_registered)\n"," \n","# Export results to numpy array\n"," registered = J_w.cpu().data.numpy()\n","# Save results\n"," io.imsave(Result_folder+'/'+model_name+\"/\"+image+\"_\"+Registration_mode+\"_registered.tif\", registered)\n","\n"," loop_number = loop_number + 1\n","\n"," print(\"Your images have been registered and saved in your result_folder\")\n","\n","\n","#Perspective registration\n","\n","if Registration_mode == \"Perspective\":\n","\n"," print(\"Perspective registration.....\")\n","\n"," for image in os.listdir(Moving_image_folder):\n","\n"," if path.isfile(Fixed_image_folder):\n"," I = imread(Fixed_image_folder).astype(np.float32) # fixed image\n","\n"," if path.isdir(Fixed_image_folder):\n"," Fixed_image = os.listdir(Fixed_image_folder)\n"," I = imread(Fixed_image).astype(np.float32) # fixed image\n","\n"," J = imread(Moving_image_folder+\"/\"+image).astype(np.float32)\n","\n"," # Here we generate the pyramidal images\n"," ifplot=True\n"," if np.ndim(I) == 3:\n"," nChannel=I.shape[2]\n"," pyramid_I = tuple(pyramid_gaussian(gaussian(I, sigma=1, multichannel=True), downscale=downscale, multichannel=True))\n"," pyramid_J = tuple(pyramid_gaussian(gaussian(J, sigma=1, multichannel=True), downscale=downscale, multichannel=True))\n"," elif np.ndim(I) == 2:\n"," nChannel=1\n"," pyramid_I = tuple(pyramid_gaussian(gaussian(I, sigma=1, multichannel=False), downscale=downscale, multichannel=False))\n"," pyramid_J = tuple(pyramid_gaussian(gaussian(J, sigma=1, multichannel=False), downscale=downscale, multichannel=False))\n"," else:\n"," print(\"Unknown rank for an image\")\n","\n","\n"," # create a list of necessary objects you will need and commit to GPU\n"," I_lst,J_lst,h_lst,w_lst,xy_lst,ind_lst=[],[],[],[],[],[]\n"," for s in range(L):\n"," I_ = torch.tensor(cv2.normalize(pyramid_I[s].astype(np.float32), None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)).to(device)\n"," J_ = torch.tensor(cv2.normalize(pyramid_J[s].astype(np.float32), None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)).to(device)\n","\n"," if nChannel>1:\n"," I_lst.append(I_.permute(2,0,1))\n"," J_lst.append(J_.permute(2,0,1))\n"," h_, w_ = I_lst[s].shape[1], I_lst[s].shape[2]\n","\n"," ind_ = torch.randperm(int(h_*w_*sampling))\n"," ind_lst.append(ind_)\n"," else:\n"," I_lst.append(I_)\n"," J_lst.append(J_)\n"," h_, w_ = I_lst[s].shape[0], I_lst[s].shape[1]\n","\n"," edges_grayscale = cv2.dilate(cv2.Canny(cv2.GaussianBlur(rgb2gray(pyramid_I[s]),(21,21),0).astype(np.uint8), 0, 10),\n"," np.ones((5,5),np.uint8),\n"," iterations = 1)\n"," ind_ = torch.randperm(int(h_*w_*sampling))\n"," ind_lst.append(ind_)\n"," h_lst.append(h_)\n"," w_lst.append(w_)\n","\n"," y_, x_ = torch.meshgrid([torch.arange(0,h_).float().to(device), torch.arange(0,w_).float().to(device)])\n"," y_, x_ = 2.0*y_/(h_-1) - 1.0, 2.0*x_/(w_-1) - 1.0\n"," xy_ = torch.stack([x_,y_],2)\n"," xy_lst.append(xy_)\n","\n"," homography_net = HomographyNet().to(device)\n"," mine_net = MINE().to(device)\n","\n"," optimizer = optim.Adam([{'params': mine_net.parameters(), 'lr': 1e-3},\n"," {'params': homography_net.vL, 'lr': 1e-3},\n"," {'params': homography_net.v1, 'lr': 1e-4}], amsgrad=True)\n"," mi_list = []\n"," for itr in range(number_of_iteration):\n"," optimizer.zero_grad()\n"," loss = multi_resolution_loss()\n"," mi_list.append(-loss.item())\n"," loss.backward()\n"," optimizer.step()\n"," clear_output(wait=True)\n"," plt.plot(mi_list)\n"," plt.xlabel('Iteration number')\n"," plt.ylabel('MI')\n"," plt.title(image+\". Image registration \"+str(loop_number)+\" out of \"+str(len(os.listdir(Moving_image_folder)))+\".\")\n"," plt.show()\n","\n"," I_t = torch.tensor(I).to(device) # without Gaussian\n"," J_t = torch.tensor(J).to(device) # without Gaussian\n"," H = homography_net(0)\n"," if nChannel>1:\n"," J_w = PerspectiveTransform(J_t.permute(2,0,1).unsqueeze(0), H, xy_lst[0][:,:,0], xy_lst[0][:,:,1]).squeeze().permute(1,2,0)\n"," else:\n"," J_w = PerspectiveTransform(J_t.unsqueeze(0).unsqueeze(0), H , xy_lst[0][:,:,0], xy_lst[0][:,:,1]).squeeze()\n","\n"," #Apply registration to other channels\n","\n"," if Apply_registration_to_other_channels:\n","\n"," for n_channel in range(1, int(Number_of_other_channels)+1):\n","\n"," channel = imread(Additional_channels_folder+\"/C\"+str(n_channel)+\"_\"+image).astype(np.float32)\n"," channel_t = torch.tensor(channel).to(device)\n"," channel_w = PerspectiveTransform(channel_t.unsqueeze(0).unsqueeze(0), H , xy_lst[0][:,:,0], xy_lst[0][:,:,1]).squeeze()\n"," channel_registered = channel_w.cpu().data.numpy()\n"," io.imsave(Result_folder+'/'+model_name+\"/\"+\"C\"+str(n_channel)+\"_\"+image+\"_Perspective_registered.tif\", channel_registered) \n","\n","\n","# Export results to numpy array\n"," registered = J_w.cpu().data.numpy()\n","# Save results\n"," io.imsave(Result_folder+'/'+model_name+\"/\"+image+\"_Perspective_registered.tif\", registered)\n","\n"," loop_number = loop_number + 1\n","\n"," print(\"Your images have been registered and saved in your result_folder\")\n","\n","\n","# PDF export missing \n","\n","#pdf_export(trained = True, augmentation = Use_Data_augmentation, pretrained_model = Use_pretrained_model)\n","\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"PfTw_pQUUAqB"},"source":["## **4.3. Assess the registration**\n","---\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"SrArBvqwYvc9","cellView":"form"},"source":["# @markdown ##Run this cell to display a randomly chosen input and its corresponding predicted output.\n","\n","# For sliders and dropdown menu and progress bar\n","from ipywidgets import interact\n","import ipywidgets as widgets\n","\n","print('--------------------------------------------------------------')\n","@interact\n","def show_QC_results(file = os.listdir(Moving_image_folder)):\n","\n"," moving_image = imread(Moving_image_folder+\"/\"+file).astype(np.float32)\n"," \n"," registered_image = imread(Result_folder+\"/\"+model_name+\"/\"+file+\"_\"+Registration_mode+\"_registered.tif\").astype(np.float32)\n","\n","#Here we display one image\n","\n"," f=plt.figure(figsize=(20,20))\n"," plt.subplot(1,5,1)\n"," plt.imshow(I, norm=simple_norm(I, percent = 99), interpolation='nearest')\n"," plt.title('Fixed image')\n"," plt.axis('off');\n","\n"," plt.subplot(1,5,2)\n"," plt.imshow(moving_image, norm=simple_norm(moving_image, percent = 99), interpolation='nearest')\n"," plt.title('Moving image')\n"," plt.axis('off');\n","\n"," plt.subplot(1,5,3)\n"," plt.imshow(registered_image, norm=simple_norm(registered_image, percent = 99), interpolation='nearest')\n"," plt.title(\"Registered image\")\n"," plt.axis('off');\n","\n"," plt.subplot(1,5,4)\n"," plt.imshow(I, norm=simple_norm(I, percent = 99), interpolation='nearest', cmap=\"Greens\")\n"," plt.imshow(moving_image, norm=simple_norm(moving_image, percent = 99), interpolation='nearest', cmap=\"Oranges\", alpha=0.5)\n"," plt.title(\"Fixed and moving images\")\n"," plt.axis('off');\n","\n"," plt.subplot(1,5,5)\n"," plt.imshow(I, norm=simple_norm(I, percent = 99), interpolation='nearest', cmap=\"Greens\")\n"," plt.imshow(registered_image, norm=simple_norm(registered_image, percent = 99), interpolation='nearest', cmap=\"Oranges\", alpha=0.5)\n"," plt.title(\"Fixed and Registered images\")\n"," plt.axis('off');\n","\n"," plt.show()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"wgO7Ok1PBFQj"},"source":["## **4.4. Download your predictions**\n","---\n","\n","**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name."]},{"cell_type":"markdown","metadata":{"id":"nlyPYwZu4VVS"},"source":["#**Thank you for using DRMIME 2D!**"]}]} \ No newline at end of file