{ "cells": [ { "cell_type": "markdown", "id": "420b6034-fe36-483d-b676-a6eeea2e80c0", "metadata": {}, "source": [ "## Introduction" ] }, { "cell_type": "markdown", "id": "58b3e6f3-7476-47bc-ade6-d5e020c78055", "metadata": {}, "source": [ "**Linn Abraham**\n", "* Final year PhD student - Thesis on the Application of Deep Learning to different problems in Astronomy\n", "* Here in IUCAA as part of the ISRO RESPOND project - **Solar Flares: Physics and Forecasting for better understanding of Space Weather**" ] }, { "cell_type": "markdown", "id": "b6a7da9c-fdfa-40fe-8daf-2313891cbaef", "metadata": {}, "source": [ "* [Automated Detection of Galactic Rings from Sloan Digital Sky Survey Images](https://doi.org/10.3847/1538-4357/ad856d)\n", " * Co-authors: Sheelu Abraham, Ajit Kembhavi, Ninan Sajeeth Philip, Sudhanshu Barwaye and others\n", "* Source detection for H1 galaxies from radio data\n", " * Co-authors: Kshitij H Thorat, Arun K. Aniyan and others\n", "* Solar active region classification and interpretability using Deep Learning\n", " * Co-authors: Durgesh Tripathi, Vishal Upendran and others " ] }, { "cell_type": "markdown", "id": "ef56eda3-1c3d-4d5e-9cb8-c07e5b55c960", "metadata": {}, "source": [ "Contact:\n", "* E-mail: linn.official@gmail.com\n", "* Linkedin - [www.linkedin.com/in/linn-abraham/](https://www.linkedin.com/in/linn-abraham/)\n", "* Github - [github.com/linnabraham](https://github.com/linnabraham)" ] }, { "cell_type": "markdown", "id": "cd31bc4e-19ac-4dd2-aa12-a56f55d88cd7", "metadata": {}, "source": [ "## Galaxy Ring Detection" ] }, { "cell_type": "markdown", "id": "7b65c0ce-bafb-429f-8cc6-e4f711cc53e9", "metadata": {}, "source": [ "### What is a ring in a galaxy?" ] }, { "cell_type": "markdown", "id": "a163a342-6059-4bdc-bf81-711cd212e168", "metadata": {}, "source": [ "Hubble Tuning Fork" ] }, { "cell_type": "markdown", "id": "803eab13-99a1-4c0c-a579-cc8b49076c85", "metadata": {}, "source": [ "" ] }, { "cell_type": "markdown", "id": "88943f8c-26b3-473b-8403-5af31899c600", "metadata": {}, "source": [ "" ] }, { "cell_type": "markdown", "id": "d2ec5086-94c3-4058-96a9-0d18bacdb5f2", "metadata": {}, "source": [ "* The understanding that rings are important morphological features for studying galaxy formation and evolution came later\n", "* de Vaucouleurs introduced the idea of a 3-dimensional classification volume, of which Hubble’s tuning fork is a sort of cross section parallel to one axis." ] }, { "cell_type": "markdown", "id": "3e2d214e-2420-44c0-b2f1-dc75522bd3c0", "metadata": {}, "source": [ "" ] }, { "cell_type": "markdown", "id": "81c5cde1-90cd-445c-aa2b-f68ec1a33918", "metadata": {}, "source": [ "### SDSS Image cutouts" ] }, { "cell_type": "markdown", "id": "2f73e92f-aca5-468b-8f00-87293d3d2f41", "metadata": {}, "source": [ "" ] }, { "cell_type": "markdown", "id": "d1e10c5e-3098-4874-8d7b-695528cc2cc4", "metadata": {}, "source": [ "## Resources" ] }, { "cell_type": "markdown", "id": "9eceee41-daf9-442a-a976-a8f09b1afd14", "metadata": {}, "source": [ "* ChatGPT, stackoverflow, blogs and video tutorials are helpful.\n", "* Return back to books to check your understanding and build up basic knowledge\n", "* Read papers for understanding the latest architectures and techniques." ] }, { "cell_type": "markdown", "id": "3f345aff-2099-4791-b9af-c8b353559bc9", "metadata": {}, "source": [ "#### Beginner level" ] }, { "cell_type": "markdown", "id": "0892f8b6-f96d-4e61-836a-c7f985117f35", "metadata": {}, "source": [ "" ] }, { "cell_type": "markdown", "id": "c37077a4-9e15-490c-9b90-4d430355230e", "metadata": {}, "source": [ "#### More advanced" ] }, { "cell_type": "markdown", "id": "940f7d14-295a-4e89-9bf3-cc377379bc82", "metadata": {}, "source": [ "" ] }, { "cell_type": "markdown", "id": "d76f4222-24c0-4c9d-a8a0-3fa00936d846", "metadata": {}, "source": [ "* [Compilation of resources Topic Wise](https://github.com/linnabraham/ml-tutorials/blob/master/Resources.md)\n", "* [Beginner's guide to machine learning - My Blog Entry](https://machinelearningmaniac.blogspot.com/2024/05/beginners-guide-to-machine-learning.html)\n", "* [3Blue1Brown Playlist on Neural Networks](https://www.youtube.com/watch?v=aircAruvnKk&list=PLZHQObOWTQDNU6R1_67000Dx_ZCJB-3pi)" ] }, { "cell_type": "markdown", "id": "bcfc8d71-9409-4d43-95cb-2f75db18f6a0", "metadata": {}, "source": [ "## Tips and Tricks\n", "Powered by own life experience" ] }, { "cell_type": "markdown", "id": "0631f3e1-bab1-47ce-b245-485a7a9a2330", "metadata": {}, "source": [ "### Choosing frameworks" ] }, { "cell_type": "markdown", "id": "1be5acc4-db45-468b-aba4-370cb5175c8d", "metadata": {}, "source": [ "* Tensorflow and PyTorch are the two major frameworks or platforms for deep learning in Python\n", "* Learning to write your own training loops is immensely useful\n", "* model.fit in tensorflow vs writing your own custom training loop in PyTorch\n", "* tf.data.Datasets modelled after PyTorch dataset\n", "* GradientTape api and custom training loops exist in tensorflow as well\n", "* (Asitang's Lecture in IUCAA 2023)" ] }, { "cell_type": "markdown", "id": "81f467df-b635-4ac2-8e3b-193a8f972052", "metadata": {}, "source": [ "### Word about IDEs" ] }, { "cell_type": "markdown", "id": "17027e00-1799-4dbd-a4dc-fbb0bd44a78b", "metadata": {}, "source": [ "* Notebooks have two good uses\n", " * Inline images (also Mardown support) which makes them good to present your code to yourself or others\n", " * Keep variables in memory until you quit the program\n", " * Reasons not to use it - [I Don't like Jupyter Notebooks](https://youtu.be/7jiPeIFXb6U)\n", "* Be comfortable with ViM + script workflow:\n", " * Save figures to disk and view them through SSH fs or something\n", " * Pylint\n", " * PDB for debugging\n", "* Once you know these things also try out VSCode and others" ] }, { "cell_type": "markdown", "id": "2cda9b90-6dd8-4b11-87c9-1f7e12d85ce9", "metadata": {}, "source": [ "### Virtual environment" ] }, { "cell_type": "markdown", "id": "dec78376-6034-4bd0-b579-590138c61616", "metadata": {}, "source": [ "* Always work in a virtual environment\n", "* Conda and Pip environments serve different use cases - https://machinelearningmaniac.blogspot.com/2024/05/a-guide-to-using-conda-for-managing.html\n", "* Version control your environment using environment.yml or requirements.txt for reproducibility" ] }, { "cell_type": "markdown", "id": "621549d2-853c-4edd-9e9e-b27fcf46667b", "metadata": {}, "source": [ "### General Tips" ] }, { "cell_type": "markdown", "id": "fca19442-e011-4333-a32a-0eaa1653a14d", "metadata": {}, "source": [ "* Never re-invent the wheel; always resort to well developed libraries\n", " * Scikit-learn, scipy, pandas, scikit-image, OpenCV \n", "* For things that don't exist elsewhere put up your custom utility libraries in github and install them into your own environment using Pip\n", "* Writing classes might help at some point when you data structure and functionalities get complex and interdependent" ] }, { "cell_type": "markdown", "id": "4595dc8f-9d54-46fe-9c11-46c7eec6eeb3", "metadata": {}, "source": [ "### Code version control" ] }, { "cell_type": "markdown", "id": "abb874a7-6051-47a1-89c6-1e299ca661ef", "metadata": {}, "source": [ "* Any ML person needs to use git - the early the better\n", "* On the command line first\n", "* Not just to avoid directories named project_v1, project_v2, project_v3\n", "* Very versatile tool\n", "* [Missing semester lecture](https://youtu.be/2sjqTHE0zok)" ] }, { "cell_type": "markdown", "id": "8b0c7779-9888-4d00-a17c-974c3f1b1aa7", "metadata": {}, "source": [ "### Clone git repo" ] }, { "cell_type": "code", "execution_count": 1, "id": "1163ae14-3856-486a-a0d5-bbc946c242a9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Cloning into 'galactic-rings'...\n", "remote: Enumerating objects: 587, done.\u001b[K\n", "remote: Counting objects: 100% (65/65), done.\u001b[K\n", "remote: Compressing objects: 100% (47/47), done.\u001b[K\n", "remote: Total 587 (delta 21), reused 50 (delta 10), pack-reused 522 (from 1)\u001b[K\n", "Receiving objects: 100% (587/587), 9.48 MiB | 5.31 MiB/s, done.\n", "Resolving deltas: 100% (328/328), done.\n" ] } ], "source": [ "!git clone -b handson https://github.com/linnabraham/galactic-rings.git" ] }, { "cell_type": "code", "execution_count": 2, "id": "dd6e9703-6ca0-4cfb-b336-3f231e69249b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/home/linn/2024/dec/aiml-handson/galactic-rings\n" ] } ], "source": [ "%cd galactic-rings" ] }, { "cell_type": "code", "execution_count": 3, "id": "f9f46dcf-e6a3-41fa-86ec-861b00c249d1", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ ".\n", "├── alexnet_utils\n", "├── create_pred_catalogue.py\n", "├── data\n", "├── environment.yml\n", "├── evaluate_alexnet.py\n", "├── figures\n", "├── helpers\n", "├── notebooks\n", "├── optional-deps.yml\n", "├── plot_hists.py\n", "├── predict_alexnet.py\n", "├── predict_single.py\n", "├── README.md\n", "├── requirements.txt\n", "├── separate_images.py\n", "├── torch\n", "├── train_alexnet_kfold.py\n", "├── train_alexnet.py\n", "└── train_lenet.py\n", "\n", "6 directories, 13 files\n" ] } ], "source": [ "!tree -L 1" ] }, { "cell_type": "markdown", "id": "1e2841a5-7248-42a4-99d9-11864290470d", "metadata": {}, "source": [ "## Data" ] }, { "cell_type": "markdown", "id": "fe834649-fc86-416f-9c29-672c07f7edf9", "metadata": {}, "source": [ "### Domain knowledge is important" ] }, { "cell_type": "markdown", "id": "c22ccf5a-8f4e-4bd9-9f06-54c566daae32", "metadata": {}, "source": [ "* Pan-STARRS vs SDSS\n", "* Pan-STARRS was much less suited for galaxy morphology analysis" ] }, { "cell_type": "markdown", "id": "f6764e17-cdfd-4e19-b3ba-70f95d332de4", "metadata": {}, "source": [ "" ] }, { "cell_type": "markdown", "id": "83a3ad88-af54-4b53-b480-03d2c562d10e", "metadata": {}, "source": [ "### Quantity of your data" ] }, { "cell_type": "markdown", "id": "9a645321-8992-49d9-8d31-ef755bb25fde", "metadata": {}, "source": [ "* Used the two largest catalogs of galaxy morphology available\n", "* Buta (2017) for Rings\n", "* Nair. et. al. (2010) for Non-Rings" ] }, { "cell_type": "markdown", "id": "dcb50c6b-317f-468b-a776-0079654cf0f5", "metadata": {}, "source": [ "#### Sample of rings" ] }, { "cell_type": "markdown", "id": "ff5722de-969c-4fa0-bfe8-67bf8f8f798b", "metadata": {}, "source": [ "" ] }, { "cell_type": "markdown", "id": "1e340be7-b250-424b-b07d-1acb142543bc", "metadata": {}, "source": [ "#### Rings: Buta (2017)\n", "* 3962 Rings\n", "* Citizen science + Expert analysis" ] }, { "cell_type": "markdown", "id": "c51db562-d7e7-4499-b579-e1fa9cbb4172", "metadata": {}, "source": [ "#### Non-Rings: Nair et. al. (2010) " ] }, { "cell_type": "markdown", "id": "ab04c00b-5c45-4c92-abf8-e2cd91f974ad", "metadata": {}, "source": [ "* 14,034 galaxies with both rings and non-rings\n", "* Ring classifications - nuclear, inner, outer, pseudo-outer (R1/R2), and collisional.\n", "* Brighter than 16 mag\n", "* Between 0.01 and 0.1 redshift" ] }, { "cell_type": "markdown", "id": "0519d146-ff3a-4fcf-b2ca-d37506963959", "metadata": {}, "source": [ "### Avoid any possible bias" ] }, { "cell_type": "markdown", "id": "70a9dba1-89b7-431f-ae71-a14d8b426321", "metadata": {}, "source": [ "* Apply the same selection criteria to Rings as in Non-Rings" ] }, { "cell_type": "markdown", "id": "2b9210a8-6417-4029-b8fa-8606d6c8a376", "metadata": {}, "source": [ "### Quality of your data matters" ] }, { "cell_type": "markdown", "id": "2a8c10d3-0a3d-48f6-bf20-5c3be7bec345", "metadata": {}, "source": [ "Looked through the large sample by eye to remove the \"bad\" data\n", "* Multiple galaxies\n", "* Outliers\n", "* Ambigous types\n", "\n", "Finally left with 1122 rings and 10,639 non-rings" ] }, { "cell_type": "markdown", "id": "6449c260-3993-464a-af0d-4cdb0bab1f09", "metadata": {}, "source": [ "### Data version control" ] }, { "cell_type": "markdown", "id": "88179b53-cd26-418c-a7e9-345e5f69eead", "metadata": {}, "source": [ "* Use [DVC](https://dvc.org/)\n", "* md5sum hashes for a directory structure + git" ] }, { "cell_type": "markdown", "id": "8b49e097-351a-49cc-ab24-8c3dd8712cd6", "metadata": {}, "source": [ "### Script for downloading galaxy image cutouts in bulk" ] }, { "cell_type": "markdown", "id": "c2a7c542-9a1e-43e1-9a39-692185bdabfe", "metadata": {}, "source": [ "* Multi-threading (not processing) for I/O or network intensive processes" ] }, { "cell_type": "code", "execution_count": 4, "id": "9edb23f7-34df-429d-ac32-1fb80e89c726", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Downloading...\n", "From (original): https://drive.google.com/uc?id=1RlPl3WD4JDz5N-kx5g00YveLtZ43vPw-\n", "From (redirected): https://drive.google.com/uc?id=1RlPl3WD4JDz5N-kx5g00YveLtZ43vPw-&confirm=t&uuid=c4a3b2e1-99ee-41b4-ab31-ef057d60ecfc\n", "To: /home/linn/2024/dec/aiml-handson/galactic-rings/galaxies.tar.gz\n", "100%|██████████████████████████████████████| 77.6M/77.6M [00:06<00:00, 12.5MB/s]\n" ] } ], "source": [ "!gdown --fuzzy \"https://drive.google.com/file/d/1RlPl3WD4JDz5N-kx5g00YveLtZ43vPw-/view?usp=drive_link\"" ] }, { "cell_type": "code", "execution_count": 5, "id": "35268e0d-cbcd-4d6a-9f93-011df696474d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "galaxies.tar.gz: OK\n" ] } ], "source": [ "!echo \"0f456e955b0b5312aec8d2dd6186218c galaxies.tar.gz\" | md5sum -c" ] }, { "cell_type": "code", "execution_count": 6, "id": "9162f14e-c5d7-4ffd-a7db-79c9de50daee", "metadata": {}, "outputs": [], "source": [ "%%capture\n", "!tar xvzf galaxies.tar.gz -C data/" ] }, { "cell_type": "markdown", "id": "6879f158-1cfb-4461-8dd4-0c2da193a4ae", "metadata": {}, "source": [ "## Setup environment" ] }, { "cell_type": "code", "execution_count": 7, "id": "2158e591-a5cd-433b-9fde-2056b10c2f3e", "metadata": {}, "outputs": [], "source": [ "%%capture\n", "!pip install wandb\n", "!pip install gdown" ] }, { "cell_type": "code", "execution_count": null, "id": "6dcbeca7-37fa-4a2d-8c5b-cb32c96029eb", "metadata": {}, "outputs": [], "source": [ "%%capture\n", "!pip install tensorflow" ] }, { "cell_type": "markdown", "id": "c62cf196-0e53-4414-973b-abd8500321b6", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": 8, "id": "ca7afdeb-08bd-454e-8c6f-21247494fba6", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-01-09 14:48:33.496568: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", "2025-01-09 14:48:35.498549: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", "To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n" ] } ], "source": [ "import os\n", "import json\n", "import tensorflow as tf\n", "from tensorflow.keras.callbacks import ModelCheckpoint, Callback\n", "from alexnet_utils.params import parser, print_arguments\n", "from alexnet_utils.alexnet import AlexNet\n", "import wandb" ] }, { "cell_type": "code", "execution_count": 9, "id": "6f933e3b-e755-46b3-a314-33c90e179e8d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.12.1\n" ] } ], "source": [ "print(tf.__version__)" ] }, { "cell_type": "code", "execution_count": 10, "id": "6cec405a-e353-44d4-8bb0-7eb92cc836c7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.15.11\n" ] } ], "source": [ "print(wandb.__version__)" ] }, { "cell_type": "code", "execution_count": 11, "id": "6ef00404-47dd-497b-821a-919df6a01f88", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]\n" ] } ], "source": [ "print(tf.config.list_physical_devices('GPU'))" ] }, { "cell_type": "markdown", "id": "51e251d2-0ed1-427a-8abb-debc353c4131", "metadata": {}, "source": [ "## Data augmentations" ] }, { "cell_type": "markdown", "id": "b6113e98-a3dd-4dfa-bac8-2164829d9eff", "metadata": {}, "source": [ "* Rescale data to be between (0,1) make it easier for the network to do all the math.\n", "* Custom augmentations\n", " * Rotation\n", " * Flip\n", " * Brightness\n", " * Contrast" ] }, { "cell_type": "code", "execution_count": 12, "id": "b80d4e92-c79a-474d-abda-838edd1bb841", "metadata": {}, "outputs": [], "source": [ "def random_choice(x, size, seed, axis=0, unique=True):\n", " dim_x = tf.cast(tf.shape(x)[axis], tf.int64)\n", " indices = tf.range(0, dim_x, dtype=tf.int64)\n", " sample_index = tf.random.shuffle(indices,seed=seed)[:size]\n", " sample = tf.gather(x, sample_index, axis=axis)\n", "\n", " return sample, sample_index\n", "\n", "def random_int_rot_img(inputs,seed):\n", " angles = tf.constant([1, 2, 3, 4])\n", " # Make a new seed.\n", " #new_seed = tf.random.experimental.stateless_split((seed,seed), num=1)[0, :]\n", " angle = random_choice(angles,1,seed=seed)[0][0]\n", " inputs = tf.image.rot90(inputs, k=angle)\n", "\n", " return inputs\n", "\n", "def rescale(image, label):\n", " image = tf.cast(image, tf.float32)\n", " image = (image / 255.0)\n", "\n", " return image, label\n", "\n", "# define custom augmentations\n", "def augment_custom(images, labels, augmentation_types, seed):\n", " images, labels = rescale(images, labels)\n", " # Make a new seed.\n", " #new_seed = tf.random.experimental.stateless_split((seed,seed), num=1)[0, :]\n", " new_seed = seed\n", " if 'rotation' in augmentation_types:\n", " images = random_int_rot_img(images,seed=seed)\n", " if 'flip' in augmentation_types:\n", " images = tf.image.random_flip_left_right(images, seed=new_seed)\n", " images = tf.image.random_flip_up_down(images, seed=new_seed)\n", " if 'brightness' in augmentation_types:\n", " images = tf.image.random_brightness(images, max_delta=0.2, seed=new_seed)\n", " if 'contrast' in augmentation_types:\n", " images = tf.image.random_contrast(images, lower=0.2, upper=0.5, seed=new_seed)\n", "\n", " return (images, labels)" ] }, { "cell_type": "markdown", "id": "5cac1070-a99d-4f0e-8ea4-366adb5a7df4", "metadata": {}, "source": [ "# Train" ] }, { "cell_type": "markdown", "id": "c5648f3a-aacb-47c9-a76e-59dd6334d251", "metadata": {}, "source": [ "### Define argparse arguments" ] }, { "cell_type": "code", "execution_count": 13, "id": "583cdd74-010a-48df-b1a1-8cee9e12aaae", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "_StoreTrueAction(option_strings=['-retrain', '--retrain'], dest='retrain', nargs=0, const=True, default=False, type=None, choices=None, required=False, help='Whether to continue previous training', metavar=None)" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "parser.add_argument('-images', '--images', required=True, help=\"path containing images of two classes\")\n", "parser.add_argument('-epochs', '--epochs', required=True, type=int, default=50, help=\"num epochs\")\n", "parser.add_argument('-model-path', '--model-path', default=None, help=\"Filepath to save model during training and to load model from when testing\")\n", "parser.add_argument('-val-dir', '--val-dir', default=None, help=\"path containing validation data\")\n", "parser.add_argument('-retrain', '--retrain', action=\"store_true\", help=\"Whether to continue previous training\")" ] }, { "cell_type": "code", "execution_count": 14, "id": "65240c23-7815-4d2f-8f69-b3591737a63d", "metadata": {}, "outputs": [], "source": [ "args = parser.parse_args(['-images', 'data/galaxies', '-epochs', '2'])" ] }, { "cell_type": "code", "execution_count": 15, "id": "3915f923-bfa4-453a-bb48-d12acdec39c3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Arguments and Data Types:\n", " target_size: tuple_type - (240, 240)\n", " batch_size: int - 16\n", " train_frac: float - 0.8\n", " random_state: int - 42\n", " num_classes: int - 2\n", " channels: int - 3\n", " output_dir: None - output\n", " augmentation_types: str - ['flip', 'rotation']\n", " images: None - data/galaxies\n", " epochs: int - 2\n", " model_path: None - None\n", " val_dir: None - None\n" ] } ], "source": [ "print_arguments(parser, args)" ] }, { "cell_type": "markdown", "id": "06e8b9a3-9f69-4e9e-ad74-a6a086425526", "metadata": {}, "source": [ "## Network Architecture" ] }, { "cell_type": "markdown", "id": "406180ab-4444-47db-b2e5-b5d537d375a6", "metadata": {}, "source": [ "* Choice of the architecture should match your data size\n", "* If your architecture of choice is not pre-defined in your favourite framework you (or someone else) has to implement it by hand" ] }, { "cell_type": "markdown", "id": "b37347fb-693f-4ace-9d6f-e9d90fa2d41c", "metadata": {}, "source": [ "### Define AlexNet architecture" ] }, { "cell_type": "code", "execution_count": 16, "id": "b8cf3584-5842-4e7b-a1a7-084b494806cf", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-01-09 14:49:20.192419: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1635] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 38367 MB memory: -> device: 0, name: NVIDIA A100-PCIE-40GB, pci bus id: 0000:37:00.0, compute capability: 8.0\n" ] } ], "source": [ "model = AlexNet.build(width=args.target_size[0], height=args.target_size[1], \\\n", " depth=args.channels, classes=1, reg=0.0002)" ] }, { "cell_type": "code", "execution_count": 17, "id": "f0ced6e9-1d13-4934-bf8a-1df94450ab2b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"sequential\"\n", "_________________________________________________________________\n", " Layer (type) Output Shape Param # \n", "=================================================================\n", " conv2d (Conv2D) (None, 120, 120, 96) 7296 \n", " \n", " activation (Activation) (None, 120, 120, 96) 0 \n", " \n", " batch_normalization (BatchN (None, 120, 120, 96) 384 \n", " ormalization) \n", " \n", " max_pooling2d (MaxPooling2D (None, 59, 59, 96) 0 \n", " ) \n", " \n", " dropout (Dropout) (None, 59, 59, 96) 0 \n", " \n", " conv2d_1 (Conv2D) (None, 59, 59, 256) 614656 \n", " \n", " activation_1 (Activation) (None, 59, 59, 256) 0 \n", " \n", " batch_normalization_1 (Batc (None, 59, 59, 256) 1024 \n", " hNormalization) \n", " \n", " max_pooling2d_1 (MaxPooling (None, 29, 29, 256) 0 \n", " 2D) \n", " \n", " dropout_1 (Dropout) (None, 29, 29, 256) 0 \n", " \n", " conv2d_2 (Conv2D) (None, 29, 29, 384) 885120 \n", " \n", " activation_2 (Activation) (None, 29, 29, 384) 0 \n", " \n", " batch_normalization_2 (Batc (None, 29, 29, 384) 1536 \n", " hNormalization) \n", " \n", " conv2d_3 (Conv2D) (None, 29, 29, 384) 1327488 \n", " \n", " activation_3 (Activation) (None, 29, 29, 384) 0 \n", " \n", " batch_normalization_3 (Batc (None, 29, 29, 384) 1536 \n", " hNormalization) \n", " \n", " conv2d_4 (Conv2D) (None, 29, 29, 256) 884992 \n", " \n", " activation_4 (Activation) (None, 29, 29, 256) 0 \n", " \n", " batch_normalization_4 (Batc (None, 29, 29, 256) 1024 \n", " hNormalization) \n", " \n", " max_pooling2d_2 (MaxPooling (None, 14, 14, 256) 0 \n", " 2D) \n", " \n", " dropout_2 (Dropout) (None, 14, 14, 256) 0 \n", " \n", " flatten (Flatten) (None, 50176) 0 \n", " \n", " dense (Dense) (None, 4096) 205524992 \n", " \n", " activation_5 (Activation) (None, 4096) 0 \n", " \n", " batch_normalization_5 (Batc (None, 4096) 16384 \n", " hNormalization) \n", " \n", " dropout_3 (Dropout) (None, 4096) 0 \n", " \n", " dense_1 (Dense) (None, 4096) 16781312 \n", " \n", " activation_6 (Activation) (None, 4096) 0 \n", " \n", " batch_normalization_6 (Batc (None, 4096) 16384 \n", " hNormalization) \n", " \n", " dropout_4 (Dropout) (None, 4096) 0 \n", " \n", " dense_2 (Dense) (None, 1) 4097 \n", " \n", " activation_7 (Activation) (None, 1) 0 \n", " \n", "=================================================================\n", "Total params: 226,068,225\n", "Trainable params: 226,049,089\n", "Non-trainable params: 19,136\n", "_________________________________________________________________\n", "None\n" ] } ], "source": [ "print(model.summary())" ] }, { "cell_type": "markdown", "id": "84efc9da-7b30-4e9a-a144-4f3f9310943d", "metadata": {}, "source": [ "## Define validation loss, evaluation metrics, optimizer and learning rate" ] }, { "cell_type": "markdown", "id": "8ea67bdd-fd4f-4109-984c-6669459a004d", "metadata": {}, "source": [ "Image Credit: Chollet\n", "" ] }, { "cell_type": "code", "execution_count": 18, "id": "487f8ed4-b038-4052-a28e-89e7f51a0854", "metadata": {}, "outputs": [], "source": [ "classification_threshold = 0.5\n", "\n", "METRICS = [\n", " tf.keras.metrics.Precision(thresholds=classification_threshold,\n", " name='precision'),\n", " tf.keras.metrics.Recall(thresholds=classification_threshold,\n", " name=\"recall\"),\n", " tf.keras.metrics.AUC(num_thresholds=100, curve='PR', name='auc_pr'),\n", "]" ] }, { "cell_type": "code", "execution_count": 19, "id": "4d7783a9-47e9-42f0-ad2d-9df1a91792af", "metadata": {}, "outputs": [], "source": [ "model.compile(loss=\"binary_crossentropy\", optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3)\\\n", " , metrics=METRICS)" ] }, { "cell_type": "markdown", "id": "9b0f1bea-9a2b-4a58-9206-9aaa54a920bc", "metadata": {}, "source": [ "## Define Callbacks" ] }, { "cell_type": "markdown", "id": "62d65eed-2448-4875-9d39-793288feae06", "metadata": {}, "source": [ "* Model checkpoint\n", "* Save history\n", "* Wandb logging" ] }, { "cell_type": "code", "execution_count": 20, "id": "5a211631-dfdb-4772-81cf-313c015ecd35", "metadata": {}, "outputs": [], "source": [ "class SaveHistoryCallback(Callback):\n", " def __init__(self, file_path):\n", " super().__init__()\n", " self.file_path = file_path\n", " self.history = {'loss': [], 'val_loss': [], 'auc_pr':[], 'val_auc_pr':[], 'val_precision':[], 'val_recall':[]}\n", "\n", " def on_epoch_end(self, epoch, logs=None):\n", " self.history['loss'].append(logs.get('loss'))\n", " self.history['val_loss'].append(logs.get('val_loss'))\n", " self.history['auc_pr'].append(logs.get('auc_pr'))\n", " self.history['val_auc_pr'].append(logs.get('val_auc_pr'))\n", " self.history['val_precision'].append(logs.get('val_precision'))\n", " self.history['val_recall'].append(logs.get('val_recall'))\n", "\n", " with open(self.file_path, 'w') as f:\n", " json.dump(self.history, f)" ] }, { "cell_type": "code", "execution_count": 21, "id": "627eb5f6-5301-4830-8b26-7a20426a7a79", "metadata": {}, "outputs": [], "source": [ "def create_callbacks(run_name):\n", " outdir = os.path.join(\"output\", run_name)\n", " if not os.path.exists(outdir):\n", " os.makedirs(outdir)\n", " model_path = os.path.join(outdir,\"best_model.keras\")\n", " mc = ModelCheckpoint(model_path, monitor='val_loss', \\\n", " mode='min', verbose=1, save_best_only=True)\n", " history_path = os.path.join(outdir,'history.json')\n", " hc = SaveHistoryCallback(history_path)\n", " callbacks=[mc,hc, wandb.keras.WandbMetricsLogger()]\n", " return callbacks" ] }, { "cell_type": "markdown", "id": "23f58c76-5e26-4b86-8c5d-eb4bdbccaa9d", "metadata": {}, "source": [ "## Create tf.data.Dataset" ] }, { "cell_type": "code", "execution_count": 22, "id": "d9bfc099-40e2-4dc6-b27f-6fc191c63807", "metadata": {}, "outputs": [], "source": [ "def get_train_data(data_dir, val_dir, train_frac, target_size, batch_size, augmentation_types, outdir, random_state):\n", " if val_dir is None:\n", " train_ds, val_ds = tf.keras.utils.image_dataset_from_directory(\n", " data_dir,\n", " validation_split=1-train_frac,\n", " subset=\"both\",\n", " color_mode='rgb',\n", " seed=random_state,\n", " image_size=target_size,\n", " batch_size=None)\n", " else:\n", " train_ds = tf.keras.utils.image_dataset_from_directory(\n", " data_dir,\n", " color_mode='rgb',\n", " seed=random_state,\n", " image_size=target_size,\n", " batch_size=None)\n", "\n", " val_ds = tf.keras.utils.image_dataset_from_directory(\n", " val_dir,\n", " color_mode='rgb',\n", " seed=random_state,\n", " image_size=target_size,\n", " batch_size=None)\n", "\n", " class_names = train_ds.class_names\n", " print(\"Training dataset class names are :\",class_names)\n", "\n", " AUTOTUNE = tf.data.AUTOTUNE\n", "\n", " train_ds = (\n", " train_ds\n", " .shuffle(1000)\n", " .map(lambda x, y: augment_custom(x, y, augmentation_types, seed=random_state), num_parallel_calls=AUTOTUNE)\n", " #.cache()\n", " .batch(batch_size)\n", " .prefetch(buffer_size=AUTOTUNE)\n", " )\n", "\n", " val_ds = (\n", " val_ds\n", " .map(rescale, num_parallel_calls=AUTOTUNE)\n", " #.cache()\n", " .batch(batch_size)\n", " .prefetch(buffer_size=AUTOTUNE)\n", " )\n", "\n", " return train_ds, val_ds" ] }, { "cell_type": "code", "execution_count": 23, "id": "0e2804a5-ca55-4cd1-a6ce-823b805d5f91", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Found 15229 files belonging to 2 classes.\n", "Using 12184 files for training.\n", "Using 3045 files for validation.\n", "Training dataset class names are : ['NonRings', 'Rings']\n" ] } ], "source": [ "train_ds, val_ds = get_train_data(args.images, args.val_dir, args.train_frac, args.target_size, args.batch_size,\\\n", " args.augmentation_types, args.output_dir, args.random_state)" ] }, { "cell_type": "code", "execution_count": 24, "id": "3aeddb1a-35e1-4a88-85ea-1fde093d83b0", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mlinn-official\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" ] }, { "data": { "text/html": [ "wandb version 0.19.2 is available! To upgrade, please run:\n", " $ pip install wandb --upgrade" ], "text/plain": [ "<IPython.core.display.HTML object>" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Tracking run with wandb version 0.15.11" ], "text/plain": [ "<IPython.core.display.HTML object>" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Run data is saved locally in <code>/home/linn/2024/dec/aiml-handson/galactic-rings/wandb/run-20250109_144948-8e6g66t9</code>" ], "text/plain": [ "<IPython.core.display.HTML object>" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Syncing run <strong><a href='https://wandb.ai/linn-official/aiml-handson/runs/8e6g66t9' target=\"_blank\">fragrant-shape-7</a></strong> to <a href='https://wandb.ai/linn-official/aiml-handson' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>" ], "text/plain": [ "<IPython.core.display.HTML object>" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View project at <a href='https://wandb.ai/linn-official/aiml-handson' target=\"_blank\">https://wandb.ai/linn-official/aiml-handson</a>" ], "text/plain": [ "<IPython.core.display.HTML object>" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View run at <a href='https://wandb.ai/linn-official/aiml-handson/runs/8e6g66t9' target=\"_blank\">https://wandb.ai/linn-official/aiml-handson/runs/8e6g66t9</a>" ], "text/plain": [ "<IPython.core.display.HTML object>" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "2025-01-09 14:50:06.142372: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_0' with dtype string and shape [12184]\n", "\t [[{{node Placeholder/_0}}]]\n", "2025-01-09 14:50:06.143000: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_4' with dtype int32 and shape [12184]\n", "\t [[{{node Placeholder/_4}}]]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/2\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2025-01-09 14:50:10.550579: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:954] layout failed: INVALID_ARGUMENT: Size of values 0 does not match size of permutation 4 @ fanin shape insequential/dropout/dropout/SelectV2-2-TransposeNHWCToNCHW-LayoutOptimizer\n", "2025-01-09 14:50:12.780207: I tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:424] Loaded cuDNN version 8902\n", "2025-01-09 14:50:16.787475: I tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:637] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.\n", "2025-01-09 14:50:17.910929: I tensorflow/compiler/xla/service/service.cc:169] XLA service 0x7fe8d837d990 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:\n", "2025-01-09 14:50:17.910983: I tensorflow/compiler/xla/service/service.cc:177] StreamExecutor device (0): NVIDIA A100-PCIE-40GB, Compute Capability 8.0\n", "2025-01-09 14:50:18.240301: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.\n", "2025-01-09 14:50:19.617198: I ./tensorflow/compiler/jit/device_compiler.h:180] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "762/762 [==============================] - ETA: 0s - loss: 3.1502 - precision: 0.0976 - recall: 0.0935 - auc_pr: 0.0915 " ] }, { "name": "stderr", "output_type": "stream", "text": [ "2025-01-09 14:50:58.921923: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_0' with dtype string and shape [3045]\n", "\t [[{{node Placeholder/_0}}]]\n", "2025-01-09 14:50:58.923012: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_4' with dtype int32 and shape [3045]\n", "\t [[{{node Placeholder/_4}}]]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 1: val_loss improved from inf to 2.29671, saving model to output/fragrant-shape-7/best_model.keras\n", "762/762 [==============================] - 63s 49ms/step - loss: 3.1502 - precision: 0.0976 - recall: 0.0935 - auc_pr: 0.0915 - val_loss: 2.2967 - val_precision: 0.1341 - val_recall: 0.2881 - val_auc_pr: 0.1140\n", "Epoch 2/2\n", "761/762 [============================>.] - ETA: 0s - loss: 1.9817 - precision: 0.1092 - recall: 0.0525 - auc_pr: 0.0983 \n", "Epoch 2: val_loss improved from 2.29671 to 1.81954, saving model to output/fragrant-shape-7/best_model.keras\n", "762/762 [==============================] - 38s 49ms/step - loss: 1.9815 - precision: 0.1089 - recall: 0.0525 - auc_pr: 0.0982 - val_loss: 1.8195 - val_precision: 0.1333 - val_recall: 0.0247 - val_auc_pr: 0.0944\n" ] }, { "data": { "text/html": [ "Waiting for W&B process to finish... <strong style=\"color:green\">(success).</strong>" ], "text/plain": [ "<IPython.core.display.HTML object>" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "wandb: WARNING Source type is set to 'repo' but some required information is missing from the environment. A job will not be created from this run. See https://docs.wandb.ai/guides/launch/create-job\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "de7daf275a1a42c2b3542a562e320254", "version_major": 2, "version_minor": 0 }, "text/plain": [ "VBox(children=(Label(value='0.005 MB of 0.005 MB uploaded (0.000 MB deduped)\\r'), FloatProgress(value=1.0, max…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "<style>\n", " table.wandb td:nth-child(1) { padding: 0 10px; text-align: left ; width: auto;} td:nth-child(2) {text-align: left ; width: 100%}\n", " .wandb-row { display: flex; flex-direction: row; flex-wrap: wrap; justify-content: flex-start; width: 100% }\n", " .wandb-col { display: flex; flex-direction: column; flex-basis: 100%; flex: 1; padding: 10px; }\n", " </style>\n", "<div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>epoch/auc_pr</td><td>▁█</td></tr><tr><td>epoch/epoch</td><td>▁█</td></tr><tr><td>epoch/learning_rate</td><td>▁▁</td></tr><tr><td>epoch/loss</td><td>█▁</td></tr><tr><td>epoch/precision</td><td>▁█</td></tr><tr><td>epoch/recall</td><td>█▁</td></tr><tr><td>epoch/val_auc_pr</td><td>█▁</td></tr><tr><td>epoch/val_loss</td><td>█▁</td></tr><tr><td>epoch/val_precision</td><td>█▁</td></tr><tr><td>epoch/val_recall</td><td>█▁</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>epoch/auc_pr</td><td>0.09821</td></tr><tr><td>epoch/epoch</td><td>1</td></tr><tr><td>epoch/learning_rate</td><td>0.001</td></tr><tr><td>epoch/loss</td><td>1.98147</td></tr><tr><td>epoch/precision</td><td>0.10893</td></tr><tr><td>epoch/recall</td><td>0.05252</td></tr><tr><td>epoch/val_auc_pr</td><td>0.0944</td></tr><tr><td>epoch/val_loss</td><td>1.81954</td></tr><tr><td>epoch/val_precision</td><td>0.13333</td></tr><tr><td>epoch/val_recall</td><td>0.02469</td></tr></table><br/></div></div>" ], "text/plain": [ "<IPython.core.display.HTML object>" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View run <strong style=\"color:#cdcd00\">fragrant-shape-7</strong> at: <a href='https://wandb.ai/linn-official/aiml-handson/runs/8e6g66t9' target=\"_blank\">https://wandb.ai/linn-official/aiml-handson/runs/8e6g66t9</a><br/>Synced 6 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)" ], "text/plain": [ "<IPython.core.display.HTML object>" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Find logs at: <code>./wandb/run-20250109_144948-8e6g66t9/logs</code>" ], "text/plain": [ "<IPython.core.display.HTML object>" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "wandb.init(project=\"aiml-handson\", anonymous=\"allow\")\n", "callbacks = create_callbacks(wandb.run.name)\n", "history = model.fit(train_ds, validation_data=val_ds, epochs=args.epochs, shuffle=True, callbacks=callbacks)\n", "wandb.finish()" ] }, { "cell_type": "markdown", "id": "6fc9f081-dd24-466c-874f-4978af71d5c9", "metadata": {}, "source": [ "# Results" ] }, { "cell_type": "markdown", "id": "ed07eed9-78a7-4ddf-9478-9e58b4889603", "metadata": {}, "source": [ "[Training history comparison](https://wandb.ai/linn-official/Ring_Train/reports/val_loss-25-01-04-10-29-05---VmlldzoxMDgxMDUwMA?accessToken=aexjbxpy9q24ikedi0vf7k3edss8jszlldy15or6blpt5f3kaxxps6lt8ql3qgg5)" ] }, { "cell_type": "markdown", "id": "8056ffed-b234-4203-8415-d9dd53f360c2", "metadata": {}, "source": [ "### Download trained model" ] }, { "cell_type": "code", "execution_count": 25, "id": "2629ee40-bc37-42dc-b7c8-8132c0e4416a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Downloading...\n", "From (original): https://drive.google.com/uc?id=1m4oVnlxAC9MxXZsQU9oEfAeFYVzmEtsA\n", "From (redirected): https://drive.google.com/uc?id=1m4oVnlxAC9MxXZsQU9oEfAeFYVzmEtsA&confirm=t&uuid=983f762e-92af-4460-a445-1751d3c86ad3\n", "To: /home/linn/2024/dec/aiml-handson/galactic-rings/clean-shadow-84-slim.zip\n", "100%|██████████████████████████████████████| 2.46G/2.46G [01:59<00:00, 20.6MB/s]\n" ] } ], "source": [ "!gdown --fuzzy \"https://drive.google.com/file/d/1m4oVnlxAC9MxXZsQU9oEfAeFYVzmEtsA/view?usp=sharing\"" ] }, { "cell_type": "code", "execution_count": 26, "id": "2f1e9a8b-53bd-4d23-8028-e73d9c4ffb16", "metadata": {}, "outputs": [], "source": [ "%%capture\n", "!unzip clean-shadow-84-slim.zip" ] }, { "cell_type": "code", "execution_count": 27, "id": "502085c4-f8be-4060-9104-60b111b63ad6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "best_model.h5 history.json train_filenames.csv validation_filenames.csv\n" ] } ], "source": [ "!ls clean-shadow-84-slim" ] }, { "cell_type": "code", "execution_count": 28, "id": "3810947c-7c55-4fb3-acfc-ad27f2f3c885", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "clean-shadow-84-slim.zip: OK\n" ] } ], "source": [ "!echo \"6f3c1140c1b4a7f0cb02ceecaf5cb030 clean-shadow-84-slim.zip\" | md5sum -c" ] }, { "cell_type": "markdown", "id": "8af0ad20-857e-4926-8a0a-15934c5ac1ff", "metadata": {}, "source": [ "### Download training data with visual selections" ] }, { "cell_type": "code", "execution_count": 29, "id": "e2a55679-caa8-40f0-aebc-d53109a5e66d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Downloading...\n", "From (original): https://drive.google.com/uc?id=1YPWqAfXbltU56Biz7j2_nyAWlyODEdUA\n", "From (redirected): https://drive.google.com/uc?id=1YPWqAfXbltU56Biz7j2_nyAWlyODEdUA&confirm=t&uuid=f38d2b45-3758-472d-952d-30378fe263e4\n", "To: /home/linn/2024/dec/aiml-handson/galactic-rings/E11dash.zip\n", "100%|██████████████████████████████████████| 67.7M/67.7M [00:01<00:00, 34.5MB/s]\n" ] } ], "source": [ "!gdown --fuzzy \"https://drive.google.com/file/d/1YPWqAfXbltU56Biz7j2_nyAWlyODEdUA/view?usp=sharing\"" ] }, { "cell_type": "code", "execution_count": 30, "id": "e826df79-2e4a-44af-b720-febc92b5c8e3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "E11dash.zip: OK\n" ] } ], "source": [ "!echo \"4925fae1310595ce8fa7f47e98180953 E11dash.zip\" | md5sum -c" ] }, { "cell_type": "code", "execution_count": 31, "id": "8d572cd0-c28a-4e0b-840a-fa597cce3f45", "metadata": {}, "outputs": [], "source": [ "%%capture\n", "!unzip E11dash.zip" ] }, { "cell_type": "code", "execution_count": 32, "id": "1971536d-5834-4649-8ae6-68bb8042b8ef", "metadata": {}, "outputs": [], "source": [ "import argparse\n", "from tensorflow.keras import layers\n", "from tensorflow.keras.models import load_model\n", "from sklearn.metrics import confusion_matrix\n", "from sklearn.metrics import accuracy_score, precision_score, recall_score, \\\n", "f1_score, roc_auc_score, roc_curve, balanced_accuracy_score, brier_score_loss, \\\n", "average_precision_score, fbeta_score, matthews_corrcoef, auc, precision_recall_curve, \\\n", "classification_report\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": 33, "id": "f0a4bf9f-1754-4f03-a5c5-7467950e8b26", "metadata": {}, "outputs": [], "source": [ "%%capture\n", "!pip install scikit-learn" ] }, { "cell_type": "code", "execution_count": 34, "id": "8e143434-1996-46b0-966e-16f3c99cdfb8", "metadata": {}, "outputs": [], "source": [ "eval_parser = argparse.ArgumentParser()" ] }, { "cell_type": "code", "execution_count": 35, "id": "b9dbe147-87d8-4ee1-8d32-c76096e4f27e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "_StoreTrueAction(option_strings=['--write'], dest='write', nargs=0, const=True, default=False, type=None, choices=None, required=False, help='Switch to enable writing results to disk', metavar=None)" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "eval_parser.add_argument('--trained-model', required=True, help=\"Path to trained model\")\n", "eval_parser.add_argument('--test-dir', help=\"Directory containing validation or test images sorted into respective classes\")\n", "eval_parser.add_argument('--saved-ds', default=False, help=\"Boolean flag that is true if test_dir points to a tf.data.Dataset object\")\n", "eval_parser.add_argument('--threshold', type=float, default=0.5, help=\"Decimal threshold to use for creating CM, etc.\")\n", "eval_parser.add_argument('--write', action=\"store_true\", help=\"Switch to enable writing results to disk\")" ] }, { "cell_type": "code", "execution_count": 36, "id": "f2438de5-c2db-43fd-88bf-596886263b6d", "metadata": {}, "outputs": [], "source": [ "eval_args = eval_parser.parse_args(['--test-dir','E11dash/test/','--trained-model','clean-shadow-84-slim/best_model.h5'])" ] }, { "cell_type": "code", "execution_count": 37, "id": "c8ea0041-c2a9-4bb8-8b2b-7e050180fcff", "metadata": {}, "outputs": [], "source": [ "test_dir = eval_args.test_dir\n", "model_path = eval_args.trained_model\n", "batch_size = 64" ] }, { "cell_type": "code", "execution_count": 38, "id": "dae3d9f1-20be-4f8d-8f34-76288ff5abad", "metadata": {}, "outputs": [], "source": [ "img_height, img_width = args.target_size" ] }, { "cell_type": "code", "execution_count": 39, "id": "4534cdc2-4594-4d56-90fd-2bcb1c1c0840", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Found 2353 files belonging to 2 classes.\n" ] } ], "source": [ "test_ds = tf.keras.utils.image_dataset_from_directory(\n", " test_dir,\n", " #color_mode='grayscale',\n", " shuffle=False,\n", " image_size=(img_height, img_width),\n", " batch_size=None)\n", "\n", "filenames = test_ds.file_paths\n", "\n", "normalization_layer = layers.Rescaling(1./255)\n", "\n", "test_ds = test_ds.map(lambda x, y: (normalization_layer(x), y))\n", "\n", "labels = test_ds.map(lambda _, label: label)\n", "\n", "AUTOTUNE = tf.data.AUTOTUNE\n", "test_ds = test_ds.batch(batch_size).cache().prefetch(buffer_size=AUTOTUNE)" ] }, { "cell_type": "code", "execution_count": 40, "id": "cd457aba-8055-4329-9a16-0a7554073611", "metadata": {}, "outputs": [], "source": [ "model = load_model(model_path)" ] }, { "cell_type": "code", "execution_count": 41, "id": "0fe8af67-73e3-4b1d-a992-9d6fd8c7301e", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-01-09 14:57:19.489971: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_4' with dtype int32 and shape [2353]\n", "\t [[{{node Placeholder/_4}}]]\n", "2025-01-09 14:57:19.490582: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_4' with dtype int32 and shape [2353]\n", "\t [[{{node Placeholder/_4}}]]\n", "2025-01-09 14:57:20.181974: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_4' with dtype int32 and shape [2353]\n", "\t [[{{node Placeholder/_4}}]]\n", "2025-01-09 14:57:20.182448: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_4' with dtype int32 and shape [2353]\n", "\t [[{{node Placeholder/_4}}]]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "37/37 [==============================] - 3s 60ms/step\n" ] } ], "source": [ "ground_truth = list(labels.as_numpy_iterator())\n", "predictions = model.predict(test_ds)" ] }, { "cell_type": "code", "execution_count": 42, "id": "f4a5f0bf-8c3b-4a39-bd5d-317baf0d3387", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using a classification threshold 0.5\n" ] } ], "source": [ "threshold = eval_args.threshold\n", "print(\"Using a classification threshold\", threshold)\n", "\n", "predicted_labels = [1 if pred >= threshold else 0 for pred in predictions]" ] }, { "cell_type": "code", "execution_count": 43, "id": "460b7682-0188-42f2-a338-c4617e2fcbc3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Confusion Matrix:\n", "[[2093 35]\n", " [ 12 213]]\n" ] } ], "source": [ "confusion_mtx = confusion_matrix(ground_truth, predicted_labels)\n", "print(\"Confusion Matrix:\")\n", "print(confusion_mtx)" ] }, { "cell_type": "code", "execution_count": 44, "id": "dbe2d8ff-9d69-4492-9f45-f31be3d21cab", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy: 0.9800254993625159\n", "F1-score: 0.9006342494714588\n", "ROC AUC Score: 0.995064745196324\n", "PR AUC Score: 0.9219745599446181\n", "Brier score 0.015241051341857327\n", "Average precision score 0.9243849287069954\n", "Classification Report\n", " precision recall f1-score support\n", "\n", " NonRings 0.99 0.98 0.99 2128\n", " Rings 0.86 0.95 0.90 225\n", "\n", " accuracy 0.98 2353\n", " macro avg 0.93 0.97 0.94 2353\n", "weighted avg 0.98 0.98 0.98 2353\n", "\n", "False Positive Rate (FPR): 0.01644736842105263\n", "TNR or Specificity: 0.9835526315789473\n", "G-Mean: 0.9649334128467467\n", "F-beta score beta=2 0.9277003484320558\n", "F1 Score After Thresholding: 0.9006342494714588\n", "Matthew Correlation Coefficient: 0.8908621868910627\n", "Balanced Accuracy: 0.9651096491228071\n" ] } ], "source": [ "accuracy = accuracy_score(ground_truth, predicted_labels)\n", "f1 = f1_score(ground_truth, predicted_labels)\n", "try:\n", " roc_auc = roc_auc_score(ground_truth, predictions)\n", "except:\n", " print(\"Setting roc_auc to be -1 as it is not defined\")\n", " roc_auc = -1\n", "precisions, recalls, thresholds = precision_recall_curve(ground_truth, predictions)\n", "pr_auc = auc(recalls, precisions)\n", "brier_score = brier_score_loss(ground_truth, predictions)\n", "avg_precision = average_precision_score(ground_truth, predictions)\n", "report = classification_report(ground_truth, predicted_labels, target_names=['NonRings', 'Rings'])\n", "\n", "tn, fp, fn, tp = confusion_mtx.ravel()\n", "fpr = fp / (fp + tn)\n", "specificity = tn / (fp + tn)\n", "precision = precision_score(ground_truth, predicted_labels)\n", "recall = recall_score(ground_truth, predicted_labels)\n", "bal_acc = balanced_accuracy_score(ground_truth, predicted_labels)\n", "matthews_corr_coef = matthews_corrcoef(ground_truth, predicted_labels)\n", "beta = 2\n", "fbeta = fbeta_score(ground_truth, predicted_labels, beta=beta)\n", "print(\"Accuracy:\", accuracy)\n", "print(\"F1-score:\", f1)\n", "print(\"ROC AUC Score:\", roc_auc)\n", "print(\"PR AUC Score:\", pr_auc)\n", "print(\"Brier score\", brier_score)\n", "print(\"Average precision score\", avg_precision)\n", "print(\"Classification Report\")\n", "print(report)\n", "print(\"False Positive Rate (FPR):\", fpr)\n", "print(\"TNR or Specificity:\", specificity)\n", "print(\"G-Mean:\", np.sqrt(recall * specificity))\n", "print(f\"F-beta score beta={beta}\", fbeta)\n", "print(\"F1 Score After Thresholding: {}\".format( f1_score(ground_truth, predicted_labels)))\n", "print(\"Matthew Correlation Coefficient:\", matthews_corr_coef)\n", "print(\"Balanced Accuracy:\", bal_acc)" ] }, { "cell_type": "code", "execution_count": null, "id": "c04dfff9-d34a-4046-937a-a9e48d7f64a1", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "solar311", "language": "python", "name": "solar311" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.5" } }, "nbformat": 4, "nbformat_minor": 5 }