fairseq_mmt

MIT License

Stars
4

multimodal machine translation(MMT)

Our dependency

  • PyTorch version == 1.9.1
  • Python version == 3.6.7
  • timm version == 0.4.12
  • vizseq version == 0.1.15
  • nltk verison == 3.6.4
  • sacrebleu version == 1.5.1

Install fairseq

cd fairseq_mmt
pip install --editable ./

Multi30k data & Flickr30k entities

Multi30k data from here and here flickr30k entities data from here Here, We get multi30k text data from Revisit-MMT

cd fairseq_mmt
git clone https://github.com/BryanPlummer/flickr30k_entities.git
cd flickr30k_entities
unzip annotations.zip

# download data and create a directory anywhere
flickr30k
 flickr30k-images
 test2017-images
 test_2016_flickr.txt
 test_2017_flickr.txt
 test_2017_mscoco.txt
 test_2018_flickr.txt
 testcoco-images
 train.txt
 val.txt

Extract image feature

1. Vision Transformer

# please read scripts/README.md to modify the code of timm firstly!
#        
python3 scripts/get_img_feat.py --dataset train --model vit_base_patch16_384 --path ../flickr30k

script parameters:

  • dataset: choices=['train', 'val', 'test2016', 'test2017', 'testcoco']
  • model: 'vit_base_patch16_384', that you can find in timm.list_models()
  • path: '/path/to/your/flickr30k'

2. DETR

# please run scripts/get_img_feat_detr.py to download DETR offical code and model firstly
# then modify detr.py (in DETR offical code) to return image feature according to the above image
#        
python3 scripts/get_img_feat_detr.py --dataset train --path ../flickr30k

script parameters:

  • dataset: choices=['train', 'val', 'test2016', 'test2017', 'testcoco']
  • path: '/path/to/your/flickr30k'

Create masking data

pip3 install stanfordcorenlp 
wget https://nlp.stanford.edu/software/stanford-corenlp-latest.zip
unzip stanford-corenlp-latest.zip

cd fairseq_mmt
cat data/multi30k/train.en data/multi30k/valid.en data/multi30k/test.2016.en > train_val_test2016.en
python3 get_and_record_noun_from_f30k_entities.py # recording noun and nouns position in each sentence by flickr30k_entities
python3 record_color_people_position.py

cd data/masking
# create en-de masking data
python3 match_origin2bpe_position.py en-de
python3 create_masking_multi30k.py en-de         # create mask1-4 & color & people data 
# create en-fr masking data
python3 match_origin2bpe_position.py en-fr
python3 create_masking_multi30k.py en-fr         # create mask1-4 & color & people data 

sh preprocess_mmt.sh

Train and Test

1. Preprocess(mask1 as an example)

src='en'
tgt='de'
mask=mask1  # mask1, mask2, mask3, maskc(color), maskp(character)
TEXT=data/multi30k-en-$tgt.$mask

fairseq-preprocess --source-lang $src --target-lang $tgt \
  --trainpref $TEXT/train \
  --validpref $TEXT/valid \
  --testpref $TEXT/test.2016,$TEXT/test.2017,$TEXT/test.coco \
  --destdir data-bin/multi30k.en-$tgt.$mask \
  --workers 8 --joined-dictionary \
  --srcdict data/dict.en2de_$mask.txt

sh preprocess.sh to generate no masking data

2. Train(mask1 as an example)

mask_data=mask1
data_dir=multi30k.en-de.mask1
src_lang='en'
tgt_lang='de'
image_feat=vit_base_patch16_384
tag=$image_feat/$image_feat-$mask_data
save_dir=checkpoints/multi30k-en2de/$tag
image_feat_path=data/$image_feat
image_feat_dim=768

criterion=label_smoothed_cross_entropy
fp16=1
lr=0.005
warmup=2000
max_tokens=4096
update_freq=1
keep_last_epochs=10
patience=10
max_update=8000
dropout=0.3

arch=image_multimodal_transformer_SA_top
SA_attention_dropout=0.1
SA_image_dropout=0.1
SA_text_dropout=0

CUDA_VISIBLE_DEVICES=0,1 fairseq-train data-bin/$data_dir
  --save-dir $save_dir
  --distributed-world-size 2 -s $src_lang -t $tgt_lang
  --arch $arch
  --dropout $dropout
  --criterion $criterion --label-smoothing 0.1
  --task image_mmt --image-feat-path $image_feat_path --image-feat-dim $image_feat_dim
  --optimizer adam --adam-betas '(0.9, 0.98)'
  --lr $lr --min-lr 1e-09 --lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates $warmup
  --max-tokens $max_tokens --update-freq $update_freq --max-update $max_update
  --find-unused-parameters
  --share-all-embeddings
  --patience $patience
  --keep-last-epochs $keep_last_epochs
  --SA-image-dropout $SA_image_dropout
  --SA-attention-dropout $SA_attention_dropout
  --SA-text-dropout $SA_text_dropout

you can run train_mmt.sh instead of scripts above.

3. Test(mask1 as an example)

#sh translation_mmt.sh $1 $2
sh translation_mmt.sh mask1 vit_base_patch16_384  # origin text is mask0

script parameters:

  • $1: choices=['mask1', 'mask2', 'mask3', 'mask4', 'maskc', 'maskp', 'mask0']
  • $2: 'vit_base_patch16_384', that you can find in timm.list_models()

Visualization

# uncomment line429-431,487-488 in /fairseq/models/image_multimodal_transformer_SA.py
# decode again, generate tensors to the checkpoint dir
# prepare files needed in /visualization/visualization.py
cd visualization
python3 visualization.py
Related Projects