{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "420b6034-fe36-483d-b676-a6eeea2e80c0",
   "metadata": {},
   "source": [
    "## Introduction"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "58b3e6f3-7476-47bc-ade6-d5e020c78055",
   "metadata": {},
   "source": [
    "**Linn Abraham**\n",
    "* Final year PhD student - Thesis on the Application of Deep Learning to different problems in Astronomy\n",
    "* Here in IUCAA as part of the ISRO RESPOND project - **Solar Flares: Physics and Forecasting for better understanding of Space Weather**"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b6a7da9c-fdfa-40fe-8daf-2313891cbaef",
   "metadata": {},
   "source": [
    "* [Automated Detection of Galactic Rings from Sloan Digital Sky Survey Images](https://doi.org/10.3847/1538-4357/ad856d)\n",
    "    * Co-authors: Sheelu Abraham, Ajit Kembhavi, Ninan Sajeeth Philip, Sudhanshu Barwaye and others\n",
    "* Source detection for H1 galaxies from radio data\n",
    "    * Co-authors: Kshitij H Thorat, Arun K. Aniyan and others\n",
    "* Solar active region classification and interpretability using Deep Learning\n",
    "    * Co-authors: Durgesh Tripathi, Vishal Upendran and others "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ef56eda3-1c3d-4d5e-9cb8-c07e5b55c960",
   "metadata": {},
   "source": [
    "Contact:\n",
    "* E-mail: linn.official@gmail.com\n",
    "* Linkedin - [www.linkedin.com/in/linn-abraham/](https://www.linkedin.com/in/linn-abraham/)\n",
    "* Github - [github.com/linnabraham](https://github.com/linnabraham)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cd31bc4e-19ac-4dd2-aa12-a56f55d88cd7",
   "metadata": {},
   "source": [
    "## Galaxy Ring Detection"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7b65c0ce-bafb-429f-8cc6-e4f711cc53e9",
   "metadata": {},
   "source": [
    "### What is a ring in a galaxy?"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a163a342-6059-4bdc-bf81-711cd212e168",
   "metadata": {},
   "source": [
    "Hubble Tuning Fork"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "803eab13-99a1-4c0c-a579-cc8b49076c85",
   "metadata": {},
   "source": [
    "![hubble_fork](hubbles-fork.jpg)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "88943f8c-26b3-473b-8403-5af31899c600",
   "metadata": {},
   "source": [
    "![galaxy types](galaxy_types.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d2ec5086-94c3-4058-96a9-0d18bacdb5f2",
   "metadata": {},
   "source": [
    "* The understanding that rings are important morphological features for studying galaxy formation and evolution came later\n",
    "* de Vaucouleurs introduced the idea of a 3-dimensional classification volume, of which Hubble’s tuning fork is a sort of cross section parallel to one axis."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3e2d214e-2420-44c0-b2f1-dc75522bd3c0",
   "metadata": {},
   "source": [
    "![hubble3d](hubble_3d.jpg)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "81c5cde1-90cd-445c-aa2b-f68ec1a33918",
   "metadata": {},
   "source": [
    "### SDSS Image cutouts"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2f73e92f-aca5-468b-8f00-87293d3d2f41",
   "metadata": {},
   "source": [
    "![Training Images](sdss_train_rings.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d1e10c5e-3098-4874-8d7b-695528cc2cc4",
   "metadata": {},
   "source": [
    "## Resources"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9eceee41-daf9-442a-a976-a8f09b1afd14",
   "metadata": {},
   "source": [
    "* ChatGPT, stackoverflow, blogs and video tutorials are helpful.\n",
    "* Return back to books to check your understanding and build up basic knowledge\n",
    "* Read papers for understanding the latest architectures and techniques."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3f345aff-2099-4791-b9af-c8b353559bc9",
   "metadata": {},
   "source": [
    "#### Beginner level"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0892f8b6-f96d-4e61-836a-c7f985117f35",
   "metadata": {},
   "source": [
    "![marsland](books1.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c37077a4-9e15-490c-9b90-4d430355230e",
   "metadata": {},
   "source": [
    "#### More advanced"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "940f7d14-295a-4e89-9bf3-cc377379bc82",
   "metadata": {},
   "source": [
    "![books2](books2.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d76f4222-24c0-4c9d-a8a0-3fa00936d846",
   "metadata": {},
   "source": [
    "* [Compilation of resources Topic Wise](https://github.com/linnabraham/ml-tutorials/blob/master/Resources.md)\n",
    "* [Beginner's guide to machine learning - My Blog Entry](https://machinelearningmaniac.blogspot.com/2024/05/beginners-guide-to-machine-learning.html)\n",
    "* [3Blue1Brown Playlist on Neural Networks](https://www.youtube.com/watch?v=aircAruvnKk&list=PLZHQObOWTQDNU6R1_67000Dx_ZCJB-3pi)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bcfc8d71-9409-4d43-95cb-2f75db18f6a0",
   "metadata": {},
   "source": [
    "## Tips and Tricks\n",
    "Powered by own life experience"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0631f3e1-bab1-47ce-b245-485a7a9a2330",
   "metadata": {},
   "source": [
    "### Choosing frameworks"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1be5acc4-db45-468b-aba4-370cb5175c8d",
   "metadata": {},
   "source": [
    "* Tensorflow and PyTorch are the two major frameworks or platforms for deep learning in Python\n",
    "* Learning to write your own training loops is immensely useful\n",
    "* model.fit in tensorflow vs writing your own custom training loop in PyTorch\n",
    "* tf.data.Datasets modelled after PyTorch dataset\n",
    "* GradientTape api and custom training loops exist in tensorflow as well\n",
    "* (Asitang's Lecture in IUCAA 2023)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "81f467df-b635-4ac2-8e3b-193a8f972052",
   "metadata": {},
   "source": [
    "### Word about IDEs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "17027e00-1799-4dbd-a4dc-fbb0bd44a78b",
   "metadata": {},
   "source": [
    "* Notebooks have two good uses\n",
    "    * Inline images (also Mardown support) which makes them good to present your code to yourself or others\n",
    "    * Keep variables in memory until you quit the program\n",
    "    * Reasons not to use it - [I Don't like Jupyter Notebooks](https://youtu.be/7jiPeIFXb6U)\n",
    "* Be comfortable with ViM + script workflow:\n",
    "    * Save figures to disk and view them through SSH fs or something\n",
    "    * Pylint\n",
    "    * PDB for debugging\n",
    "* Once you know these things also try out VSCode and others"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2cda9b90-6dd8-4b11-87c9-1f7e12d85ce9",
   "metadata": {},
   "source": [
    "### Virtual environment"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dec78376-6034-4bd0-b579-590138c61616",
   "metadata": {},
   "source": [
    "* Always work in a virtual environment\n",
    "* Conda and Pip environments serve different use cases - https://machinelearningmaniac.blogspot.com/2024/05/a-guide-to-using-conda-for-managing.html\n",
    "* Version control your environment using environment.yml or requirements.txt for reproducibility"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "621549d2-853c-4edd-9e9e-b27fcf46667b",
   "metadata": {},
   "source": [
    "### General Tips"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fca19442-e011-4333-a32a-0eaa1653a14d",
   "metadata": {},
   "source": [
    "* Never re-invent the wheel; always resort to well developed libraries\n",
    "    * Scikit-learn, scipy, pandas, scikit-image, OpenCV \n",
    "* For things that don't exist elsewhere put up your custom utility libraries in github and install them into your own environment using Pip\n",
    "* Writing classes might help at some point when you data structure and functionalities get complex and interdependent"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4595dc8f-9d54-46fe-9c11-46c7eec6eeb3",
   "metadata": {},
   "source": [
    "### Code version control"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "abb874a7-6051-47a1-89c6-1e299ca661ef",
   "metadata": {},
   "source": [
    "* Any ML person needs to use git - the early the better\n",
    "* On the command line first\n",
    "* Not just to avoid directories named project_v1, project_v2, project_v3\n",
    "* Very versatile tool\n",
    "* [Missing semester lecture](https://youtu.be/2sjqTHE0zok)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8b0c7779-9888-4d00-a17c-974c3f1b1aa7",
   "metadata": {},
   "source": [
    "### Clone git repo"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "1163ae14-3856-486a-a0d5-bbc946c242a9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Cloning into 'galactic-rings'...\n",
      "remote: Enumerating objects: 587, done.\u001b[K\n",
      "remote: Counting objects: 100% (65/65), done.\u001b[K\n",
      "remote: Compressing objects: 100% (47/47), done.\u001b[K\n",
      "remote: Total 587 (delta 21), reused 50 (delta 10), pack-reused 522 (from 1)\u001b[K\n",
      "Receiving objects: 100% (587/587), 9.48 MiB | 5.31 MiB/s, done.\n",
      "Resolving deltas: 100% (328/328), done.\n"
     ]
    }
   ],
   "source": [
    "!git clone -b handson https://github.com/linnabraham/galactic-rings.git"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "dd6e9703-6ca0-4cfb-b336-3f231e69249b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/home/linn/2024/dec/aiml-handson/galactic-rings\n"
     ]
    }
   ],
   "source": [
    "%cd galactic-rings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f9f46dcf-e6a3-41fa-86ec-861b00c249d1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      ".\n",
      "├── alexnet_utils\n",
      "├── create_pred_catalogue.py\n",
      "├── data\n",
      "├── environment.yml\n",
      "├── evaluate_alexnet.py\n",
      "├── figures\n",
      "├── helpers\n",
      "├── notebooks\n",
      "├── optional-deps.yml\n",
      "├── plot_hists.py\n",
      "├── predict_alexnet.py\n",
      "├── predict_single.py\n",
      "├── README.md\n",
      "├── requirements.txt\n",
      "├── separate_images.py\n",
      "├── torch\n",
      "├── train_alexnet_kfold.py\n",
      "├── train_alexnet.py\n",
      "└── train_lenet.py\n",
      "\n",
      "6 directories, 13 files\n"
     ]
    }
   ],
   "source": [
    "!tree -L 1"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1e2841a5-7248-42a4-99d9-11864290470d",
   "metadata": {},
   "source": [
    "## Data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fe834649-fc86-416f-9c29-672c07f7edf9",
   "metadata": {},
   "source": [
    "### Domain knowledge is important"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c22ccf5a-8f4e-4bd9-9f06-54c566daae32",
   "metadata": {},
   "source": [
    "* Pan-STARRS vs SDSS\n",
    "* Pan-STARRS was much less suited for galaxy morphology analysis"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f6764e17-cdfd-4e19-b3ba-70f95d332de4",
   "metadata": {},
   "source": [
    "![ps](panstarrs_collage.jpg)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "83a3ad88-af54-4b53-b480-03d2c562d10e",
   "metadata": {},
   "source": [
    "### Quantity of your data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9a645321-8992-49d9-8d31-ef755bb25fde",
   "metadata": {},
   "source": [
    "* Used the two largest catalogs of galaxy morphology available\n",
    "* Buta (2017) for Rings\n",
    "* Nair. et. al. (2010) for Non-Rings"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dcb50c6b-317f-468b-a776-0079654cf0f5",
   "metadata": {},
   "source": [
    "#### Sample of rings"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ff5722de-969c-4fa0-bfe8-67bf8f8f798b",
   "metadata": {},
   "source": [
    "![train_rings](sdss_train_rings.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1e340be7-b250-424b-b07d-1acb142543bc",
   "metadata": {},
   "source": [
    "#### Rings: Buta (2017)\n",
    "* 3962 Rings\n",
    "* Citizen science + Expert analysis"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c51db562-d7e7-4499-b579-e1fa9cbb4172",
   "metadata": {},
   "source": [
    "#### Non-Rings: Nair et. al. (2010) "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ab04c00b-5c45-4c92-abf8-e2cd91f974ad",
   "metadata": {},
   "source": [
    "* 14,034 galaxies with both rings and non-rings\n",
    "* Ring classifications - nuclear, inner, outer, pseudo-outer (R1/R2), and collisional.\n",
    "* Brighter than 16 mag\n",
    "* Between 0.01 and 0.1 redshift"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0519d146-ff3a-4fcf-b2ca-d37506963959",
   "metadata": {},
   "source": [
    "### Avoid any possible bias"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "70a9dba1-89b7-431f-ae71-a14d8b426321",
   "metadata": {},
   "source": [
    "* Apply the same selection criteria to Rings as in Non-Rings"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2b9210a8-6417-4029-b8fa-8606d6c8a376",
   "metadata": {},
   "source": [
    "### Quality of your data matters"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2a8c10d3-0a3d-48f6-bf20-5c3be7bec345",
   "metadata": {},
   "source": [
    "Looked through the large sample by eye to remove the \"bad\" data\n",
    "* Multiple galaxies\n",
    "* Outliers\n",
    "* Ambigous types\n",
    "\n",
    "Finally left with 1122 rings and 10,639 non-rings"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6449c260-3993-464a-af0d-4cdb0bab1f09",
   "metadata": {},
   "source": [
    "### Data version control"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "88179b53-cd26-418c-a7e9-345e5f69eead",
   "metadata": {},
   "source": [
    "* Use [DVC](https://dvc.org/)\n",
    "* md5sum hashes for a directory structure + git"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8b49e097-351a-49cc-ab24-8c3dd8712cd6",
   "metadata": {},
   "source": [
    "### Script for downloading galaxy image cutouts in bulk"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c2a7c542-9a1e-43e1-9a39-692185bdabfe",
   "metadata": {},
   "source": [
    "* Multi-threading (not processing) for I/O or network intensive processes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "9edb23f7-34df-429d-ac32-1fb80e89c726",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading...\n",
      "From (original): https://drive.google.com/uc?id=1RlPl3WD4JDz5N-kx5g00YveLtZ43vPw-\n",
      "From (redirected): https://drive.google.com/uc?id=1RlPl3WD4JDz5N-kx5g00YveLtZ43vPw-&confirm=t&uuid=c4a3b2e1-99ee-41b4-ab31-ef057d60ecfc\n",
      "To: /home/linn/2024/dec/aiml-handson/galactic-rings/galaxies.tar.gz\n",
      "100%|██████████████████████████████████████| 77.6M/77.6M [00:06<00:00, 12.5MB/s]\n"
     ]
    }
   ],
   "source": [
    "!gdown --fuzzy \"https://drive.google.com/file/d/1RlPl3WD4JDz5N-kx5g00YveLtZ43vPw-/view?usp=drive_link\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "35268e0d-cbcd-4d6a-9f93-011df696474d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "galaxies.tar.gz: OK\n"
     ]
    }
   ],
   "source": [
    "!echo \"0f456e955b0b5312aec8d2dd6186218c  galaxies.tar.gz\" | md5sum -c"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "9162f14e-c5d7-4ffd-a7db-79c9de50daee",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture\n",
    "!tar xvzf galaxies.tar.gz -C data/"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6879f158-1cfb-4461-8dd4-0c2da193a4ae",
   "metadata": {},
   "source": [
    "## Setup environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "2158e591-a5cd-433b-9fde-2056b10c2f3e",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture\n",
    "!pip install wandb\n",
    "!pip install gdown"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6dcbeca7-37fa-4a2d-8c5b-cb32c96029eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture\n",
    "!pip install tensorflow"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c62cf196-0e53-4414-973b-abd8500321b6",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "ca7afdeb-08bd-454e-8c6f-21247494fba6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-01-09 14:48:33.496568: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
      "2025-01-09 14:48:35.498549: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
      "To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import json\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras.callbacks import ModelCheckpoint, Callback\n",
    "from alexnet_utils.params import parser, print_arguments\n",
    "from alexnet_utils.alexnet import AlexNet\n",
    "import wandb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "6f933e3b-e755-46b3-a314-33c90e179e8d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.12.1\n"
     ]
    }
   ],
   "source": [
    "print(tf.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "6cec405a-e353-44d4-8bb0-7eb92cc836c7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.15.11\n"
     ]
    }
   ],
   "source": [
    "print(wandb.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "6ef00404-47dd-497b-821a-919df6a01f88",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]\n"
     ]
    }
   ],
   "source": [
    "print(tf.config.list_physical_devices('GPU'))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "51e251d2-0ed1-427a-8abb-debc353c4131",
   "metadata": {},
   "source": [
    "## Data augmentations"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b6113e98-a3dd-4dfa-bac8-2164829d9eff",
   "metadata": {},
   "source": [
    "* Rescale data to be between (0,1) make it easier for the network to do all the math.\n",
    "* Custom augmentations\n",
    "  * Rotation\n",
    "  * Flip\n",
    "  * Brightness\n",
    "  * Contrast"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "b80d4e92-c79a-474d-abda-838edd1bb841",
   "metadata": {},
   "outputs": [],
   "source": [
    "def random_choice(x, size, seed, axis=0, unique=True):\n",
    "    dim_x = tf.cast(tf.shape(x)[axis], tf.int64)\n",
    "    indices = tf.range(0, dim_x, dtype=tf.int64)\n",
    "    sample_index = tf.random.shuffle(indices,seed=seed)[:size]\n",
    "    sample = tf.gather(x, sample_index, axis=axis)\n",
    "\n",
    "    return sample, sample_index\n",
    "\n",
    "def random_int_rot_img(inputs,seed):\n",
    "    angles = tf.constant([1, 2, 3, 4])\n",
    "    # Make a new seed.\n",
    "    #new_seed = tf.random.experimental.stateless_split((seed,seed), num=1)[0, :]\n",
    "    angle = random_choice(angles,1,seed=seed)[0][0]\n",
    "    inputs = tf.image.rot90(inputs, k=angle)\n",
    "\n",
    "    return inputs\n",
    "\n",
    "def rescale(image, label):\n",
    "    image = tf.cast(image, tf.float32)\n",
    "    image = (image / 255.0)\n",
    "\n",
    "    return image, label\n",
    "\n",
    "# define custom augmentations\n",
    "def augment_custom(images, labels, augmentation_types, seed):\n",
    "    images, labels = rescale(images, labels)\n",
    "    # Make a new seed.\n",
    "    #new_seed = tf.random.experimental.stateless_split((seed,seed), num=1)[0, :]\n",
    "    new_seed = seed\n",
    "    if 'rotation' in augmentation_types:\n",
    "        images = random_int_rot_img(images,seed=seed)\n",
    "    if 'flip' in augmentation_types:\n",
    "        images = tf.image.random_flip_left_right(images, seed=new_seed)\n",
    "        images = tf.image.random_flip_up_down(images, seed=new_seed)\n",
    "    if 'brightness' in augmentation_types:\n",
    "        images = tf.image.random_brightness(images, max_delta=0.2, seed=new_seed)\n",
    "    if 'contrast' in augmentation_types:\n",
    "        images = tf.image.random_contrast(images, lower=0.2, upper=0.5, seed=new_seed)\n",
    "\n",
    "    return (images, labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5cac1070-a99d-4f0e-8ea4-366adb5a7df4",
   "metadata": {},
   "source": [
    "# Train"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c5648f3a-aacb-47c9-a76e-59dd6334d251",
   "metadata": {},
   "source": [
    "### Define argparse arguments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "583cdd74-010a-48df-b1a1-8cee9e12aaae",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "_StoreTrueAction(option_strings=['-retrain', '--retrain'], dest='retrain', nargs=0, const=True, default=False, type=None, choices=None, required=False, help='Whether to continue previous training', metavar=None)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "parser.add_argument('-images', '--images', required=True, help=\"path containing images of two classes\")\n",
    "parser.add_argument('-epochs', '--epochs', required=True, type=int, default=50, help=\"num epochs\")\n",
    "parser.add_argument('-model-path', '--model-path', default=None, help=\"Filepath to save model during training and to load model from when testing\")\n",
    "parser.add_argument('-val-dir', '--val-dir', default=None, help=\"path containing validation data\")\n",
    "parser.add_argument('-retrain', '--retrain', action=\"store_true\", help=\"Whether to continue previous training\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "65240c23-7815-4d2f-8f69-b3591737a63d",
   "metadata": {},
   "outputs": [],
   "source": [
    "args = parser.parse_args(['-images', 'data/galaxies', '-epochs', '2'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "3915f923-bfa4-453a-bb48-d12acdec39c3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Arguments and Data Types:\n",
      "  target_size: tuple_type - (240, 240)\n",
      "  batch_size: int - 16\n",
      "  train_frac: float - 0.8\n",
      "  random_state: int - 42\n",
      "  num_classes: int - 2\n",
      "  channels: int - 3\n",
      "  output_dir: None - output\n",
      "  augmentation_types: str - ['flip', 'rotation']\n",
      "  images: None - data/galaxies\n",
      "  epochs: int - 2\n",
      "  model_path: None - None\n",
      "  val_dir: None - None\n"
     ]
    }
   ],
   "source": [
    "print_arguments(parser, args)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "06e8b9a3-9f69-4e9e-ad74-a6a086425526",
   "metadata": {},
   "source": [
    "## Network Architecture"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "406180ab-4444-47db-b2e5-b5d537d375a6",
   "metadata": {},
   "source": [
    "* Choice of the architecture should match your data size\n",
    "* If your architecture of choice is not pre-defined in your favourite framework you (or someone else) has to implement it by hand"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b37347fb-693f-4ace-9d6f-e9d90fa2d41c",
   "metadata": {},
   "source": [
    "### Define AlexNet architecture"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "b8cf3584-5842-4e7b-a1a7-084b494806cf",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-01-09 14:49:20.192419: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1635] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 38367 MB memory:  -> device: 0, name: NVIDIA A100-PCIE-40GB, pci bus id: 0000:37:00.0, compute capability: 8.0\n"
     ]
    }
   ],
   "source": [
    "model = AlexNet.build(width=args.target_size[0], height=args.target_size[1], \\\n",
    "                      depth=args.channels, classes=1, reg=0.0002)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "f0ced6e9-1d13-4934-bf8a-1df94450ab2b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " conv2d (Conv2D)             (None, 120, 120, 96)      7296      \n",
      "                                                                 \n",
      " activation (Activation)     (None, 120, 120, 96)      0         \n",
      "                                                                 \n",
      " batch_normalization (BatchN  (None, 120, 120, 96)     384       \n",
      " ormalization)                                                   \n",
      "                                                                 \n",
      " max_pooling2d (MaxPooling2D  (None, 59, 59, 96)       0         \n",
      " )                                                               \n",
      "                                                                 \n",
      " dropout (Dropout)           (None, 59, 59, 96)        0         \n",
      "                                                                 \n",
      " conv2d_1 (Conv2D)           (None, 59, 59, 256)       614656    \n",
      "                                                                 \n",
      " activation_1 (Activation)   (None, 59, 59, 256)       0         \n",
      "                                                                 \n",
      " batch_normalization_1 (Batc  (None, 59, 59, 256)      1024      \n",
      " hNormalization)                                                 \n",
      "                                                                 \n",
      " max_pooling2d_1 (MaxPooling  (None, 29, 29, 256)      0         \n",
      " 2D)                                                             \n",
      "                                                                 \n",
      " dropout_1 (Dropout)         (None, 29, 29, 256)       0         \n",
      "                                                                 \n",
      " conv2d_2 (Conv2D)           (None, 29, 29, 384)       885120    \n",
      "                                                                 \n",
      " activation_2 (Activation)   (None, 29, 29, 384)       0         \n",
      "                                                                 \n",
      " batch_normalization_2 (Batc  (None, 29, 29, 384)      1536      \n",
      " hNormalization)                                                 \n",
      "                                                                 \n",
      " conv2d_3 (Conv2D)           (None, 29, 29, 384)       1327488   \n",
      "                                                                 \n",
      " activation_3 (Activation)   (None, 29, 29, 384)       0         \n",
      "                                                                 \n",
      " batch_normalization_3 (Batc  (None, 29, 29, 384)      1536      \n",
      " hNormalization)                                                 \n",
      "                                                                 \n",
      " conv2d_4 (Conv2D)           (None, 29, 29, 256)       884992    \n",
      "                                                                 \n",
      " activation_4 (Activation)   (None, 29, 29, 256)       0         \n",
      "                                                                 \n",
      " batch_normalization_4 (Batc  (None, 29, 29, 256)      1024      \n",
      " hNormalization)                                                 \n",
      "                                                                 \n",
      " max_pooling2d_2 (MaxPooling  (None, 14, 14, 256)      0         \n",
      " 2D)                                                             \n",
      "                                                                 \n",
      " dropout_2 (Dropout)         (None, 14, 14, 256)       0         \n",
      "                                                                 \n",
      " flatten (Flatten)           (None, 50176)             0         \n",
      "                                                                 \n",
      " dense (Dense)               (None, 4096)              205524992 \n",
      "                                                                 \n",
      " activation_5 (Activation)   (None, 4096)              0         \n",
      "                                                                 \n",
      " batch_normalization_5 (Batc  (None, 4096)             16384     \n",
      " hNormalization)                                                 \n",
      "                                                                 \n",
      " dropout_3 (Dropout)         (None, 4096)              0         \n",
      "                                                                 \n",
      " dense_1 (Dense)             (None, 4096)              16781312  \n",
      "                                                                 \n",
      " activation_6 (Activation)   (None, 4096)              0         \n",
      "                                                                 \n",
      " batch_normalization_6 (Batc  (None, 4096)             16384     \n",
      " hNormalization)                                                 \n",
      "                                                                 \n",
      " dropout_4 (Dropout)         (None, 4096)              0         \n",
      "                                                                 \n",
      " dense_2 (Dense)             (None, 1)                 4097      \n",
      "                                                                 \n",
      " activation_7 (Activation)   (None, 1)                 0         \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 226,068,225\n",
      "Trainable params: 226,049,089\n",
      "Non-trainable params: 19,136\n",
      "_________________________________________________________________\n",
      "None\n"
     ]
    }
   ],
   "source": [
    "print(model.summary())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "84efc9da-7b30-4e9a-a144-4f3f9310943d",
   "metadata": {},
   "source": [
    "## Define validation loss, evaluation metrics, optimizer and learning rate"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8ea67bdd-fd4f-4109-984c-6669459a004d",
   "metadata": {},
   "source": [
    "Image Credit: Chollet\n",
    "![skeleton](ml_skeleton.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "487f8ed4-b038-4052-a28e-89e7f51a0854",
   "metadata": {},
   "outputs": [],
   "source": [
    "classification_threshold = 0.5\n",
    "\n",
    "METRICS = [\n",
    "      tf.keras.metrics.Precision(thresholds=classification_threshold,\n",
    "                                 name='precision'),\n",
    "      tf.keras.metrics.Recall(thresholds=classification_threshold,\n",
    "                              name=\"recall\"),\n",
    "      tf.keras.metrics.AUC(num_thresholds=100, curve='PR', name='auc_pr'),\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "4d7783a9-47e9-42f0-ad2d-9df1a91792af",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.compile(loss=\"binary_crossentropy\", optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3)\\\n",
    "              , metrics=METRICS)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9b0f1bea-9a2b-4a58-9206-9aaa54a920bc",
   "metadata": {},
   "source": [
    "## Define Callbacks"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "62d65eed-2448-4875-9d39-793288feae06",
   "metadata": {},
   "source": [
    "* Model checkpoint\n",
    "* Save history\n",
    "* Wandb logging"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "5a211631-dfdb-4772-81cf-313c015ecd35",
   "metadata": {},
   "outputs": [],
   "source": [
    "class SaveHistoryCallback(Callback):\n",
    "    def __init__(self, file_path):\n",
    "        super().__init__()\n",
    "        self.file_path = file_path\n",
    "        self.history = {'loss': [], 'val_loss': [], 'auc_pr':[], 'val_auc_pr':[], 'val_precision':[], 'val_recall':[]}\n",
    "\n",
    "    def on_epoch_end(self, epoch, logs=None):\n",
    "        self.history['loss'].append(logs.get('loss'))\n",
    "        self.history['val_loss'].append(logs.get('val_loss'))\n",
    "        self.history['auc_pr'].append(logs.get('auc_pr'))\n",
    "        self.history['val_auc_pr'].append(logs.get('val_auc_pr'))\n",
    "        self.history['val_precision'].append(logs.get('val_precision'))\n",
    "        self.history['val_recall'].append(logs.get('val_recall'))\n",
    "\n",
    "        with open(self.file_path, 'w') as f:\n",
    "            json.dump(self.history, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "627eb5f6-5301-4830-8b26-7a20426a7a79",
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_callbacks(run_name):\n",
    "    outdir = os.path.join(\"output\", run_name)\n",
    "    if not os.path.exists(outdir):\n",
    "        os.makedirs(outdir)\n",
    "    model_path = os.path.join(outdir,\"best_model.keras\")\n",
    "    mc = ModelCheckpoint(model_path, monitor='val_loss', \\\n",
    "        mode='min', verbose=1, save_best_only=True)\n",
    "    history_path = os.path.join(outdir,'history.json')\n",
    "    hc = SaveHistoryCallback(history_path)\n",
    "    callbacks=[mc,hc, wandb.keras.WandbMetricsLogger()]\n",
    "    return callbacks"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "23f58c76-5e26-4b86-8c5d-eb4bdbccaa9d",
   "metadata": {},
   "source": [
    "## Create tf.data.Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "d9bfc099-40e2-4dc6-b27f-6fc191c63807",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_train_data(data_dir, val_dir, train_frac, target_size, batch_size, augmentation_types, outdir, random_state):\n",
    "  if val_dir is None:\n",
    "      train_ds, val_ds = tf.keras.utils.image_dataset_from_directory(\n",
    "        data_dir,\n",
    "        validation_split=1-train_frac,\n",
    "        subset=\"both\",\n",
    "        color_mode='rgb',\n",
    "        seed=random_state,\n",
    "        image_size=target_size,\n",
    "        batch_size=None)\n",
    "  else:\n",
    "      train_ds = tf.keras.utils.image_dataset_from_directory(\n",
    "        data_dir,\n",
    "        color_mode='rgb',\n",
    "        seed=random_state,\n",
    "        image_size=target_size,\n",
    "        batch_size=None)\n",
    "\n",
    "      val_ds = tf.keras.utils.image_dataset_from_directory(\n",
    "            val_dir,\n",
    "            color_mode='rgb',\n",
    "            seed=random_state,\n",
    "            image_size=target_size,\n",
    "            batch_size=None)\n",
    "\n",
    "  class_names = train_ds.class_names\n",
    "  print(\"Training dataset class names are :\",class_names)\n",
    "\n",
    "  AUTOTUNE = tf.data.AUTOTUNE\n",
    "\n",
    "  train_ds = (\n",
    "          train_ds\n",
    "          .shuffle(1000)\n",
    "          .map(lambda x, y: augment_custom(x, y, augmentation_types, seed=random_state), num_parallel_calls=AUTOTUNE)\n",
    "          #.cache()\n",
    "          .batch(batch_size)\n",
    "          .prefetch(buffer_size=AUTOTUNE)\n",
    "          )\n",
    "\n",
    "  val_ds = (\n",
    "          val_ds\n",
    "          .map(rescale, num_parallel_calls=AUTOTUNE)\n",
    "          #.cache()\n",
    "          .batch(batch_size)\n",
    "          .prefetch(buffer_size=AUTOTUNE)\n",
    "          )\n",
    "\n",
    "  return train_ds, val_ds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "0e2804a5-ca55-4cd1-a6ce-823b805d5f91",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Found 15229 files belonging to 2 classes.\n",
      "Using 12184 files for training.\n",
      "Using 3045 files for validation.\n",
      "Training dataset class names are : ['NonRings', 'Rings']\n"
     ]
    }
   ],
   "source": [
    "train_ds, val_ds = get_train_data(args.images, args.val_dir, args.train_frac, args.target_size, args.batch_size,\\\n",
    "                                  args.augmentation_types, args.output_dir, args.random_state)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "3aeddb1a-35e1-4a88-85ea-1fde093d83b0",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mlinn-official\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "wandb version 0.19.2 is available!  To upgrade, please run:\n",
       " $ pip install wandb --upgrade"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Tracking run with wandb version 0.15.11"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Run data is saved locally in <code>/home/linn/2024/dec/aiml-handson/galactic-rings/wandb/run-20250109_144948-8e6g66t9</code>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Syncing run <strong><a href='https://wandb.ai/linn-official/aiml-handson/runs/8e6g66t9' target=\"_blank\">fragrant-shape-7</a></strong> to <a href='https://wandb.ai/linn-official/aiml-handson' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View project at <a href='https://wandb.ai/linn-official/aiml-handson' target=\"_blank\">https://wandb.ai/linn-official/aiml-handson</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View run at <a href='https://wandb.ai/linn-official/aiml-handson/runs/8e6g66t9' target=\"_blank\">https://wandb.ai/linn-official/aiml-handson/runs/8e6g66t9</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-01-09 14:50:06.142372: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_0' with dtype string and shape [12184]\n",
      "\t [[{{node Placeholder/_0}}]]\n",
      "2025-01-09 14:50:06.143000: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_4' with dtype int32 and shape [12184]\n",
      "\t [[{{node Placeholder/_4}}]]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-01-09 14:50:10.550579: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:954] layout failed: INVALID_ARGUMENT: Size of values 0 does not match size of permutation 4 @ fanin shape insequential/dropout/dropout/SelectV2-2-TransposeNHWCToNCHW-LayoutOptimizer\n",
      "2025-01-09 14:50:12.780207: I tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:424] Loaded cuDNN version 8902\n",
      "2025-01-09 14:50:16.787475: I tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:637] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.\n",
      "2025-01-09 14:50:17.910929: I tensorflow/compiler/xla/service/service.cc:169] XLA service 0x7fe8d837d990 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:\n",
      "2025-01-09 14:50:17.910983: I tensorflow/compiler/xla/service/service.cc:177]   StreamExecutor device (0): NVIDIA A100-PCIE-40GB, Compute Capability 8.0\n",
      "2025-01-09 14:50:18.240301: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.\n",
      "2025-01-09 14:50:19.617198: I ./tensorflow/compiler/jit/device_compiler.h:180] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "762/762 [==============================] - ETA: 0s - loss: 3.1502 - precision: 0.0976 - recall: 0.0935 - auc_pr: 0.0915     "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-01-09 14:50:58.921923: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_0' with dtype string and shape [3045]\n",
      "\t [[{{node Placeholder/_0}}]]\n",
      "2025-01-09 14:50:58.923012: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_4' with dtype int32 and shape [3045]\n",
      "\t [[{{node Placeholder/_4}}]]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 1: val_loss improved from inf to 2.29671, saving model to output/fragrant-shape-7/best_model.keras\n",
      "762/762 [==============================] - 63s 49ms/step - loss: 3.1502 - precision: 0.0976 - recall: 0.0935 - auc_pr: 0.0915 - val_loss: 2.2967 - val_precision: 0.1341 - val_recall: 0.2881 - val_auc_pr: 0.1140\n",
      "Epoch 2/2\n",
      "761/762 [============================>.] - ETA: 0s - loss: 1.9817 - precision: 0.1092 - recall: 0.0525 - auc_pr: 0.0983          \n",
      "Epoch 2: val_loss improved from 2.29671 to 1.81954, saving model to output/fragrant-shape-7/best_model.keras\n",
      "762/762 [==============================] - 38s 49ms/step - loss: 1.9815 - precision: 0.1089 - recall: 0.0525 - auc_pr: 0.0982 - val_loss: 1.8195 - val_precision: 0.1333 - val_recall: 0.0247 - val_auc_pr: 0.0944\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "Waiting for W&B process to finish... <strong style=\"color:green\">(success).</strong>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "wandb: WARNING Source type is set to 'repo' but some required information is missing from the environment. A job will not be created from this run. See https://docs.wandb.ai/guides/launch/create-job\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "de7daf275a1a42c2b3542a562e320254",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "VBox(children=(Label(value='0.005 MB of 0.005 MB uploaded (0.000 MB deduped)\\r'), FloatProgress(value=1.0, max…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<style>\n",
       "    table.wandb td:nth-child(1) { padding: 0 10px; text-align: left ; width: auto;} td:nth-child(2) {text-align: left ; width: 100%}\n",
       "    .wandb-row { display: flex; flex-direction: row; flex-wrap: wrap; justify-content: flex-start; width: 100% }\n",
       "    .wandb-col { display: flex; flex-direction: column; flex-basis: 100%; flex: 1; padding: 10px; }\n",
       "    </style>\n",
       "<div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>epoch/auc_pr</td><td>▁█</td></tr><tr><td>epoch/epoch</td><td>▁█</td></tr><tr><td>epoch/learning_rate</td><td>▁▁</td></tr><tr><td>epoch/loss</td><td>█▁</td></tr><tr><td>epoch/precision</td><td>▁█</td></tr><tr><td>epoch/recall</td><td>█▁</td></tr><tr><td>epoch/val_auc_pr</td><td>█▁</td></tr><tr><td>epoch/val_loss</td><td>█▁</td></tr><tr><td>epoch/val_precision</td><td>█▁</td></tr><tr><td>epoch/val_recall</td><td>█▁</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>epoch/auc_pr</td><td>0.09821</td></tr><tr><td>epoch/epoch</td><td>1</td></tr><tr><td>epoch/learning_rate</td><td>0.001</td></tr><tr><td>epoch/loss</td><td>1.98147</td></tr><tr><td>epoch/precision</td><td>0.10893</td></tr><tr><td>epoch/recall</td><td>0.05252</td></tr><tr><td>epoch/val_auc_pr</td><td>0.0944</td></tr><tr><td>epoch/val_loss</td><td>1.81954</td></tr><tr><td>epoch/val_precision</td><td>0.13333</td></tr><tr><td>epoch/val_recall</td><td>0.02469</td></tr></table><br/></div></div>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View run <strong style=\"color:#cdcd00\">fragrant-shape-7</strong> at: <a href='https://wandb.ai/linn-official/aiml-handson/runs/8e6g66t9' target=\"_blank\">https://wandb.ai/linn-official/aiml-handson/runs/8e6g66t9</a><br/>Synced 6 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Find logs at: <code>./wandb/run-20250109_144948-8e6g66t9/logs</code>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "wandb.init(project=\"aiml-handson\", anonymous=\"allow\")\n",
    "callbacks = create_callbacks(wandb.run.name)\n",
    "history = model.fit(train_ds, validation_data=val_ds, epochs=args.epochs, shuffle=True, callbacks=callbacks)\n",
    "wandb.finish()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6fc9f081-dd24-466c-874f-4978af71d5c9",
   "metadata": {},
   "source": [
    "# Results"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ed07eed9-78a7-4ddf-9478-9e58b4889603",
   "metadata": {},
   "source": [
    "[Training history comparison](https://wandb.ai/linn-official/Ring_Train/reports/val_loss-25-01-04-10-29-05---VmlldzoxMDgxMDUwMA?accessToken=aexjbxpy9q24ikedi0vf7k3edss8jszlldy15or6blpt5f3kaxxps6lt8ql3qgg5)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8056ffed-b234-4203-8415-d9dd53f360c2",
   "metadata": {},
   "source": [
    "### Download trained model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "2629ee40-bc37-42dc-b7c8-8132c0e4416a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading...\n",
      "From (original): https://drive.google.com/uc?id=1m4oVnlxAC9MxXZsQU9oEfAeFYVzmEtsA\n",
      "From (redirected): https://drive.google.com/uc?id=1m4oVnlxAC9MxXZsQU9oEfAeFYVzmEtsA&confirm=t&uuid=983f762e-92af-4460-a445-1751d3c86ad3\n",
      "To: /home/linn/2024/dec/aiml-handson/galactic-rings/clean-shadow-84-slim.zip\n",
      "100%|██████████████████████████████████████| 2.46G/2.46G [01:59<00:00, 20.6MB/s]\n"
     ]
    }
   ],
   "source": [
    "!gdown --fuzzy \"https://drive.google.com/file/d/1m4oVnlxAC9MxXZsQU9oEfAeFYVzmEtsA/view?usp=sharing\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "2f1e9a8b-53bd-4d23-8028-e73d9c4ffb16",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture\n",
    "!unzip clean-shadow-84-slim.zip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "502085c4-f8be-4060-9104-60b111b63ad6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "best_model.h5  history.json  train_filenames.csv  validation_filenames.csv\n"
     ]
    }
   ],
   "source": [
    "!ls clean-shadow-84-slim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "3810947c-7c55-4fb3-acfc-ad27f2f3c885",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "clean-shadow-84-slim.zip: OK\n"
     ]
    }
   ],
   "source": [
    "!echo \"6f3c1140c1b4a7f0cb02ceecaf5cb030  clean-shadow-84-slim.zip\" | md5sum -c"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8af0ad20-857e-4926-8a0a-15934c5ac1ff",
   "metadata": {},
   "source": [
    "### Download training data with visual selections"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "e2a55679-caa8-40f0-aebc-d53109a5e66d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading...\n",
      "From (original): https://drive.google.com/uc?id=1YPWqAfXbltU56Biz7j2_nyAWlyODEdUA\n",
      "From (redirected): https://drive.google.com/uc?id=1YPWqAfXbltU56Biz7j2_nyAWlyODEdUA&confirm=t&uuid=f38d2b45-3758-472d-952d-30378fe263e4\n",
      "To: /home/linn/2024/dec/aiml-handson/galactic-rings/E11dash.zip\n",
      "100%|██████████████████████████████████████| 67.7M/67.7M [00:01<00:00, 34.5MB/s]\n"
     ]
    }
   ],
   "source": [
    "!gdown --fuzzy \"https://drive.google.com/file/d/1YPWqAfXbltU56Biz7j2_nyAWlyODEdUA/view?usp=sharing\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "e826df79-2e4a-44af-b720-febc92b5c8e3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "E11dash.zip: OK\n"
     ]
    }
   ],
   "source": [
    "!echo \"4925fae1310595ce8fa7f47e98180953  E11dash.zip\" | md5sum -c"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "8d572cd0-c28a-4e0b-840a-fa597cce3f45",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture\n",
    "!unzip E11dash.zip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "1971536d-5834-4649-8ae6-68bb8042b8ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "import argparse\n",
    "from tensorflow.keras import layers\n",
    "from tensorflow.keras.models import load_model\n",
    "from sklearn.metrics import confusion_matrix\n",
    "from sklearn.metrics import accuracy_score, precision_score, recall_score, \\\n",
    "f1_score, roc_auc_score, roc_curve, balanced_accuracy_score, brier_score_loss, \\\n",
    "average_precision_score, fbeta_score, matthews_corrcoef, auc, precision_recall_curve, \\\n",
    "classification_report\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "f0a4bf9f-1754-4f03-a5c5-7467950e8b26",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture\n",
    "!pip install scikit-learn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "8e143434-1996-46b0-966e-16f3c99cdfb8",
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_parser = argparse.ArgumentParser()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "b9dbe147-87d8-4ee1-8d32-c76096e4f27e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "_StoreTrueAction(option_strings=['--write'], dest='write', nargs=0, const=True, default=False, type=None, choices=None, required=False, help='Switch to enable writing results to disk', metavar=None)"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "eval_parser.add_argument('--trained-model', required=True, help=\"Path to trained model\")\n",
    "eval_parser.add_argument('--test-dir',  help=\"Directory containing validation or test images sorted into respective classes\")\n",
    "eval_parser.add_argument('--saved-ds',  default=False, help=\"Boolean flag that is true if test_dir points to a tf.data.Dataset object\")\n",
    "eval_parser.add_argument('--threshold', type=float,  default=0.5, help=\"Decimal threshold to use for creating CM, etc.\")\n",
    "eval_parser.add_argument('--write', action=\"store_true\", help=\"Switch to enable writing results to disk\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "f2438de5-c2db-43fd-88bf-596886263b6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_args = eval_parser.parse_args(['--test-dir','E11dash/test/','--trained-model','clean-shadow-84-slim/best_model.h5'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "c8ea0041-c2a9-4bb8-8b2b-7e050180fcff",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_dir = eval_args.test_dir\n",
    "model_path = eval_args.trained_model\n",
    "batch_size = 64"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "dae3d9f1-20be-4f8d-8f34-76288ff5abad",
   "metadata": {},
   "outputs": [],
   "source": [
    "img_height, img_width = args.target_size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "4534cdc2-4594-4d56-90fd-2bcb1c1c0840",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Found 2353 files belonging to 2 classes.\n"
     ]
    }
   ],
   "source": [
    "test_ds = tf.keras.utils.image_dataset_from_directory(\n",
    "  test_dir,\n",
    "  #color_mode='grayscale',\n",
    "  shuffle=False,\n",
    "  image_size=(img_height, img_width),\n",
    "  batch_size=None)\n",
    "\n",
    "filenames = test_ds.file_paths\n",
    "\n",
    "normalization_layer = layers.Rescaling(1./255)\n",
    "\n",
    "test_ds = test_ds.map(lambda x, y: (normalization_layer(x), y))\n",
    "\n",
    "labels = test_ds.map(lambda _, label: label)\n",
    "\n",
    "AUTOTUNE = tf.data.AUTOTUNE\n",
    "test_ds = test_ds.batch(batch_size).cache().prefetch(buffer_size=AUTOTUNE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "cd457aba-8055-4329-9a16-0a7554073611",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = load_model(model_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "0fe8af67-73e3-4b1d-a992-9d6fd8c7301e",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-01-09 14:57:19.489971: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_4' with dtype int32 and shape [2353]\n",
      "\t [[{{node Placeholder/_4}}]]\n",
      "2025-01-09 14:57:19.490582: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_4' with dtype int32 and shape [2353]\n",
      "\t [[{{node Placeholder/_4}}]]\n",
      "2025-01-09 14:57:20.181974: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_4' with dtype int32 and shape [2353]\n",
      "\t [[{{node Placeholder/_4}}]]\n",
      "2025-01-09 14:57:20.182448: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_4' with dtype int32 and shape [2353]\n",
      "\t [[{{node Placeholder/_4}}]]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "37/37 [==============================] - 3s 60ms/step\n"
     ]
    }
   ],
   "source": [
    "ground_truth = list(labels.as_numpy_iterator())\n",
    "predictions = model.predict(test_ds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "f4a5f0bf-8c3b-4a39-bd5d-317baf0d3387",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using a classification threshold 0.5\n"
     ]
    }
   ],
   "source": [
    "threshold = eval_args.threshold\n",
    "print(\"Using a classification threshold\", threshold)\n",
    "\n",
    "predicted_labels = [1 if pred >= threshold else 0 for pred in predictions]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "460b7682-0188-42f2-a338-c4617e2fcbc3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Confusion Matrix:\n",
      "[[2093   35]\n",
      " [  12  213]]\n"
     ]
    }
   ],
   "source": [
    "confusion_mtx = confusion_matrix(ground_truth, predicted_labels)\n",
    "print(\"Confusion Matrix:\")\n",
    "print(confusion_mtx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "dbe2d8ff-9d69-4492-9f45-f31be3d21cab",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.9800254993625159\n",
      "F1-score: 0.9006342494714588\n",
      "ROC AUC Score: 0.995064745196324\n",
      "PR AUC Score: 0.9219745599446181\n",
      "Brier score 0.015241051341857327\n",
      "Average precision score 0.9243849287069954\n",
      "Classification Report\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "    NonRings       0.99      0.98      0.99      2128\n",
      "       Rings       0.86      0.95      0.90       225\n",
      "\n",
      "    accuracy                           0.98      2353\n",
      "   macro avg       0.93      0.97      0.94      2353\n",
      "weighted avg       0.98      0.98      0.98      2353\n",
      "\n",
      "False Positive Rate (FPR): 0.01644736842105263\n",
      "TNR or Specificity: 0.9835526315789473\n",
      "G-Mean: 0.9649334128467467\n",
      "F-beta score beta=2 0.9277003484320558\n",
      "F1 Score After Thresholding: 0.9006342494714588\n",
      "Matthew Correlation Coefficient: 0.8908621868910627\n",
      "Balanced Accuracy: 0.9651096491228071\n"
     ]
    }
   ],
   "source": [
    "accuracy = accuracy_score(ground_truth, predicted_labels)\n",
    "f1 = f1_score(ground_truth, predicted_labels)\n",
    "try:\n",
    "    roc_auc = roc_auc_score(ground_truth, predictions)\n",
    "except:\n",
    "    print(\"Setting roc_auc to be -1 as it is not defined\")\n",
    "    roc_auc = -1\n",
    "precisions, recalls, thresholds = precision_recall_curve(ground_truth, predictions)\n",
    "pr_auc = auc(recalls, precisions)\n",
    "brier_score = brier_score_loss(ground_truth, predictions)\n",
    "avg_precision = average_precision_score(ground_truth, predictions)\n",
    "report = classification_report(ground_truth, predicted_labels, target_names=['NonRings', 'Rings'])\n",
    "\n",
    "tn, fp, fn, tp = confusion_mtx.ravel()\n",
    "fpr = fp / (fp + tn)\n",
    "specificity = tn / (fp + tn)\n",
    "precision = precision_score(ground_truth, predicted_labels)\n",
    "recall = recall_score(ground_truth, predicted_labels)\n",
    "bal_acc = balanced_accuracy_score(ground_truth, predicted_labels)\n",
    "matthews_corr_coef = matthews_corrcoef(ground_truth, predicted_labels)\n",
    "beta = 2\n",
    "fbeta = fbeta_score(ground_truth, predicted_labels, beta=beta)\n",
    "print(\"Accuracy:\", accuracy)\n",
    "print(\"F1-score:\", f1)\n",
    "print(\"ROC AUC Score:\", roc_auc)\n",
    "print(\"PR AUC Score:\", pr_auc)\n",
    "print(\"Brier score\", brier_score)\n",
    "print(\"Average precision score\", avg_precision)\n",
    "print(\"Classification Report\")\n",
    "print(report)\n",
    "print(\"False Positive Rate (FPR):\", fpr)\n",
    "print(\"TNR or Specificity:\", specificity)\n",
    "print(\"G-Mean:\", np.sqrt(recall * specificity))\n",
    "print(f\"F-beta score beta={beta}\", fbeta)\n",
    "print(\"F1 Score After Thresholding: {}\".format( f1_score(ground_truth, predicted_labels)))\n",
    "print(\"Matthew Correlation Coefficient:\", matthews_corr_coef)\n",
    "print(\"Balanced Accuracy:\", bal_acc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c04dfff9-d34a-4046-937a-a9e48d7f64a1",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "solar311",
   "language": "python",
   "name": "solar311"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}