Modality-Agnostic Learning for Medical Image Segmentation Using Multi-modality Self-distillation
APACHE-2.0 License
This is the official implementation of MAG-MS.
For the official implementation of MAGNET: A Modality-Agnostic Networks for Medical Image Segmentation, please check to branch stable-1.1.
MAG-MS is designed to be compatible with MAGNET (v1). The new MAGNET (v2) used in MAG-MS is designed to support multi-modality self-distillation and multi-modality feature distillation.
Use the package manager pip to install MAG-MS.
pip install magms
training_dataset = ...
validation_dataset = ...
magnet.build
function, or use the magnet.build_v2
(UNETR backbone)/magnet.build_v2_unet
(3D UNet backbone) function for the new MAGNET used in MAG-MSnum_modalities: int = ...
num_classes: int = ...
img_size: Union[int, Sequence[int]] = ...
model = magnet.build_v2(num_modalities, num_classes, img_size, target_dict=target_dict)
magnet.nn
framework to customize MAGNET backboneencoder1: torch.nn.Module = ...
encoder2: torch.nn.Module = ...
fusion: torch.nn.Module = ...
decoder: torch.nn.Module = ...
model = magnet.nn.MAGNET2(encoder1, encoder2, fusion=fusion, decoder=decoder)
main_loss_fn: list[Callable[[Any, Any], torch.Tensor]] = ...
kldiv_loss_fn: list[Callable[[Any, Any], torch.Tensor]] = ...
mse_loss_fn: list[Callable[[Any, Any], torch.Tensor]] = ...
self_distillation_loss_fn = magnet.losses.MAGSelfDistillationLoss(main_loss_fn, kldiv_loss_fn)
feature_distillation_loss_fn = magnet.losses.MAGFeatureDistillationLoss(self_distillation_loss_fn, mse_loss_fn)
loss_fn = feature_distillation_loss_fn
optimizer = ...
metric_fns = ...
epochs = ...
callbacks = ...
manager = magnet.Manager(model, optimizer, loss_fn=loss_fn, metric_fns=metric_fns)
manager.fit(training_dataset, epochs, val_dataset=validation_dataset, callbacks=callbacks)
summary.test(validation_dataset)
print(summary)
magnet.MonaigManager
instead of Manager
post_labels
and post_predicts
post_labels = [...]
post_predicts = [...]
manager = magnet.MonaigManager(model, post_labels=post_labels, post_predicts=post_predicts, optimizer=optimizer, loss_fn=loss_fn, metric_fns=metric_fns)
@article{he2023modality,
title={Modality-Agnostic Learning for Medical Image Segmentation Using Multi-modality Self-distillation},
author={He, Qisheng and Summerfield, Nicholas and Dong, Ming and Glide-Hurst, Carri},
journal={arXiv preprint arXiv:2306.03730},
year={2023}
}