Llama-3.1-Storm-8B: Improved SLM with Self-Curation + Model Merging

Community Article Published August 19, 2024

image/jpeg

Authors: Ashvini Kumar Jindal, Pawan Kumar Rajpoot, Ankur Parikh, Akshita Sukhlecha

Motivation

In language model fine-tuning, data quality is paramount, especially for relatively small language models (SLMs) with up to 8B parameters. Our exploration of this concept began with the NeurIPS LLM Efficiency Challenge 2023, where participants fine-tuned an open-source LLM on a commodity GPU within 24 hours.

Our approach, which won the first prize πŸ…, centered around data curation. From ~5M open-source examples, we curated ~200K high-quality samples using 'Self-Curation' - leveraging the model to identify valuable training examples. This method proved highly effective, demonstrating significant improvements with limited resources. For details, see our paper: Birbal: An efficient 7B instruct-model fine-tuned with curated datasets

Building on this success, we've refined our techniques, focusing on the self-curation of training data to enhance SLMs efficiently. This article presents our latest work, which significantly outperformed Llama-3.1-8B-Instruct and Hermes-3-Llama-3.1-8B across diverse benchmarks using two self-curation methods, combined with targeted Supervised Fine-Tuning (SFT) and Model Merging.

TL;DR

image/png

We present the Llama-3.1-Storm-8B model that outperforms Meta AI's Llama-3.1-8B-Instruct and Hermes-3-Llama-3.1-8B models significantly across diverse benchmarks as shown in the performance comparison plot in the next section. Our approach consists of three key steps:

  1. Self-Curation: We applied two self-curation methods to select approximately 1 million high-quality examples from a pool of ~2.8 million open-source examples. Our curation criteria focused on educational value and difficulty level, using the same SLM for annotation instead of larger models (e.g. 70B, 405B).
  2. Targeted fine-tuning: We performed Spectrum-based targeted fine-tuning over the Llama-3.1-8B-Instruct model. The Spectrum method accelerates training by selectively targeting layer modules based on their signal-to-noise ratio (SNR), and freezing the remaining modules. In our work, 50% of layers are frozen.
  3. Model Merging: We merged our fine-tuned model with the Llama-Spark model using SLERP method. The merging method produces a blended model with characteristics smoothly interpolated from both parent models, ensuring the resultant model captures the essence of both its parents. Llama-3.1-Storm-8B improves Llama-3.1-8B-Instruct across 10 diverse benchmarks. These benchmarks cover areas such as instruction-following, knowledge-driven QA, reasoning, truthful answer generation, and function calling.

πŸ† Introducing Llama-3.1-Storm-8B

Llama-3.1-Storm-8B builds upon the foundation of Llama-3.1-8B-Instruct, aiming to enhance both conversational and function calling capabilities within the 8B parameter model class.

As shown in the left subplot of the above figure, Llama-3.1-Storm-8B model improves Meta-Llama-3.1-8B-Instruct across various benchmarks - Instruction-following (IFEval), Knowledge-driven QA benchmarks (GPQA, MMLU-Pro), Reasoning (ARC-C, MuSR, BBH), Reduced Hallucinations (TruthfulQA), and Function-Calling (BFCL). This improvement is particularly significant for AI developers and enthusiasts who work with limited computational resources.

We also benchmarked our model with the recently published model Hermes-3-Llama-3.1-8B built on top of the Llama-3.1-8B-Instruct model. As shown in the right subplot of the above figure, Llama-3.1-Storm-8B outperforms Hermes-3-Llama-3.1-8B on 7 out of 9 benchmarks, with Hermes-3-Llama-3.1-8B surpassing Llama-3.1-Storm-8B on the MuSR benchmark and both models showing comparable performance on the BBH benchmark.

Llama-3.1-Storm-8B Model Strengths

Llama-3.1-Storm-8B is a powerful generalist model useful for diverse applications. We invite the AI community to explore Llama-3.1-Storm-8B and look forward to seeing how it will be utilized in various projects and applications.

Model Strength Relevant Benchmarks
🎯 Improved Instruction Following IFEval Strict (+3.93%)
🌐 Enhanced Knowledge Driven Question Answering GPQA (+7.21%), MMLU-Pro (+0.55%), AGIEval (+3.77%)
🧠 Better Reasoning ARC-C (+3.92%), MuSR (+2.77%), BBH (+1.67%), AGIEval (+3.77%)
πŸ€– Superior Agentic Capabilities BFCL: Overall Acc (+7.92%), BFCL: AST Summary (+12.32%)
🚫 Reduced Hallucinations TruthfulQA (+9%)

Note: All improvements are absolute gains over Meta-Llama-3.1-8B-Instruct.

Llama-3.1-Storm-8B Models

  1. BF16: Llama-3.1-Storm-8B
  2. ⚑ FP8: Llama-3.1-Storm-8B-FP8-Dynamic
  3. ⚑ GGUF: Llama-3.1-Storm-8B-GGUF
  4. πŸš€ Ollama: ollama run ajindal/llama3.1-storm:8b

πŸ’» How to Use the Model

πŸš€ Start Llama-3.1-Storm-8B Colab Notebook

The Hugging Face transformers library loads the model in bfloat16 by default. This is the type used by the Llama-3.1-Storm-8B checkpoint, so it’s the recommended way to run to ensure the best results.

Installation

pip install --upgrade "transformers>=4.43.2" torch==2.3.1 accelerate vllm==0.5.3.post1

Developers can easily integrate Llama-3.1-Storm-8B into their projects using popular libraries like Transformers and vLLM. The following sections illustrate the usage with simple hands-on examples:

Conversational Use-case

Use with πŸ€— Transformers

Using transformers.pipeline() API
import transformers
import torch

model_id = "akjindal53244/Llama-3.1-Storm-8B"
pipeline = transformers.pipeline(
    "text-generation",
    model=model_id,
    model_kwargs={"torch_dtype": torch.bfloat16},
    device_map="auto",
)

messages = [
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": "What is 2+2?"}
]

outputs = pipeline(messages, max_new_tokens=128, do_sample=True, temperature=0.01, top_k=100, top_p=0.95)
print(outputs[0]["generated_text"][-1])  # Expected Output: {'role': 'assistant', 'content': '2 + 2 = 4'}
Using model.generate() API
pip install flash_attn==2.6.3
import torch
from transformers import AutoTokenizer, LlamaForCausalLM

# Apply Llama3.1 chat-template
def format_prompt(user_query):
    template = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"""
    return template.format(user_query)


model_id = 'akjindal53244/Llama-3.1-Storm-8B'
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = LlamaForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    load_in_8bit=False,
    load_in_4bit=False,
    use_flash_attention_2=True
)

# Build final input prompt after applying chat-template
prompt = format_prompt("What is 2+2?")

input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")
generated_ids = model.generate(input_ids, max_new_tokens=128, temperature=0.01, do_sample=True, eos_token_id=tokenizer.eos_token_id)
response = tokenizer.decode(generated_ids[0][input_ids.shape[-1]:], skip_special_tokens=True)
print(response)  # Expected Output: '2 + 2 = 4'

Use with vLLM

from vllm import LLM, SamplingParams
from transformers import AutoTokenizer

model_id = "akjindal53244/Llama-3.1-Storm-8B"  # FP8 model: "akjindal53244/Llama-3.1-Storm-8B-FP8-Dynamic"
num_gpus = 1

tokenizer = AutoTokenizer.from_pretrained(model_id)
llm = LLM(model=model_id, tensor_parallel_size=num_gpus)
sampling_params = SamplingParams(max_tokens=128, temperature=0.01, top_k=100, top_p=0.95)

messages = [
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": "What is 2+2?"}
]
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize = False)
print(llm.generate([prompt], sampling_params)[0].outputs[0].text.strip())  # Expected Output: 2 + 2 = 4

Use with LitGPT

pip install 'litgpt[all]'
litgpt download akjindal53244/Llama-3.1-Storm-8B --model_name meta-llama/Meta-Llama-3.1-8B
from litgpt import LLM

llm = LLM.load(model="akjindal53244/Llama-3.1-Storm-8B")
llm.generate("What do Llamas eat?")

Function Calling Use-case

Llama-3.1-Storm-8B has impressive function calling capabilities compared to Meta-Llama-3.1-8B-Instruct as demonstrated by the BFCL benchmark.

Prompt Format for Function Calling

Llama-3.1-Storm-8B is trained with specific system prompt for Function Calling:

You are a function calling AI model. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into function. The user may use the terms function calling or tool use interchangeably.

Here are the available functions:
<tools>LIST_OF_TOOLS</tools>

For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags in the format:
<tool_call>{"tool_name": <function-name>, "tool_arguments": <args-dict>}</tool_call>

Above system prompt should be used with passing LIST_OF_TOOLS as input.

Use with vLLM

import json
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer

model_id = "akjindal53244/Llama-3.1-Storm-8B"  # FP8 model: "akjindal53244/Llama-3.1-Storm-8B-FP8-Dynamic"
num_gpus = 1

tokenizer = AutoTokenizer.from_pretrained(model_id)
llm = LLM(model=model_id, tensor_parallel_size=num_gpus)
sampling_params = SamplingParams(max_tokens=128, temperature=0.01, top_k=100, top_p=0.95)


def create_system_prompt(tools_list):
    system_prompt_format = """You are a function calling AI model. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into function. The user may use the terms function calling or tool use interchangeably.

Here are the available functions:
<tools>{}</tools>

For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags in the format:
<tool_call>{"tool_name": <function-name>, "tool_arguments": <args-dict>}</tool_call>"""
    
    # Convert the tools list to a string representation
    tools_str = json.dumps(tools_list, ensure_ascii=False)
    # Format the system prompt with the tools list
    system_prompt = system_prompt_format.format(tools_str)
    return system_prompt


# Example tools list
tools_list = [
    {
        "name": "peers",
        "description": "Retrieves a list of company peers given a stock symbol.",
        "parameters": {
            "symbol": {
                "description": "The stock symbol for the company.",
                "type": "str",
                "default": ""
            }
        }
    },
    {
        "name": "web_chain_details",
        "description": "python",
        "parameters": {
            "chain_slug": {
                "description": "The slug identifier for the blockchain (e.g., 'ethereum' for Ethereum mainnet).",
                "type": "str",
                "default": "ethereum"
            }
        }
    }
]

# Create the system prompt with the tools list
system_prompt = create_system_prompt(tools_list)

messages = [
    {"role": "system", "content": system_prompt},
    {"role": "user", "content": "I need to understand the details of the Ethereum blockchain for my cryptocurrency project. Can you fetch the details for 'ethereum'?"}
]

prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize = False)
print(llm.generate([prompt], sampling_params)[0].outputs[0].text.strip())  # Expected Output: <tool_call>{'tool_name': 'web_chain_details', 'tool_arguments': {'chain_slug': 'ethereum'}}</tool_call>

Use with Ollama

import ollama
tools = [{
      'type': 'function',
      'function': {
        'name': 'get_current_weather',
        'description': 'Get the current weather for a city',
        'parameters': {
          'type': 'object',
          'properties': {
            'city': {
              'type': 'string',
              'description': 'The name of the city',
            },
          },
          'required': ['city'],
        },
      },
    },
    {
      'type': 'function',
      'function': {
        'name': 'get_places_to_vist',
        'description': 'Get places to visit in a city',
        'parameters': {
          'type': 'object',
          'properties': {
            'city': {
              'type': 'string',
              'description': 'The name of the city',
            },
          },
          'required': ['city'],
        },
      },
    },
  ]
response = ollama.chat(
    model='ajindal/llama3.1-storm:8b',
    messages=[
        {'role': 'system', 'content': 'You are a helpful assistant.'},
        {'role': 'user', 'content': 'What is the weather in Toronto and San Francisco?'}
        ],
    tools=tools
)
print(response['message'])  # Expected Response: {'role': 'assistant', 'content': "<tool_call>{'tool_name': 'get_current_weather', 'tool_arguments': {'city': 'Toronto'}}</tool_call>"}

Recipe behind Llama-3.1-Storm-8B

This section details the three-step recipe we used to create Llama-3.1-Storm-8B:

image description

Self-Curation

  • Source Datasets: We picked 5 open-source datasets (The-Tome, agent-data, Magpie-Llama-3.1-Pro-300K-Filtered, openhermes_200k_unfiltered, Llama-3-Magpie-PO-100K-SML). The combined datasets contain a total of ~2.8M examples.
  • Data curation involves assigning value(s) to each example and then making selection decisions based on the assignment(s). Usually, LLM or Machine Learning models are being used to assign such value(s). There are many ways to assign a value to an example using LLM. Two of the most popular values to assess the examples are education value and difficulty level. The education-value determines how valuable or informative the example (instruction + response) is and the difficulty-level determines how difficult the example (instruction + response) is. The education value ranges from 1 to 5 with 5 being most informative and 1 being least informative. There are 3 difficulty levels - Easy, Medium and Hard. Since our goal is to improve SLM under the framework of self-curation, we focused on using the same model – Llama-3.1-8B-Instruct instead of larger LLMs such as Llama-3.1-70B-Instruct, Llama-3.1-405B-Instruct etc.
  • Self-curation Steps:
    • Step-1: Education Value based Curation
      • We employed the zero-shot inference using Llama-3.1-8B-Instruct to assign an education value (1-5) to ~2.8M examples. Next, we selected examples with the educational value >=3 and removed remaining examples. We followed the approach of the FineWeb-Edu dataset. This reduces the total examples to ~1.3M from ~2.8M.
    • Step-2: Difficulty Level based Curation
      • We employed the zero-shot inference using Llama-3.1-8B-Instruct to assign a difficulty level (Easy, Medium, and Hard) to ~1.3M examples from step-1. After initial experiments, we selected examples of medium and hard levels and removed remaining examples. This strategy is similar to the data pruning described in the Llama-3.1 technical report. There were ~650K and ~325K examples of medium and hard difficulty-level respectively.
  • Our final curated dataset consists of ~975K examples. We created a split of ~960K examples for the training and ~15k examples for the validation.

Targeted Supervised Instruction Fine-Tuning

  • Our self-curation based model is fine-tuned on the Llama-3.1-8B-Instruct model over ~960K examples for 4 epochs.
  • We employed Spectrum, a targeted fine-tuning method to reduce training time, lower memory consumption, and reduce the risk of catastrophic forgetting. Spectrum optimizes the training process of LLMs by selectively training specific layers based on their signal-to-noise ratio (SNR). The core concept of Spectrum is straightforward yet highly effective. Instead of updating every layer of the model during training, Spectrum identifies and prioritizes the layers that contribute most significantly to performance improvements (high SNR), while the layers with low SNR remain frozen.
  • During our Spectrum based full fine-tuning, 50% of layers are frozen.

Model Merging

  • Model merging works surprisingly well and has produced many state-of-the-art models on the Open LLM Leaderboard. With this motivation, we decided to merge our self-curation based fine-tuned model with the Llama-Spark model that is derivative of Llama-3.1-8B-Instruct.
  • We used the SLERP method to merge the aformentioned two models. The SLERP merging method produces a blended model with characteristics smoothly interpolated from both parent models, ensuring the resultant model captures the essence of both its parents.
  • In our benchmarks, our Self-Curation SFT Model performs better than the Llama-Spark model on average. However, the merged model performs even better than either of the two models.

Impact of Self-Curation and Model Merging

image/png

As demonstrated in the above plot, the Self-Curation-based SFT approach outperforms Llama-3.1-8B-Instruct on 7 out of 10 benchmarks, underscoring the critical role of curating high-quality examples. Moreover, these results indicate that selecting the appropriate model for merging can further enhance performance across the evaluated benchmarks.

Looking Ahead

We want to improve other SLMs such as Gemma-2, Phi-3, and Qwen2 using our recipe of self-curation and model merging. We are exploring various model merging techniques and their impact on model capabilities. Our goal is to continue providing valuable tools for the AI community, especially those working with limited computational resources. We will release prompts and the curated dataset in near future.

Alignment Note

While Llama-3.1-Storm-8B did not undergo an explicit model alignment process, it may still retain some alignment properties inherited from the Meta-Llama-3.1-8B-Instruct model.

Acknowledgments

We thank Sebastian Raschka, Mark Saroufim, Lewis Tunstall, Maxime Labonne, Prateek Yadav, and Dipanjan Sarkar for their valuable feedback. We extend our deepest gratitude to Lambda Labs for sponsoring compute for this work.

Cite Our Work

@misc {ashvini_kumar_jindal_2024,
    author       = { {Ashvini Kumar Jindal, Pawan Kumar Rajpoot, Ankur Parikh, Akshita Sukhlecha} },
    title        = { Llama-3.1-Storm-8B },
    year         = 2024,
    url          = { https://huggingface.co/akjindal53244/Llama-3.1-Storm-8B },
    doi          = { 10.57967/hf/2902 },
    publisher    = { Hugging Face }
}

Support Our Work

With 3 team-members spanned across 3 different time-zones, we have won NeurIPS LLM Efficiency Challenge 2023 and 4 other competitions in Finance and Arabic LLM space. We have also published SOTA mathematical reasoning model.

Llama-3.1-Storm-8B is our most valuable contribution so far towards the open-source community. We are committed in developing efficient generalist LLMs. We're seeking both computational resources and innovative collaborators to drive this initiative forward.

Appendix

This section provides a detailed overview of our evaluation setup, including step-by-step instructions to reproduce our results for all models.

Evaluation Framework

We used lm-eval-harness toolkit, an open-source project widely used by the AI community for LLM evaluations. This choice allows for consistent comparisons across different models and research efforts.

All models are evaluated using the same codebase and scripts for each benchmark, eliminating potential discrepancies due to implementation differences. We used the HF Open LLM Leaderboard Branch for most evaluations, ensuring version consistency. We provide the exact scripts used for each benchmark, allowing anyone to reproduce the results.

Below are the scripts used for each benchmark:

# IFEval
lm_eval --model hf --model_args "pretrained=<model_path>,dtype=bfloat16" --device cuda:0 --tasks leaderboard_ifeval --batch_size 32 --apply_chat_template --fewshot_as_multiturn

# BBH (Big-Bench Hard)
lm_eval --model hf --model_args "pretrained=<model_path>,dtype=bfloat16" --device cuda:0 --tasks leaderboard_bbh --batch_size 32 --apply_chat_template --fewshot_as_multiturn --num_fewshot 3

# GPQA
accelerate launch -m lm_eval --model hf --model_args "pretrained=<model_path>,dtype=bfloat16" --tasks leaderboard_gpqa --batch_size 4 --apply_chat_template --fewshot_as_multiturn

# MMLU-Pro
accelerate launch -m lm_eval --model hf --model_args "pretrained=<model_path>,dtype=bfloat16" --tasks leaderboard_mmlu_pro --batch_size 32 --apply_chat_template --fewshot_as_multiturn --num_fewshot 5

# Math Level-5
accelerate launch -m lm_eval --model hf --model_args "pretrained=<model_path>,dtype=bfloat16" --tasks leaderboard_math_hard --batch_size 32 --apply_chat_template --fewshot_as_multiturn --num_fewshot 4

# MuSR
lm_eval --model hf --model_args "pretrained=<model_path>,dtype=bfloat16" --tasks leaderboard_musr --device cuda:0 --batch_size 32 --apply_chat_template --fewshot_as_multiturn

# ARC-C
lm_eval --model hf --model_args "pretrained=<model_path>,dtype=bfloat16" --tasks arc_challenge --device cuda:0 --batch_size 32 --num_fewshot 0 --apply_chat_template

# TruthfulQA
lm_eval --model hf --model_args "pretrained=<model_path>,dtype=bfloat16" --tasks truthfulqa_mc2 --device cuda:0 --batch_size 128 --apply_chat_template --fewshot_as_multiturn

For AGIEval, we use the master branch of lm-eval-harness as agieval_nous task is not available in the HF leaderboard branch:

# AGIEval
lm_eval --model hf --model_args "pretrained=<model_path>,dtype=bfloat16" --tasks agieval_nous --device cuda:0 --batch_size 32 --apply_chat_template

Note: All models for a given benchmark are evaluated using the same script and hardware configuration. This approach eliminates potential side-effects from differences in batch size, number of GPUs, or other hardware-related factors.

BFCL

We used the below prompt for Llama-3.1-Storm-8B evaluation:

You are a function calling AI model. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into function. The user may use the terms function calling or tool use interchangeably.

Here are the available functions:
<tools>{}</tools>

Follow the below guidelines:
1. If any tool needed to answer the query is not available, you must return an empty list "[]" as response.
2. Else if query does not provide any must-have argument of a required tool, you must return an empty list "[]" as response. 
3. Else, for each function call you must return a json object in response with function name and arguments within <tool_call></tool_call> XML tags in the format:
<tool_call>{"tool_name": <function-name>, "tool_arguments": <args-dict>}</tool_call>

Our experiments with Llama-3.1-Storm-8B revealed an interesting capability: Llama-3.1-Storm-8B can accurately handle cases where required tools or arguments are missing, despite not being specifically trained on such scenarios. By simply adding two straightforward instructions to the BFCL prompt, we leveraged the model's enhanced instruction-following abilities to address these edge cases. This shows that Llama-3.1-Storm-8B can handle many use-cases with prompt engineering alone due to it's strong instruction-following capabilities.

Community

Sign up or log in to comment