Streaming-STT-1.5B / README.md
huseinzol05's picture
Update README.md
e4f2cd5 verified
metadata
library_name: transformers
language:
  - ms
  - en
base_model:
  - Qwen/Qwen2.5-1.5B
datasets:
  - malaysia-ai/Malaysian-STT

Streaming-STT-1.5B

Continue pretraining Qwen/Qwen2.5-1.5B on malaysia-ai/Malaysian-STT, natively,

  1. Streaming mode by using <|streaming|> prefix.
  2. Semantic VAD by predicting <|endofspeech|> token probability for streaming mode.
  3. Whole mode by using <|whole|> prefix.
  4. Support segment level timestamp by using <|segment|> prefix.
  5. Support word level timestamp by using <|word|> prefix.
  6. Beyond 30 seconds audio prediction.
  7. Plug and play in any continuous batching serving framework such as vLLM, just another Qwen2.5 model.
  8. Use GLM4 Speech Tokenizer, 12.5 TPS. Discrete tokens work like a charm with prefix caching, especially for streaming.

Still on training.

How do we train

  1. Multipacking with proper document masking on 10240 context length.
  2. FP32-BF16 mixed precision training.
  3. Full parameter finetuning.
  4. WanDB at https://wandb.ai/huseinzol05/Qwen-Qwen2.5-1.5B-STT-10k

How to

First you need to install the speech tokenizer,

pip3 install git+https://github.com/malaysia-ai/glm4-audio-tokenizer

And load the model,

from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
from glm4_audio_tokenizer import Glm4Tokenizer
import torch

glm4 = Glm4Tokenizer().to(torch.float16).cuda()
model = AutoModelForCausalLM.from_pretrained('malaysia-ai/Streaming-STT-1.5B').cuda()
tokenizer = AutoTokenizer.from_pretrained('malaysia-ai/Streaming-STT-1.5B')
streamer = TextStreamer(tokenizer)

Whole segment timestamp mode

# !wget https://github.com/mesolitica/malaya-speech/raw/refs/heads/master/speech/record/husein-chat.mp3
speech_tokens = glm4.tokenize(['husein-chat.mp3'])
token = ''.join([f'<|s{t}|>' for t in speech_tokens[0]]) + '<|endofspeech|>'
prompt = '<|whole|><|segment|>' + token
generate_kwargs = dict(
    **tokenizer(prompt, return_tensors = 'pt').to('cuda'),
    max_new_tokens=1024,
    top_p=0.95,
    top_k=50,
    temperature=0.1,
    do_sample=True,
    repetition_penalty=1.0,
    streamer=streamer
)
generation_output = model.generate(**generate_kwargs)

Output,

<|0.30|> Hai,<|0.56|><|1.14|> saya adalah pembantu<|2.14|><|2.48|> AI anda.<|2.96|><|3.56|> Selamat berkenalan!<|4.44|><|5.00|> Apa yang saya boleh tolong<|6.16|><|6.48|> untuk buatkan hari anda lebih ceria?<|8.58|><|endoftext|>

Whole word timestamp mode

# !wget https://github.com/mesolitica/malaya-speech/raw/refs/heads/master/speech/record/husein-chat.mp3
speech_tokens = glm4.tokenize(['husein-chat.mp3'])
token = ''.join([f'<|s{t}|>' for t in speech_tokens[0]]) + '<|endofspeech|>'
prompt = '<|whole|><|whole|>' + token
generate_kwargs = dict(
    **tokenizer(prompt, return_tensors = 'pt').to('cuda'),
    max_new_tokens=1024,
    top_p=0.95,
    top_k=50,
    temperature=0.1,
    do_sample=True,
    repetition_penalty=1.0,
    streamer=streamer
)
generation_output = model.generate(**generate_kwargs)

Output,

<|0.30|> Hai,<|0.56|><|1.14|> saya<|1.36|><|1.48|> adalah<|1.76|><|1.82|> pembantu<|2.20|><|2.38|> AI<|2.66|><|2.82|> anda.<|3.04|><|3.64|> Selamat<|3.94|><|4.00|> berkenalan!<|4.50|><|5.06|> Apa<|5.20|><|5.28|> yang<|5.40|><|5.46|> saya<|5.60|><|5.66|> boleh<|5.82|><|5.86|> tolong<|6.18|><|6.50|> untuk<|6.70|><|6.76|> buatkan<|7.08|><|7.16|> hari<|7.36|><|7.50|> anda<|7.66|><|7.80|> lebih<|7.98|><|8.04|> ceria?<|8.56|><|endoftext|>

Streaming segment timestamp mode

# !wget https://github.com/mesolitica/malaya-speech/raw/refs/heads/master/speech/record/husein-chat-part1.mp3
# !wget https://github.com/mesolitica/malaya-speech/raw/refs/heads/master/speech/record/husein-chat-part2.mp3
# !wget https://github.com/mesolitica/malaya-speech/raw/refs/heads/master/speech/record/husein-chat-part3.mp3

speech_tokens = glm4.tokenize(['husein-chat-part1.mp3', 'husein-chat-part2.mp3', 'husein-chat-part3.mp3'])

prompt = '<|streaming|><|segment|>'
for i in range(len(speech_tokens)):
    token = ''.join([f'<|s{t}|>' for t in speech_tokens[i]]) + '<|endofspeech|>'
    
    input_ids = tokenizer(prompt + token, return_tensors = 'pt').to('cuda')
    generate_kwargs = dict(
        **input_ids,
        max_new_tokens=1024,
        top_p=0.95,
        top_k=50,
        temperature=0.1,
        do_sample=True,
        repetition_penalty=1.0,
    )
    generation_output = model.generate(**generate_kwargs)
    new_prompt = tokenizer.decode(generation_output[0])
    prompt = new_prompt
    print(f'index {i + 1}: {prompt}')
    print()

Output,

index 1: <|0.02|> Hai. Saya ada laporan bantuan IIN dah.<|3.26|><|endoftext|>

index 2: <|3.70|> Dah lama berkenalan. Apa yang saya boleh tolong?<|6.94|><|endoftext|>

index 3: <|7.36|> Untuk buatkan hari anda lebih ceria.<|9.56|><|endoftext|>

Streaming word timestamp mode

# !wget https://github.com/mesolitica/malaya-speech/raw/refs/heads/master/speech/record/husein-chat-part1.mp3
# !wget https://github.com/mesolitica/malaya-speech/raw/refs/heads/master/speech/record/husein-chat-part2.mp3
# !wget https://github.com/mesolitica/malaya-speech/raw/refs/heads/master/speech/record/husein-chat-part3.mp3

speech_tokens = glm4.tokenize(['husein-chat-part1.mp3', 'husein-chat-part2.mp3', 'husein-chat-part3.mp3'])

prompt = '<|streaming|><|word|>'
for i in range(len(speech_tokens)):
    token = ''.join([f'<|s{t}|>' for t in speech_tokens[i]]) + '<|endofspeech|>'
    
    input_ids = tokenizer(prompt + token, return_tensors = 'pt').to('cuda')
    generate_kwargs = dict(
        **input_ids,
        max_new_tokens=1024,
        top_p=0.95,
        top_k=50,
        temperature=0.1,
        do_sample=True,
        repetition_penalty=1.0,
    )
    generation_output = model.generate(**generate_kwargs)
    new_prompt = tokenizer.decode(generation_output[0])
    prompt = new_prompt
    print(f'index {i + 1}: {prompt}')
    print()

Output,

index 1: <|0.02|> Hai.<|0.36|><|0.40|> Saya<|1.14|><|1.34|> ada<|1.46|><|1.54|> laporan<|1.90|><|1.96|> tu<|2.02|><|2.20|> AIA<|2.54|><|2.68|> anda.<|3.08|><|endoftext|>

index 2: <|3.60|> Selamat<|4.04|><|4.10|> berkenalan.<|4.62|><|4.66|> Apa<|4.72|><|4.76|> yang<|4.82|><|4.86|> saya<|4.92|><|4.96|> boleh<|5.06|><|5.10|> tolong?<|5.44|><|5.48|> Apa<|5.52|><|5.56|> yang<|5.62|><|5.66|> saya<|5.72|><|5.76|> boleh<|5.84|><|5.88|> tolong?<|6.00|><|6.04|> Apa<|6.08|><|6.12|> yang<|6.18|><|6.22|> saya<|6.28|><|6.32|> boleh<|6.40|><|6.44|> tolong?<|6.56|><|6.60|> Apa<|6.64|><|6.68|> yang<|6.74|><|6.78|> saya<|6.84|><|6.88|> boleh<|6.96|><|7.00|> tolong?<|7.10|><|endoftext|>

index3: <|7.54|> Untuk<|7.80|><|7.88|> buatkan<|8.22|><|8.30|> hari<|8.50|><|8.62|> anda<|8.80|><|8.92|> lebih<|9.10|><|9.14|> ceria.<|9.42|><|endoftext|>

Semantic VAD

# !wget https://github.com/mesolitica/malaya-speech/raw/refs/heads/master/speech/record/husein-chat-not-proper-cut.mp3
# !wget https://github.com/mesolitica/malaya-speech/raw/refs/heads/master/speech/record/husein-chat-proper-cut.mp3

speech_tokens = glm4.tokenize(['husein-chat-not-proper-cut.mp3', 'dummy-record.mp3', 'husein-chat-proper-cut.mp3', 'husein-chat-part3.mp3'])
for i in range(len(speech_tokens)):
    prompt = '<|streaming|><|word|>'
    token = ''.join([f'<|s{t}|>' for t in speech_tokens[i]])
    input_ids = tokenizer(prompt + token, return_tensors = 'pt').to('cuda')
    logits = model(**input_ids).logits
    print(i, logits[0, -1, 151665]) # 151665 is <|endofspeech|> token

Output,

0 tensor(96.5629, device='cuda:0') # not proper cut
1 tensor(97.0512, device='cuda:0') # not proper cut
2 tensor(102.7403, device='cuda:0') # proper cut
3 tensor(100.4126, device='cuda:0') # proper cut

Source code

Source code at https://github.com/malaysia-ai/cooking/tree/main/qwen-stt

Acknowledgement

Special thanks to Lambda Research Grant program for Lambda cloud credit!