Recaption large (Web)Datasets with vllm and save the artifacts.
APACHE-2.0 License
Recaption large (Web)Datasets with vllm
and save the artifacts. It is NOT a library. It, instead, provides reference points that you're free to use and modify.
[!NOTE] I use the code of this repository for my projects and I don't claim this project to be out of the world. If you want to contribute an enhancement feature, you're more than welcome to open a PR. I'd greatly appreciate it.
Install the requirements: pip install -r requirements.txt
. Then run:
python main.py \
--data_path="https://huggingface.co/datasets/pixparse/cc3m-wds/resolve/main/cc3m-train-0000.tar"
This will recaption a single shard of the CC3M dataset and will serialize the artifacts inside a directory called sample_outputs
. This directory will have:
If you want to use multiple shards then do:
# full CC3M training set
python main.py \
--data_path="pipe:curl -s -f -L https://huggingface.co/datasets/pixparse/cc3m-wds/resolve/main/cc3m-train-{0000..0575}.tar"
You can allow watermark detection by passing --detect_watermarks
. Note that this will require the following things:
onnx
and onnxruntime
dependencies.pip install git+https://github.com/sayakpaul/watermark-detection
. Then follow the steps to obtain the ONNX model needed for watermark detection.By default, the script will use all the available GPUs. Refer to the main.py
script for a full list of the supported CLI arguments.
I tested the above commands on two A100s and on eight H100s.
Recaptioning large image datasets has become a da-facto standard for the image generation community. So, I wanted to have a simple-yet-performant utility that would allow me to recaption large image datasets like CC3M. This is why, I chose vllm
as it provides optimized inference across multiple GPUs off-the-shelf.
webdataset
is a common format used by practitioners to conduct training on large-scale datasets. So, I chose that as an entrypoint. Specifically, I assume that your image-caption pair dataset is already sharded into multiple webdataset
archives. Refer here as an example.
I need to be able to use multiple GPUs, overlapping communication and computation. But this project also works with a single GPU.
There has to be artifact serialization. This project serializes the original image, original caption, and the predicted caption in separate threads, not blocking the GPU(s).
There has to be watermark detection in the data curation pipeline at minimum. Otherwise, it messes up with the generation quality. In this project, it happens during dataloading. To not clog the processes, we make use of ONNX for fast CPU-based inferencing.
Failures can happen during the captioning process so we need to able to avoid duplication. I have added a simple ExistsFilter
filter to filter out the existing images that were serialized before interruptions.
Ultimately, you'd want to modify the codebase to suit your needs.
.
├── config.py -- specifies the prompt to be used to generate the captions and model id.
├── data_processing.py -- webdataset loading and processing code including watermark detection and caching.
├── main.py -- main entrypoint.
├── model.py -- loads the vllm engine and houses the simple inference function.
└── utils.py -- misc utilities.
If you have anything to modify, feel free to go into these files and modify them as per your needs.
PIL
. Simply replace your Pillow
installation to use Pillow-SIMD
for better speed.hf_transfer
for faster downloads from the Hugging Face Hub. Refer here to know more.Would really appreciate some contributions too :-)
vllm
. But this restricts higher throughputs a bit.vllm
.vllm
for the amazing project.webdataset
for scalability.