Model Overview

⚠️ Whisper is currently only available via the keras-hub-nightly package. Use pip install keras-hub-nightly to try this model.

A Whisper encoder-decoder network for speech.

This class implements a Transformer-based encoder-decoder model as described in "Robust Speech Recognition via Large-Scale Weak Supervision". It includes the embedding lookups and transformer layers, but not the head for predicting the next token.

The default constructor gives a fully customizable, randomly initialized Whisper model with any number of layers, heads, and embedding dimensions. To load preset architectures and weights, use the from_preset() constructor.

Disclaimer: Pre-trained models are provided on an "as is" basis, without warranties or conditions of any kind. The underlying model is provided by a third party and subject to a separate license, available here.

Arguments

  • vocabulary_size: int. The size of the token vocabulary.
  • num_layers: int. The number of transformer encoder layers and transformer decoder layers.
  • num_heads: int. The number of attention heads for each transformer. The hidden size must be divisible by the number of attention heads.
  • hidden_dim: int. The size of the transformer encoding and pooler layers.
  • intermediate_dim: int. The output dimension of the first Dense layer in a two-layer feedforward network for each transformer.
  • num_mels: int. The number of mel-frequency filters. Defaults to 80.
  • dropout: float. Dropout probability for the Transformer encoder.
  • max_encoder_sequence_length: int. The maximum sequence length that the audio encoder can consume. Since the second convolutional layer in the encoder reduces the sequence length by half (stride of 2), we use max_encoder_sequence_length // 2 as the sequence length for the positional embedding layer.
  • max_decoder_sequence_length: int. The maximum sequence length that the text decoder can consume.

Example Usage

import keras_hub
import keras_core as keras
import numpy as np
input_data = {
    "encoder_features": np.ones(shape=(1, 12, 80), dtype="int32"),
    "decoder_token_ids": np.ones(shape=(1, 12), dtype="int32"),
    "decoder_padding_mask": np.array(
        [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]
    ),
}

# Randomly initialized Whisper encoder-decoder model with a custom config.
model = keras_hub.models.WhisperBackbone(
    vocabulary_size=51864,
    num_layers=4,
    num_heads=4,
    hidden_dim=256,
    intermediate_dim=512,
    max_encoder_sequence_length=128,
    max_decoder_sequence_length=128,
)
model(input_data)

Example Usage with Hugging Face URI

import keras_hub
import keras_core as keras
import numpy as np
input_data = {
    "encoder_features": np.ones(shape=(1, 12, 80), dtype="int32"),
    "decoder_token_ids": np.ones(shape=(1, 12), dtype="int32"),
    "decoder_padding_mask": np.array(
        [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]
    ),
}

# Randomly initialized Whisper encoder-decoder model with a custom config.
model = keras_hub.models.WhisperBackbone(
    vocabulary_size=51864,
    num_layers=4,
    num_heads=4,
    hidden_dim=256,
    intermediate_dim=512,
    max_encoder_sequence_length=128,
    max_decoder_sequence_length=128,
)
model(input_data)
Downloads last month
6
Inference Examples
Inference API (serverless) does not yet support keras-hub models for this pipeline type.

Collection including keras/whisper_medium_en