diff --git a/README.md b/README.md index d121222..fda6149 100644 --- a/README.md +++ b/README.md @@ -13,11 +13,10 @@ synthetic radio data generated using MATLAB, showcases our commitment to interop approach to innovation. Classification results are comparable to those reported by MathWorks' AI-based network. For more information, -please refer to the following MathWorks article: +please refer to the following article by MathWorks: [Spectrum Sensing with Deep Learning to Identify 5G and LTE Signals](https://www.mathworks.com/help/comm/ug/spectrum-sensing-with-deep-learning-to-identify-5g-and-lte-signals.html). -If you found this example interesting or helpful, don't forget to give it a star! ⭐ Also, be sure to check out our -open-source project: [RIA Core](https://github.com/qoherent/ria). +If you found this example interesting or helpful, don't forget to give it a star! ⭐ ## 🚀 Getting Started @@ -25,7 +24,7 @@ open-source project: [RIA Core](https://github.com/qoherent/ria). This example is provided as a Jupyter Notebook. You have the option to either run this example locally or in Google Colab. -To run this example locally, you'll need to download this project and dataset, and set up a Conda +To run this example locally, you'll need to download the project and dataset and set up a Conda virtual environment. If this seems daunting, we recommend running this example on Google Colab. ### Running this example locally @@ -33,8 +32,8 @@ virtual environment. If this seems daunting, we recommend running this example o Please note that running this example locally will require approximately 10 GB of free space. Please ensure you have sufficient space available prior to proceeding. -1. Ensure that [Python](https://www.python.org/downloads/), [Git](https://git-scm.com/downloads), and [Conda](https://conda.io/projects/conda/en/latest/user-guide/install/index.html) are installed on the computer where you plan to run -this example. Additionally, if you'd like to accelerate model training with a GPU, you'll require [CUDA](https://docs.nvidia.com/cuda/cuda-quick-start-guide/index.html). +1. Ensure that [Git](https://git-scm.com/downloads) and [Conda](https://conda.io/projects/conda/en/latest/user-guide/install/index.html) are installed on the computer where you plan to run this example. +Additionally, if you'd like to accelerate model training with a GPU, you'll require [CUDA](https://docs.nvidia.com/cuda/cuda-quick-start-guide/index.html). 2. Clone this repository to your local computer: @@ -109,6 +108,8 @@ page [here](https://github.com/qoherent/spectrogram-segmentation/issues). Has this example inspired a project or research initiative related to intelligent radio? Please [get in touch](mailto:info@qoherent.ai); we'd love to collaborate with you! 📡🚀 +Finally, be sure to check out our open-source project: [RIA Core](https://github.com/qoherent/ria). + ## 🖊️ Authorship @@ -119,10 +120,10 @@ for sharing. ## 🙏 Attribution -The dataset used in this example was prepared by MathWorks and is publicly available under the MIT license -[here](https://www.mathworks.com/supportfiles/spc/SpectrumSensing/SpectrumSenseTrainingDataNetwork.tar.gz). For more information on how this dataset was generated or to generate further spectrum data, please -refer to MathWork's article on spectrum sensing. For more information about Qoherent's use of MATLAB to accelerate -intelligent radio research, check out our [customer story](https://www.mathworks.com/company/user_stories/qoherent-uses-matlab-to-accelerate-research-on-next-generation-ai-for-wireless.html). +The dataset used in this example was prepared by MathWorks and is publicly available [here](https://www.mathworks.com/supportfiles/spc/SpectrumSensing/SpectrumSenseTrainingDataNetwork.tar.gz). For more information +on how this dataset was generated or to generate further spectrum data, please refer to MathWork's article on spectrum +sensing. For more information about Qoherent's use of MATLAB to accelerate intelligent radio research, check out our +[customer story](https://www.mathworks.com/company/user_stories/qoherent-uses-matlab-to-accelerate-research-on-next-generation-ai-for-wireless.html). The DeepLabv3 models used in this example were initially proposed by Chen _et al._ and are further discussed in their 2017 paper titled '[Rethinking Atrous Convolution for Semantic Image Segmentation](https://arxiv.org/abs/1706.05587)'. The MobileNetV3 diff --git a/download_dataset.py b/download_dataset.py index f8a6cf7..9ba564e 100644 --- a/download_dataset.py +++ b/download_dataset.py @@ -1,5 +1,5 @@ """ -Download MathWorks' Spectrum Sensing 5G dataset, if it isn't already downloaded. +Download MathWorks' Spectrum Sensing dataset, if it isn't already downloaded. """ import os diff --git a/spectrogram_segmentation.ipynb b/spectrogram_segmentation.ipynb index 227a676..12ef62b 100644 --- a/spectrogram_segmentation.ipynb +++ b/spectrogram_segmentation.ipynb @@ -46,11 +46,10 @@ "the field of computer vision to the problem of spectrogram analysis. Our task is to assign one of the \n", "following labels to each pixel in the spectrogram: 'LTE', 'NR', or 'Noise'. ('Noise' refers to the absence of signal, representing \n", "a vacant or empty spectrum, also known as whitespace.)\n", - ".\n", "\n", "The machine learning model utilized in this example is a DeepLabV3 model with a MobileNetV3 large backbone. The DeepLabv3 framework was originally introduced by Chen _et al._ in their 2017 paper titled '[Rethinking Atrous Convolution for Semantic Image Segmentation](https://arxiv.org/abs/1706.05587) and the MobileNetV3 backbone was developed by Howard _et al._ and is further discussed in their 2019 paper titled '[Searching for MobileNetV3](https://arxiv.org/abs/1905.02244)'. For an accessible introduction to the DeepLabV3 framework, please check out Isaac Berrios' article: [DeepLabv3: Building Blocks for Robust Segmentation Models](https://medium.com/@itberrios6/deeplabv3-c0c8c93d25a4).\n", "\n", - "The dataset used in this example is the Spectrum Sensing dataset, provided by MathWorks. This dataset contains 900 LTE frames, 900 NR frames, and 900 combined frames with both LTE and NR signal. In this example, we train exclusively on the individual LTE and NR examples, excluding the combined frames." + "The dataset used in this example is the Spectrum Sensing dataset, provided by MathWorks. This dataset contains 900 LTE frames, 900 NR frames, and 900 combined frames with both LTE and NR signal. In this example, we train exclusively on the individual LTE and NR examples, excluding the combined frames from the training process." ] }, { @@ -200,7 +199,7 @@ "and `target_transform`, which is applied to the mask.\n", "\n", "Both the spectrograms and masks are 256 x 256 pixel images. However, the spectrograms are three channeled, while the masks are single-channeled. This is because the spectrograms are full RGB images, whereas the masks are ternary-valued images, where each pixel takes one of three discrete values:\n", - "- `0`: Represents noise.\n", + "- `0`: Representing noise.\n", "- `127`: Representing NR signal.\n", "- `255`: Representing LTE signal.\n", "\n", @@ -220,6 +219,7 @@ "\n", "\n", "class Squeeze(torch.nn.Module):\n", + " # Mapping 0 -> 0, 127 -> 1, and 255 -> 2.\n", " def forward(self, target: Tensor):\n", " return torch.squeeze(target)\n", "\n", @@ -237,9 +237,7 @@ " ]\n", ")\n", "\n", - "target_transform = Compose(\n", - " [PILToTensor(), Squeeze(), DivideBy127(), ToDtype(torch.long)] # Mapping 0 -> 0, 127 -> 1, and 255 -> 2.\n", - ")" + "target_transform = Compose([PILToTensor(), Squeeze(), DivideBy127(), ToDtype(torch.long)])" ] }, { @@ -269,7 +267,7 @@ "random_index = np.random.randint(len(dataset))\n", "training_example, corresponding_mask = dataset[random_index]\n", "\n", - "print(f\"The full dataset has {len(dataset)} examples. Loading example at index {random_index}:\")\n", + "print(f\"The full dataset has {len(dataset)} examples. Loading example at index {random_index}.\")\n", "print(f\"Spectrogram: {type(training_example)}, {training_example.dtype}, {training_example.size()}\")\n", "print(f\"Mask: {type(corresponding_mask)}, {corresponding_mask.dtype}, {corresponding_mask.size()}\")" ] @@ -490,7 +488,7 @@ "source": [ "# Model Training\n", "\n", - "In this example, we'll use a DeepLabV3 model with a MobileNetV3 large backbones. This model is designed to be lightweight and efficient, making it ideal for edge computing devices and quick proof-of-concept demonstrations" + "In this example, we'll use a DeepLabV3 model with a MobileNetV3 large backbones. This model is designed to be lightweight and efficient, making it ideal for edge computing devices and quick proof-of-concept demonstrations." ] }, { @@ -512,7 +510,7 @@ "feedback that guides the model's training process. For classification problems, we commonly use the [Cross-Entropy Loss](https://machinelearningmastery.com/cross-entropy-for-machine-learning/), especially for \n", "multi-class classification problems. Let's use the [`CrossEntropyLoss`](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html) class from PyTorch, which allows us to assign different weights to individual classes during the computation of the loss. \n", "\n", - "We'll use weights inversely proportional to the relative pixel count for each class. That way, we assign lower weights to overrepresented classes, like noise, and larger weights to underrepresented classes, like LTE signal. This reduces the impact of noise and allows the model to prioritize learning from NR and especially LTE samples. Class weighting is not the only way to address data imblance, but it is one of the more straightforward methods." + "We'll use weights inversely proportional to the relative pixel count for each class. That way, we assign lower weights to overrepresented classes, like noise, and larger weights to underrepresented classes, like LTE signal. This reduces the impact of noise and allows the model to prioritize learning from NR and LTE samples. Class weighting is not the only way to address data imblance, but it is one of the more straightforward methods." ] }, { @@ -615,10 +613,10 @@ "The following hyperparameters are used to configure the optimizer:\n", "- **Momentum:** A parameter that accelerates SGD in the relevant direction and dampens oscillations.\n", "- **Learning Rate:** The rate at which the model parameters are updated during optimization.\n", - "- **Weight Decay:** A regularization term added to the loss function to penalize large weights in the model to prevent overfitting\n", + "- **Weight Decay:** A regularization term added to the loss function to penalize large weights in the model to prevent overfitting.\n", "\n", - "By gradually reducing the learning rate over epochs, the scheduler can help improve the convergence and stability of the optimization process\n", - "We need to provide the following two parameters, which the learning rate scheduler uses to dynamically adjust the learning rate during training:\n", + "By gradually reducing the learning rate over epochs, the scheduler can help improve the convergence and stability of the optimization process.\n", + "We need to provide the following two parameters, which will be used by the learning rate scheduler to dynamically adjust the learning rate during training:\n", "- **Step Size:** The number of epochs after which the learning rate is reduced.\n", "- **Gamma:** The factor by which the learning rate is reduced after every step-size epochs.\n", "\n", @@ -647,9 +645,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Now that we have our model, weighted loss function, and Lightning Module, we are prepared to train our model. If available, we will leverage GPU acceleration. Otherwise, the training process will default to using the CPU. Please be patient; model training time may vary depending on the current hardware configuration and could take a few minutes.\n", + "Now that we have our model, weighted loss function, and Lightning Module, we are prepared to train our model. If available, we will leverage GPU acceleration. Otherwise, the training process will default to using the CPU. Please be patient; model training time may vary from a few minutes to a couple hours depending on the current hardware configuration.\n", "\n", - "The number of epochs determines how many times the entire dataset will be used to train the model. For this specific model and dataset, 10 epochs should be more than sufficient." + "The number of epochs determines how many times the entire dataset will be used to train the model. For this specific model and dataset, 10 epochs should be more than sufficient. However, if you are training on the CPU, you might want to consider reducing the number of training epochs to 4 to save on training time." ] }, { @@ -659,6 +657,7 @@ "outputs": [], "source": [ "n_epochs = 10\n", + "# n_epochs = 4 # Suggested for CPU training.\n", "\n", "if torch.cuda.is_available():\n", " print(\"Training model on GPU.\")\n", @@ -685,9 +684,9 @@ "source": [ "# Model Validation\n", "\n", - "Having trained our model, the next step is to evaluate its performance. To accomplish this, we'll use a suite of standard machine learning metrics. But first, let's take a look at a random batch of predictions and true labels.\n", + "Having trained our model, the next step is to evaluate its performance. To accomplish this, we'll use a suite of standard machine learning metrics. But first, let's take a look at a random batch of predictions.\n", "\n", - "Because the model returns the unnormalized probabilities corresponding to the predictions of each class. We need to use `argmax()` to get the maximum prediction of each class. The result is a ternary-valued image for each example in the batch." + "Because the model returns the probabilities corresponding to the predictions of each class. We need to use `argmax()` to obtain the class with the highest prediction probability. The result is a singe-channel image for each example in the batch, which can be compared directly to the corresponding target mask." ] }, { @@ -702,7 +701,7 @@ "spects, masks = next(iter(train_loader))\n", "spects = spects.to(device)\n", "\n", - "with torch.no_grad():\n", + "with torch.inference_mode():\n", " preds = (model(spects)[\"out\"]).argmax(1)\n", "\n", "print(\"Predictions:\", preds.size())" @@ -723,7 +722,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Looks pretty good! But to get a more ojective sense, let's turn to the metrics. Let's start with model accuracy, calculated as the ratio of correctly predicted pixels to the total number of pixels.\n", + "Looks pretty good! But to get a more objective sense, let's turn to the metrics. Let's start with model accuracy, calculated as the ratio of correctly predicted pixels to the total number of pixels.\n", "\n", "**Note:** You can view different predictions by rerunning the previous few code cells." ] @@ -755,7 +754,7 @@ "model.to(device)\n", "confusion_matrix = MulticlassConfusionMatrix(num_classes=n_classes, normalize=\"true\").to(device)\n", "\n", - "with torch.no_grad():\n", + "with torch.inference_mode():\n", " for spect, mask in val_loader:\n", " spect, mask = spect.to(device), mask.to(device)\n", " pred = (model(spect)[\"out\"]).argmax(dim=1)\n", @@ -777,7 +776,7 @@ "\n", "- **F1 Score:** The F1 score combines both recall and precision into a single value, providing a more balanced measure of the model's performance. A higher F1 indicates a model with both good precision and recall (fewer false positives and false negatives overall).\n", "\n", - "- **Intersection over Union (IoU):** The IoU, commonly called Jaccard's Index, quantifies the overlap between the predicted bounding box or segmented region and the ground truth. A higher IoU value indicates a better alignment between the predicted and actual regions, reflecting a more accurate model." + "- **Intersection over Union (IoU):** The IoU quantifies the overlap between the predicted bounding box or segmented region and the ground truth. A higher IoU value indicates a better alignment between the predicted and actual regions, reflecting a more accurate model." ] }, { @@ -794,11 +793,12 @@ " MulticlassRecall(num_classes=n_classes, average=None),\n", " MulticlassPrecision(num_classes=n_classes, average=None),\n", " MulticlassF1Score(num_classes=n_classes, average=None),\n", + " # The IoU is commonly referred to as Jaccard's Index\n", " MulticlassJaccardIndex(num_classes=n_classes, average=None),\n", " ]\n", " metrics = [m.to(device) for m in metrics]\n", "\n", - " with torch.no_grad():\n", + " with torch.inference_mode():\n", " for spect, mask in dataloader:\n", " spect, mask = spect.to(device), mask.to(device)\n", " pred = (model(spect)[\"out\"]).argmax(dim=1)\n", @@ -815,7 +815,7 @@ " \"Recall\": metrics[1],\n", " \"Precision\": metrics[2],\n", " \"F1 Score\": metrics[3],\n", - " \"Jaccard Index\": metrics[4],\n", + " \"IoU\": metrics[4],\n", " }\n", " )\n", " print(\n", @@ -840,7 +840,7 @@ "source": [ "# Challenge Data\n", "\n", - "In machine leaning, generalization refers to the ability of a trained model to perform well on new data that it hasn't been trained on. As an easy way to test the generalization of our model, let's test on the combined frames with both LTE and NR signal. As a reminder, such frames were excluded from training." + "In machine learning, out-of-distribution data refers to examples that deviate from those used during training. For example, recall the Spectrogram Sensing dataset comprises 900 combined frames featuring both LTE and NR signals. As we excluded the combined frames from the training process, they represent out-of-distribution data. To get a quick sense of how our model performs on these combined frames, let's take a look at a random batch of predictions." ] }, { @@ -864,7 +864,7 @@ "spects, masks = next(iter(challenge_loader))\n", "spects = spects.to(device)\n", "\n", - "with torch.no_grad():\n", + "with torch.inference_mode():\n", " preds = (model(spects)[\"out\"]).argmax(1)\n", "\n", "print(\"Predictions:\", preds.size())" @@ -887,7 +887,7 @@ "source": [ "**Note:** You can view different examples by rerunning the previous few code cells.\n", "\n", - "Now, let's evaluate the same metrics as we did above in the [Model Validation](#Model-Validation) section, but now for the challenge dataset. Given that these combined frames represent a more challenging problem, we anticipate the model's capabilities to be somewhat diminished, yet we still anticipate reasonable results." + "Now, let's evaluate the same metrics as we did above in the [Model Validation](#Model-Validation) section, but now for the challenge dataset. Given the model's lack of exposure to these combined frames during the training process, we anticipate the model's capabilities to be somewhat diminished. Yet, we still anticipate reasonable results." ] }, { @@ -905,7 +905,7 @@ "source": [ "# Conclusions & Next Steps\n", "\n", - "In this example, we used PyTorch and PyTorch Lightning to train DeepLabV3 models to identify and differentiate between 5G NR and 4G LTE signals within wideband spectrograms, showcasing one of the ways we can leverage machine learning to identify things in the wireless spectrum. This involved data analysis and preprocessing, choosing a loss function and optimizer, model training, model performance validation, and finally testing the model's generalization on combined frames containing both NR and LTE signals, \n", + "In this example, we used PyTorch and PyTorch Lightning to train a DeepLabV3 model to identify and differentiate between 5G NR and 4G LTE signals within wideband spectrograms, showcasing one of the ways we can leverage machine learning to identify things in the wireless spectrum. This involved data analysis and preprocessing, choosing a loss function and optimizer, model training, model performance validation, and finally testing the model on out-of-distribution frames containing both NR and LTE signals.\n", "\n", "The capability to differentiate and recognize various signals finds direct applications in spectrum sensing, which is fundamental to autonomous spectrum management, and brings us one step closer to more holistic cognitive radio solutions! 📡🚀" ] @@ -916,9 +916,9 @@ "source": [ "We hope this example was informative. Here are some next steps you can take to further explore and expand upon what you've learned:\n", "\n", - "- **Experiment with the Hyperparameters:** Adjust the values of hyperparameters such as the number of training epochs, batch size, and learning rate, and observe how these configurations influence model training, performance, and generalization capabilities. After gaining insights through manual hyperparameter tuning, explore automated approaches using tools like [Ray Tune](https://pytorch.org/tutorials/beginner/hyperparameter_tuning_tutorial.html) or [Optuna](https://optuna.org/).\n", + "- **Experiment with the Hyperparameters:** Adjust the values of hyperparameters such as the number of training epochs, batch size, and learning rate, and observe how these configurations influence model training and performance. After gaining insights through manual hyperparameter tuning, explore automated approaches using tools like [Ray Tune](https://pytorch.org/tutorials/beginner/hyperparameter_tuning_tutorial.html) or [Optuna](https://optuna.org/).\n", "\n", - "- **Experiment with DeepLabV3's ResNet Models:** DeepLabV3 also provides models with ResNet-50 and ResNet-101 backbones. These ResNet models are deeper and more complex, and generally offers better model performance than MobileNetV3, which is designed to be lightweight and efficient. Because all DeepLabV3 models implement the same interface, no code changes are required. However, some hyperparameter tuning and/or a larger dataset may be required to train these models effectively. These models have already been imported for your convenience.\n", + "- **Experiment with DeepLabV3's ResNet Models:** DeepLabV3 also provides models with ResNet-50 and ResNet-101 backbones. These ResNet models are deeper and more complex, and generally offer better model performance than MobileNetV3, which is designed to be lightweight and efficient. Because all DeepLabV3 models implement the same interface, no code changes are required. However, some hyperparameter tuning and/or a larger dataset may be required to train these models effectively. These models have already been imported into this notebook for your convenience.\n", "\n", "- **Explore Alternative Solutions to Class Imbalance:** In this example, we addressed class imbalance in our dataset using a weighted cross-entropy loss function. Research and implement alternative strategies or loss functions designed to address imbalance in image datasets.\n", "\n",