This repository focuses on leveraging OpenAI's Whisper model for speech recognition in Chinese (Mandarin) and Taiwanese Hokkien languages. It includes tools and scripts for data preprocessing, model training, and evaluation, tailored to improve speech recognition accuracy for these languages.
An advanced Automatic Speech Recognition (ASR) system for Chinese (Traditional) and Taiwanese, leveraging the power of OpenAI's Whisper model. This project supports full fine-tuning, Parameter-Efficient Fine-Tuning (PEFT), and streaming inference, optimized for T4 GPUs.
ChineseTaiwaneseWhisper/
โโโ scripts/
โ โโโ gradio_interface.py
โ โโโ infer.py
โ โโโ train.py
โโโ src/
โ โโโ config/
โ โโโ crawler/
โ โโโ data/
โ โโโ models/
โ โโโ trainers/
โ โโโ inference/
โโโ tests/
โโโ requirements.txt
โโโ setup.py
โโโ README.md
Clone the repository:
git clone https://github.com/sandy1990418/ChineseTaiwaneseWhisper.git
cd ChineseTaiwaneseWhisper
Set up a virtual environment:
python -m venv venv
source venv/bin/activate # On Windows, use `venv\Scripts\activate`
Install dependencies:
pip install -r requirements.txt
python scripts/train.py --model_name_or_path "openai/whisper-small" \
--language "chinese" \
--dataset_name "mozilla-foundation/common_voice_11_0" \
--youtube_data_dir "./youtube_data" \
--output_dir "./whisper-finetuned-zh-tw" \
--num_train_epochs 3 \
--per_device_train_batch_size 16 \
--learning_rate 1e-5 \
--fp16 \
--timestamp False
python scripts/train.py --model_name_or_path "openai/whisper-small" \
--language "chinese" \
--use_peft \
--peft_method "lora" \
--dataset "common_voice_13_train","youtube_data" \
--output_dir "Checkpoint_Path" \
--num_train_epochs 10 \
--per_device_train_batch_size 4 \
--learning_rate 1e-5 \
--fp16\
--timestamp True
Argument | Description | Default |
---|---|---|
--model_name_or_path |
Path or name of the pre-trained model | Required |
--language |
Language for fine-tuning (e.g., "chinese", "taiwanese") | Required |
--dataset_name |
Name of the dataset to use | Required |
--dataset_config_names |
Configuration name for the dataset | Required |
--youtube_data_dir |
Directory containing YouTube data | Optional |
--output_dir |
Directory to save the fine-tuned model | Required |
--num_train_epochs |
Number of training epochs | 3 |
--per_device_train_batch_size |
Batch size per GPU/CPU for training | 16 |
--learning_rate |
Initial learning rate | 3e-5 |
--fp16 |
Use mixed precision training | False |
--use_timestamps |
Include timestamp information in training | False |
--use_peft |
Use Parameter-Efficient Fine-Tuning | False |
--peft_method |
PEFT method to use (e.g., "lora") | None |
Launch the interactive web interface:
python scripts/gradio_interface.py
Access the interface at http://127.0.0.1:7860
(default URL).
Note: For streaming mode, use Chrome instead of Safari to avoid CPU memory issues.
python scripts/infer.py --model_path openai/whisper-small \
--audio_files audio.wav \
--mode batch \
--use_timestamps False
Argument | Description | Default |
---|---|---|
--model_path |
Path to the fine-tuned model | Required |
--audio_files |
Path(s) to audio file(s) for transcription | Required |
--mode |
Inference mode ("batch" or "stream") | "batch" |
--use_timestamps |
Include timestamps in transcription | False |
--device |
Device to use for inference (e.g., "cuda", "cpu") | "cuda" if available, else "cpu" |
--output_dir |
Directory to save transcription results | "output" |
--use_peft |
Use PEFT model for inference | False |
--language |
Language of the audio (e.g., "chinese", "taiwanese") | "chinese" |
Collect YouTube data:
python src/crawler/youtube_crawler.py \
--playlist_urls "YOUTUBE_PLAYLIST_URL" \
--output_dir ./youtube_data \
--dataset_name youtube_asr_dataset \
--file_prefix language_prefix
Argument | Description | Default |
---|---|---|
--playlist_urls |
YouTube playlist URL(s) to crawl | Required |
--output_dir |
Directory to save audio files and dataset | "./output" |
--dataset_name |
Name of the output dataset file | "youtube_dataset" |
--file_prefix |
Prefix for audio and subtitle files | "youtube" |
dataset_name
parameterpeft_method
and configurations in src/config/train_config.py
ChineseTaiwaneseASRInference
in src/inference/flexible_inference.py
Run tests with pytest:
pytest tests/
For detailed output:
pytest -v tests/
Check test coverage:
pip install pytest-cov
pytest --cov=src tests/
On a T4 GPU, without any acceleration methods:
This baseline gives you an idea of the default performance. Depending on your specific needs, you may want to optimize further or use acceleration techniques.
To address memory issues or improve performance on T4 GPUs:
--per_device_train_batch_size
)
--gradient_accumulation_steps
)
--fp16
)
For further performance improvements:
Note: The actual performance may vary depending on your specific hardware, audio complexity, and chosen optimization techniques. Always benchmark your specific use case.
graph TD
A[Start] --> B[Set Up System]
B --> C{Listen for Audio}
C -->|Audio Received| D[Check for Speech]
D -->|Speech Found| E[Transcribe Audio]
D -->|No Speech| F[Skip Transcription]
E --> G[Output Result]
F --> G
G --> C
C -->|No More Audio| H[Finish Up]
H --> I[End]
graph TD
A[Start] --> B[Initialize Audio Stream]
B --> C[Initialize ASR Model]
C --> D[Initialize VAD Model]
D --> E[Initialize Audio Buffer]
E --> F[Initialize ThreadPoolExecutor]
F --> G{Receive Audio Chunk}
G -->|Yes| H[Add to Audio Buffer]
H --> I{Buffer Full?}
I -->|No| G
I -->|Yes| J[Submit Chunk to ThreadPool]
J --> K[Apply VAD]
K --> L{Speech Detected?}
L -->|No| O[Slide Buffer]
L -->|Yes| M[Process Audio Chunk]
M --> N[Generate Partial Transcription]
N --> O
O --> G
G -->|No| P[Process Remaining Audio]
P --> Q[Finalize Transcription]
Q --> R[End]
subgraph "Parallel Processing"
J
K
L
M
N
end
The Chinese/Taiwanese Whisper ASR project uses a specific format for its datasets to ensure compatibility with the training and inference scripts. The format can include or exclude timestamps, depending on the configuration.
Each item in the dataset represents an audio file and its corresponding transcription:
{
"audio": {
"path": "path/to/audio/file.wav",
"sampling_rate": 16000
},
"sentence": "The transcription of the audio in Chinese or Taiwanese.",
"language": "zh-TW", # or "zh-CN" for Mandarin, "nan" for Taiwanese, etc.
"duration": 10.5 # Duration of the audio in seconds
}
labels:
<|startoftranscript|><|zh|><|transcribe|><|notimestamps|>ๅฐๅ็ฎ<|endoftext|>
In this example:
<|startoftranscript|>
: Marks the beginning of the transcription<|zh|>
: Indicates the language (Chinese)<|transcribe|>
: Denotes that this is a transcription task<|notimestamps|>
: Indicates that no timestamps are includedๅฐๅ็ฎ
: The actual transcription<|endoftext|>
: Marks the end of the transcriptionlabels:
<|startoftranscript|><|zh|><|transcribe|><|0.00|>่ๅฐๆจๅธๆไบคๆๅถไฝ็จๆๅคง็้่ณผ<|6.00|><|endoftext|>
In this example:
<|startoftranscript|>
, <|zh|>
, and <|transcribe|>
: Same as above<|0.00|>
: Timestamp indicating the start of the transcription (0 seconds)่ๅฐๆจๅธๆไบคๆๅถไฝ็จๆๅคง็้่ณผ
: The actual transcription<|6.00|>
: Timestamp indicating the end of the transcription (6 seconds)<|endoftext|>
: Marks the end of the transcriptionuse_timestamps
parameter in your training and inference scripts.If you're preparing your own dataset:
<|startoftranscript|>
, <|zh|>
, etc.).<|seconds.decimals|>
before each segment of transcription.<|notimestamps|>
if not including timestamp information.<|endoftext|>
.By following this format, you ensure that your dataset is compatible with the Chinese/Taiwanese Whisper ASR system, allowing for efficient training and accurate inference.
Development Mode:
fastapi dev api_main.py
Production Mode:
fastapi run api_main.py
The API will be accessible at http://0.0.0.0:8000
by default.
bash app/docker.sh
Access the Swagger UI documentation at http://localhost:8000/docs
when the server is running.
Health Check:
curl -k http://localhost:8000/health
Transcribe Audio:
curl -k -X POST -H "Content-Type: multipart/form-data" -F "file=@/path/to/your/audio/file.wav" http://localhost:8000/transcribe
Replace /path/to/your/audio/file.wav
with the actual path to your audio file.
List All Transcriptions:
curl -k http://localhost:8000/transcriptions
Get a Specific Transcription:
curl -k http://localhost:8000/transcription/{transcription_id}
Replace {transcription_id}
with the actual UUID of the transcription.
Delete a Transcription:
curl -k -X DELETE http://localhost:8000/transcription/{transcription_id}
Replace {transcription_id}
with the UUID of the transcription you want to delete.
This project is licensed under the MIT License. See the LICENSE file for details.