Official code for the paper "Context-Aware Language Modeling for Goal-Oriented Dialogue Systems"
MIT License
Official code for the paper "Context-Aware Language Modeling for Goal-Oriented Dialogue Systems"
conda create --name CALM python=3.9.7
conda activate CALM
pip install -r requirements.txt
conda install pytorch==1.9.0 cudatoolkit=11.3 -c pytorch -c conda-forge
export PYTHONPATH="$PWD/offline_airdialogue"
outputs/
folder contains checkpoints for our main model, our task pretrained model, and our customer bot.(Note: all training runs use wandb by default, you can turn off wandb syncing in the config.)
cd scripts/train
To run data-parallel multi-GPU training, on any of the commands below replace python <script_path>
with python -m torch.distributed.launch --nproc_per_node <n_GPUs> --use_env <script_path>
.
Pretraining CALM (two variants of the auxiliary loss function)
script: python train_pretrain_table_agent.py
config: config/train_pretrain_table_agent.yaml
script: python train_pretrain_simplified_aux_gpt2.py
config: config/train_pretrain_simplified_aux_gpt2.yaml
Training the customer bot
python train_customer.py
config/train_customer_bot.yaml
Training CALM (two variants of the auxiliary loss function)
script: python train_real_table_agent.py
config: config/train_real_table_agent.yaml
script: python train_simplified_aux_gpt2.py
config: config/train_simplified_aux_agent.yaml
Training Standard LM
python train_basic_agent.py
config/train_basic_agent.yaml
Training the reward model for Model Based Rollout Planning
python train_constraint_parser.py
config/train_constraint_parser.yaml
cd scripts/eval
Simulated Evaluation
python selfplay_eval.py
config/selfplay_eval.yaml
selfplay/outputs_file
in the config. To print out the success rate for the selfplay run: python compute_results.py --results_file <your_eval_outputs_file>
CUDA_VISIBLE_DEVICES=<comma_seperated_list_of_gpu_indicies>
Language Quality Evaluation
python language_quality_eval.py
config/language_eval.yaml
python -m torch.distributed.launch --nproc_per_node <n_GPUs> --use_env language_quality_eval.py