PyTorch implementation for Channel Distillation
PyTorch implement of Channel Distillation: Channel-Wise Attention for Knowledge Distillation
Note
In our code, kdv2 means GKD and lrd2 means EDT.
├── cifar_config.py # Hyperparameters
├── cifar_train.py
├── data
│ └── directory_of_data.md
├── imagenet_config.py # Hyperparameters
├── imagenet_train.py
├── losses
│ ├── cd_loss.py # CD Loss
│ ├── ce_loss.py
│ ├── __init__.py
│ └── kd_loss.py # GKD Loss
├── models
│ ├── channel_distillation.py # Distillation Network
│ ├── __init__.py
│ └── resnet.py
├── pretrain
│ └── path_of_teacher_checkpoint.md
├── README.md
└── utils
├── average_meter.py
├── data_prefetcher.py
├── __init__.py
├── logutil.py
├── metric.py
└── util.py # Early Decay Teacher
python >= 3.7 torch >= 1.4.0 torchvision >= 0.5.0
images should be arranged in this way
./data/ILSVRC2012/train/dog/xxx.png
./data/ILSVRC2012/train/cat/xxy.png
./data/ILSVRC2012/val/dog/xxx.png
./data/ILSVRC2012/val/cat/xxy.png
Note
Teacher checkpoint will be downloaded automatically.
Running the following command and experiment will be launched.
CUDA_VISIBLE_DEVICES=0 python3 ./imagenet_train.py
If you want to run other experiments, you just need modify following losses in imagenet_config.py
loss_list = [
{"loss_name": "CELoss", "loss_rate": 1, "factor": 1, "loss_type": "ce_family", "loss_rate_decay": "lrdv1"},
{"loss_name": "CDLoss", "loss_rate": 6, "factor": 1, "loss_type": "fd_family", "loss_rate_decay": "lrdv1"},
]
loss_list = [
{"loss_name": "CELoss", "loss_rate": 1, "factor": 1, "loss_type": "ce_family", "loss_rate_decay": "lrdv1"},
{"loss_name": "KDLossv2", "T": 1, "loss_rate": 1, "factor": 1, "loss_type": "kdv2_family", "loss_rate_decay": "lrdv1"},
{"loss_name": "CDLoss", "loss_rate": 6, "factor": 0.9, "loss_type": "fd_family", "loss_rate_decay": "lrdv1"},
]
loss_list = [
{"loss_name": "CELoss", "loss_rate": 1, "factor": 1, "loss_type": "ce_family", "loss_rate_decay": "lrdv2"},
{"loss_name": "KDLossv2", "T": 1, "loss_rate": 1, "factor": 1, "loss_type": "kdv2_family", "loss_rate_decay": "lrdv2"},
{"loss_name": "CDLoss", "loss_rate": 6, "factor": 0.9, "loss_type": "fd_family", "loss_rate_decay": "lrdv2"},
]
Method | Model | Top-1 error(%) | Top-5 error(%) |
---|---|---|---|
teacher | ResNet34 | 26.73 | 8.74 |
student | ResNet18 | 30.43 | 10.76 |
KD | ResNet34-ResNet18 | 29.50 | 9.52 |
CD(our) | ResNet34-ResNet18 | 28.53 | 9.56 |
CD+GKD(our) | ResNet34-ResNet18 | 28.26 | 9.41 |
CD+GKD+EDT(our) | ResNet34-ResNet18 | 27.61 | 9.2 |
Method | Model | Top-1 error(%) | Top-5 error(%) |
---|---|---|---|
teacher | ResNet34 | 26.73 | 8.74 |
student | ResNet18 | 30.43 | 10.76 |
KD | ResNet34-ResNet18 | 29.50 | 9.52 |
FitNets | ResNet34-ResNet18 | 29.34 | 10.77 |
AT | ResNet34-ResNet18 | 29.30 | 10.00 |
RKD | ResNet34-ResNet18 | 28.46 | 9.74 |
CD+GKD+EDT(our) | ResNet34-ResNet18 | 27.61 | 9.2 |
CIFAR100 dataset will be downloaded automatically.
Note
Download the teacher checkpoint from here Then, put the checkpoint in the pretrain directory
Running the following command and experiment will be launched.
CUDA_VISIBLE_DEVICES=0 python3 ./cifar_train.py
If you want to run other experiments, you just need modify following losses in cifar_config.py
loss_list = [
{"loss_name": "CELoss", "loss_rate": 1, "factor": 1, "loss_type": "ce_family", "loss_rate_decay": "lrdv1"},
{"loss_name": "CDLoss", "loss_rate": 6, "factor": 1, "loss_type": "fd_family", "loss_rate_decay": "lrdv1"},
]
loss_list = [
{"loss_name": "CELoss", "loss_rate": 1, "factor": 1, "loss_type": "ce_family", "loss_rate_decay": "lrdv1"},
{"loss_name": "KDLossv2", "T": 1, "loss_rate": 0.1, "factor": 1, "loss_type": "kdv2_family", "loss_rate_decay": "lrdv1"},
{"loss_name": "CDLoss", "loss_rate": 6, "factor": 0.9, "loss_type": "fd_family", "loss_rate_decay": "lrdv1"},
]
loss_list = [
{"loss_name": "CELoss", "loss_rate": 1, "factor": 1, "loss_type": "ce_family", "loss_rate_decay": "lrdv2"},
{"loss_name": "KDLossv2", "T": 1, "loss_rate": 0.1, "factor": 1, "loss_type": "kdv2_family","loss_rate_decay": "lrdv2"},
{"loss_name": "CDLoss", "loss_rate": 6, "factor": 0.9, "loss_type": "fd_family", "loss_rate_decay": "lrdv2"},
]
Method | Model | Top-1 error(%) | Top-5 error(%) |
---|---|---|---|
teacher | ResNet152 | 19.09 | 4.45 |
student | ResNet50 | 22.02 | 5.74 |
KD | ResNet152-ResNet50 | 20.36 | 4.94 |
CD(our) | ResNet152-ResNet50 | 20.08 | 4.78 |
CD+GKD(our) | ResNet152-ResNet50 | 19.49 | 4.85 |
CD+GKD+EDT(our) | ResNet152-ResNet50 | 18.63 | 4.29 |