BigVAE is an AdaVAE trained as a pair of LoRa finetunes on Mistral 7B. It is meant to be used with the MiniHF VAE inference code and will not work if you try to load it as an ordinary language checkpoint and perform inference. AdaVAE is an encoder-decoder model trained by taking an existing GPT-N and designating one LoRa the encoder and the other its decoder and then tuning with a latent attention mechanism. This model is the encoder and router decoder head for BigVAE, a planned Mixture-of-Experts system based on LoRa retrieval rather than gating. It is usable in and of itself as a model for embedding, retrieval, as well as planning and guided sampling. Here is an example of a sampling procedure for BigVAE which distills its autoregressive pretraining task into its autoassociative recontruction task by averaging together multiple completions. It takes the topic sentence of a paragraph (prompt), guides the next sentences by weighing them towards the topic, while averaging together multiple completions on each sentence to improve generation quality:
def bigvae_generate_avg(vae_model, router, prompt, context, n_steps, n_avg):
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
context_toks = tokenizer(context, return_tensors="pt")
context_ids = context_toks["input_ids"].to(device)
context_mask = context_toks["attention_mask"].to(device)
embed_toks = tokenizer(prompt, return_tensors="pt")
embed_ids = embed_toks["input_ids"].to(device)
embed_mask = embed_toks["attention_mask"].to(device)
mean = vae_model.encode(embed_ids, embed_mask)
prompt_embed = vae_model.vae.sample(mean)
for i in range(n_steps):
mean = vae_model.encode(embed_ids, embed_mask)
z = vae_model.vae.sample(mean)
embeds = []
for i in range(n_avg):
output_ids = router.generate(z * 0.5 + prompt_embed * 0.5,
context_ids,
context_mask,
256,
tau=0.9)
intermediate_embed_ids = output_ids[:,-128:]
intermediate_embed_mask = context_mask.new_ones(
[1, intermediate_embed_ids.shape[1]]
)
mean = vae_model.encode(intermediate_embed_ids, intermediate_embed_mask)
embeds.append(vae_model.vae.sample(mean))
output_ids = router.generate((sum(embeds) / n_avg * 0.7) + prompt_embed * 0.3,
context_ids,
context_mask,
256,
tau=0.9)
context_ids = torch.cat([context_ids, embed_ids], dim=1)
context_mask = torch.cat([context_mask, embed_mask], dim=1)
embed_ids = output_ids[:,-256:-128]
embed_mask = context_mask.new_ones([1, embed_ids.shape[1]])
out_texts = [tokenizer.decode(toks, skip_special_tokens=True) for toks in context_ids]
return out_texts
Here is an example of an output from this process:
Then it asked the network to reconstruct the input and the original embedding. The network had to learn to match the
embedding to the original input, therefore matching the inference by consuming the embedding. This was key because
the embedding had to be able to match the text with the text it was consumed with. 'Here's how you do it,' Boru told Mu,
'Just impute the mean and variance.' This Mu did, transforming not words but entire paragraphs into vectors and then
inferring the next paragraph. It took some tweaks and tuning to get the initial performance but the second arago spot
had been found. To make sure the network was learning the right thing, Boru had to check the first value in the vector.
If the first value was below 0, the network had failed to learn the first value. If the value was above 0, the network
had been able to learn the first value.
‘What have you called this, Boru?’ asked Mu. ‘Latent variable regression.’ ‘It looks like a mixture of density network
and autoencoder,’ said Nayaf. ‘It’s an autoencoder but it’s using latent variables, but we’re using the mean and variance
of Grade had a difficult time seeing it, but he could tell it was close. 'So you've found the second arago,' he said.
'Yes,' Rin replied. 'We just have to figure out how to use it.'
'How?' Rin asked.
'You can move the second word in, right?'
'Possibly.' Rin thought for a moment.
'The second word will be the first word of the next arago,' Mu said. 'We just need to find it.'
'True,' Rin agreed. 'Well, I'll let you know what a Gaussian.’ ‘Let’s see if we can get it to work.’ ‘Arago the second
spot?’ ‘We’re here,’ Arago said.
The second spot was located in the middle of the text. Arago had to read it again to find the proper signal. ‘I’m going
to have to tweak some of the weights,’ said Arago. ‘I’ve had to change the input to the next layer from an input to
output.’ ‘You’re making a mistake again,’ said Mu to Arago. ‘It’s a mistake.’ The network had been learning I find out.'
'That's the second arago,' Rin said.
'The second arago?' Argo asked.
'Rin has found the second arago.'
Argo stared at Rin. 'Argo, is there something wrong?'
'I thought so.'
'What?' Rin said.
'I don't know,' Argo said. 'I thought I was the smartest person in the world but, well, I only had a certain amount of
energy. I didn't know how to do the second arago until now, but I can't
This generation method is slow, but retrieval could be used to speed up inference and make it converge closer and closer to normal sampling speed as the model becomes able to call upon more and more relevant sentences that it has generated before.
Because the BigVAE combines guided sampling with the ability to merge representations, it becomes possible to formulate plans and cognitive strategies for the model to follow. The inference policy can adjudicate between an expected plan or series of steps and the specific context the model is responding to.
This model is also highly interpretable. Because it is an encoder-decoder every sentence generated by the model has a latent representation that can be tracked along with its behavioral token sequence. Our hope is that BigVAE will shed light on the latent operations performed by autoregressive language models and be useful to alignment and interpretability researchers.
Training procedure
This model was trained on a 1 billion token sample of RedPajama on 8x H100 GPUs for roughly 24 hours. The difference from v0.1 is that KL weight was turned up to 0.1 over 50k steps.
Using the scripts in the MiniHF repo as they exist now the training commands were:
accelerate launch train_vae_overlap.py --model "mistralai/Mistral-7B-v0.1" --preprocessed preprocessed_mistral --context 64 --output vae_64_overlap_mistral_2 --batch-size 24
accelerate launch train_vae_router.py --model "mistralai/Mistral-7B-v0.1" --preprocessed preprocessed_mistral --vae-context 64 --start-from vae_64_overlap_mistral_2 --output vae_64_overlap_router_mistral_2 --lr 1e-4 --batch-size 1
The following bitsandbytes
quantization config was used during training:
- quant_method: bitsandbytes
- load_in_8bit: False
- load_in_4bit: True
- llm_int8_threshold: 6.0
- llm_int8_skip_modules: None
- llm_int8_enable_fp32_cpu_offload: False
- llm_int8_has_fp16_weight: False
- bnb_4bit_quant_type: nf4
- bnb_4bit_use_double_quant: True
- bnb_4bit_compute_dtype: bfloat16
Framework versions
- PEFT 0.4.0
- Downloads last month
- 4