Class activation maps for your PyTorch models (CAM, Grad-CAM, Grad-CAM++, Smooth Grad-CAM++, Score-CAM, SS-CAM, IS-CAM, XGrad-CAM, Layer-CAM)
APACHE-2.0 License
Published by frgfm about 1 year ago
This minor release adds evaluation metrics to the package and bumps PyTorch to version 2.0
Note: TorchCAM 0.4.0 requires PyTorch 2.0.0 or higher.
This release comes with a standard way to evaluate interpretability methods. This allows users to better evaluate models' robustness:
from functools import partial
from torchcam.metrics import ClassificationMetric
metric = ClassificationMetric(cam_extractor, partial(torch.softmax, dim=-1))
metric.update(input_tensor)
metric.summary()
Full Changelog: https://github.com/frgfm/torch-cam/compare/v0.3.2...v0.4.0
Published by frgfm about 2 years ago
This patch release fixes the Score-CAM methods and improves the base API for CAM computation.
Note: TorchCAM 0.3.2 requires PyTorch 1.7.0 or higher.
CAM computation now supports batch sizes larger than 1 (#143) ! Practically, this means that you can compute CAMs for multiple samples at the same time, which will let you make the most of your GPU as well β‘
The following snippet:
import torch
from torchcam.methods import LayerCAM
from torchvision.models import resnet18
# A preprocessed (resized & normalized) tensor
img_tensor = torch.rand((2, 3, 224, 224))
model = resnet18(pretrained=True).eval()
# Hook your model before inference
cam_extractor = LayerCAM(model)
out = model(img_tensor)
# Compute the CAM
activation_map = cam_extractor(out[0].argmax().item(), out)
print(activation_map[0].ndim)
will yield 3
as the batch dimension is now also used.
New year, new documentation theme!
For clarity and improved interface, the documentation theme was changed from Read the Docs to Furo (#162)
This comes with nice features like dark mode and edit button!
Contributions are important to OSS projects, and for this reason, a few improvements were made to the contribution process:
CAM extractors now outputs a list of tensors. The size of the list is equal to the number of target layers and ordered the same way.
Each of these elements used to be a 2D spatial tensor, and is now a 3D tensor to include the batch dimension:
# Model was hooked and a tensor of shape (2, 3, 224, 224) was forwarded to it
amaps = cam_extractor(0, out)
for elt in amaps: print(elt.shape)
will, from now on, yield
torch.Size([2, 7, 7])
Full Changelog: https://github.com/frgfm/torch-cam/compare/v0.3.1...v0.3.2
Published by frgfm almost 3 years ago
This patch release adds new features to the demo and reorganizes the package for a clearer hierarchy.
Note: TorchCAM 0.3.1 requires PyTorch 1.5.1 or higher.
With release 0.3.0, the support of multiple target layers was added as well as CAM fusion. The demo was updated to automatically fuse CAMs when you hooked multiple layers (add a "+" separator between each layer name):
To anticipate further developments of the library, modules were renamed:
torchcam.cams
was renamed into torchcam.methods
torchcam.cams.utils
was renamed and made private (torchcam.methods._utils
) since it's API may evolve quicklytorchcam.methods.activation
rather than torchcam.cams.cam
torchcam.methods.gradient
rather than torchcam.cams.gradcam
0.3.0 | 0.3.1 |
---|---|
>>> from torchcam.cams import LayerCAM |
>>> from torchcam.methods import LayerCAM |
Full Changelog: https://github.com/frgfm/torch-cam/compare/v0.3.0...v0.3.1
Published by frgfm almost 3 years ago
This release extends CAM methods with Layer-CAM, greatly improves the core features (CAM computation for multiple layers at once, CAM fusion, support of torch.nn.Module
), while improving accessibility for entry users.
Note: TorchCAM 0.3.0 requires PyTorch 1.5.1 or higher.
The previous release saw the introduction of Score-CAM variants, and this one introduces you to Layer-CAM, which is meant to be considerably faster, while offering very competitive localization cues!
Just like any other CAM methods, you can now use it as follows:
from torchcam.cams import LayerCAM
# model = ....
# Hook the model
cam_extractor = LayerCAM(model)
Consequently, the illustration of visual outputs for all CAM methods has been updated so that you can better choose the option that suits you:
A class activation map is specific to a given layer in a model. To fully capture the influence of visual traits on your classification output, you might want to explore the CAMs for multiple layers.
For instance, here are the CAMs on the layers "layer2", "layer3" and "layer4" of a resnet18
:
from torchvision.io.image import read_image
from torchvision.models import resnet18
from torchvision.transforms.functional import normalize, resize, to_pil_image
import matplotlib.pyplot as plt
from torchcam.cams import LayerCAM
from torchcam.utils import overlay_mask
# Download an image
!wget https://www.woopets.fr/assets/races/000/066/big-portrait/border-collie.jpg
# Set this to your image path if you wish to run it on your own data
img_path = "border-collie.jpg"
# Get your input
img = read_image(img_path)
# Preprocess it for your chosen model
input_tensor = normalize(resize(img, (224, 224)) / 255., [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
# Get your model
model = resnet18(pretrained=True).eval()
# Hook the model
cam_extractor = LayerCAM(model, ["layer2", "layer3", "layer4"])
out = model(input_tensor.unsqueeze(0))
cams = cam_extractor(out.squeeze(0).argmax().item(), out)
# Plot the CAMs
_, axes = plt.subplots(1, len(cam_extractor.target_names))
for idx, name, cam in zip(range(len(cam_extractor.target_names)), cam_extractor.target_names, cams):
axes[idx].imshow(cam.numpy()); axes[idx].axis('off'); axes[idx].set_title(name);
plt.show()
Now, the way you would combine those together is up to you. By default, most approaches use an element-wise maximum. But, LayerCAM has its own fusion method:
# Let's fuse them
fused_cam = cam_extractor.fuse_cams(cams)
# Plot the raw version
plt.imshow(fused_cam.numpy()); plt.axis('off'); plt.title(" + ".join(cam_extractor.target_names)); plt.show()
# Overlay it on the image
result = overlay_mask(to_pil_image(img), to_pil_image(fused_cam, mode='F'), alpha=0.5)
# Plot the result
plt.imshow(result); plt.axis('off'); plt.title(" + ".join(cam_extractor.target_names)); plt.show()
torch.nn.Module
as target_layer
While making the API more robust, CAM constructors now also accept torch.nn.Module
as target_layer
. Previously, you had to pass the name of the layer as string, but you can now pass the object reference directly if you prefer:
from torchcam.cams import LayerCAM
# model = ....
# Hook the model
cam_extractor = LayerCAM(model, model.layer4)
Since CAMs can be used from localization or production pipelines, it is important to consider latency along with pure visual output quality. For this reason, a latency evaluation script has been included in this release along with a full benchmark table.
Should you wish to have latency metrics on your dedicated hardware, you can run the script on your own:
python scripts/eval_latency.py SmoothGradCAMpp --size 224
Do you prefer to only run code rather than write it? Perhaps you only want to tweak a few things?
Then enjoy the brand new Jupyter notebooks than you can either run locally or on Google Colab!
The ML community was recently blessed by HuggingFace with their beta of Spaces, which let you host free-of-charge your ML demos!
Previously, you were able to run the demo locally on deploy it on your own, but now, you can enjoy the live demo of TorchCAM π¨
Since CAM extractor can now compute the resulting maps for multiple layer at a time, the return type of all CAMs has been changed from torch.Tensor
to List[torch.Tensor]
with N elements, where N is the number of target layers.
0.2.0 | 0.3.0 |
---|---|
>>> from torchcam.cams import SmoothGradCAMpp >>> extractor = SmoothGradCAMpp(model) >>> out = model(input_tensor.unsqueeze(0)) >>> print(type(cam_extractor(out.squeeze(0).argmax().item(), out))) <class 'torch.Tensor'>
|
>>> from torchcam.cams import SmoothGradCAMpp >>> extractor = SmoothGradCAMpp(model) >>> out = model(input_tensor.unsqueeze(0)) >>> print(type(cam_extractor(out.squeeze(0).argmax().item(), out))) <class 'list'>
|
Implementations of CAM method
torch.nn.Module
as target_layer
or fc_layer
#83 (@frgfm)Side scripts to make the most out of TorchCAM
Verifications of the package well-being before release
torch.nn.Module
as target_layer
in CAM constructor #83 #88 (@frgfm)Online resources for potential users
.fuse_cams
method #93 (@frgfm)Other tools and implementations
class_idx
& target_layer
selection in the demo #67 (@frgfm)docutils
version constraint for documentation building #98 (@frgfm)Published by frgfm over 3 years ago
This release extends TorchCAM compatibility to 3D inputs, and improves documentation.
Note: TorchCAM 0.2.0 requires PyTorch 1.5.1 or higher.
The first papers about CAM methods were built for classification models using 2D (spatially) inputs. However, the latest methods can be extrapolated to higher dimension inputs and it's now live:
import torch
from torchcam.cams import SmoothGradCAMpp
# Define your model compatible with 3D inputs
video_model = ...
extractor = SmoothGradCAMpp(video_model)
# Forward your input
scores = model(torch.rand((1, 3, 32, 224, 224)))
# Retrieve the CAM
cam = extractor(scores[0].argmax().item(), scores)
While documentation was up-to-date with the latest commit on the main branch, previously if you were running an older release of the library, you had no corresponding documentation.
As of now, you can select the version of the documentation you wish to access (stable releases or latest commit):
Since spatial information is at the very core of TorchCAM, a minimal Streamlit demo app was added to explore the activation of your favorite models. You can run the demo with the following commands:
streamlit run demo/app.py
Here is how it renders retrieving the heatmap using SmoothGradCAMpp
on a pretrained resnet18
:
Implementations of CAM method
Verifications of the package well-being before release
Online resources for potential users
Other tools and implementations
overlay_mask
#38 (@alexandrosstergiou)overlay_mask
unittest #38 (@alexandrosstergiou)unittest
to pytest
#45 (@frgfm) and split test files by modulePublished by frgfm almost 4 years ago
This release adds an implementation of IS-CAM and greatly improves interface.
Note: torchcam 0.1.2 requires PyTorch 1.1 or newer.
Implementation of CAM extractor
New
Improvements
Fixes
Verifications of the package well-being before release
New
torchcam.cams
#13, #30 (@frgfm)Improvements
Online resources for potential users
New
Fixes
Other tools and implementations
New
Improvements
Fixes
Published by frgfm about 4 years ago
This release adds implementations of SmoothGradCAM++, Score-CAM and SS-CAM.
Note: torchcam 0.1.1 requires PyTorch 1.1 or newer.
brought to you by @frgfm
Implementation of CAM extractor
New
Improvements
Verifications of the package well-being before release
New
torchcam.cams
(#4, #5, #11)Online resources for potential users
Improvements
Other tools and implementations
Published by frgfm over 4 years ago
This release adds implementations of CAM, GradCAM and GradCAM++.
Note: torchcam 0.1.0 requires PyTorch 1.1 or newer.
brought to you by @frgfm
Implementation of gradient-based CAM extractor
New
Verifications of the package well-being before release
New
torchcam.cams
(#1, #2)torschscan.utils
(#1)Online resources for potential users
New
Other tools and implementations