Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	Upload RAG_using_Llama3.py.py
Browse files- RAG_using_Llama3.py.py +153 -0
    	
        RAG_using_Llama3.py.py
    ADDED
    
    | @@ -0,0 +1,153 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # -*- coding: utf-8 -*-
         | 
| 2 | 
            +
            """RAG_using_Llama3.ipynb
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            Automatically generated by Colab.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            Original file is located at
         | 
| 7 | 
            +
                https://colab.research.google.com/drive/1b-ZDo3QQ-axgm804UlHu3ohZwnoXz5L1
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            # install dependecies
         | 
| 10 | 
            +
            """
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            !pip install -q datasets sentence-transformers faiss-cpu accelerate
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            from huggingface_hub import notebook_login
         | 
| 15 | 
            +
            notebook_login()
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            """# embed dataset
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            this is a slow procedure so you might consider saving your results
         | 
| 20 | 
            +
            """
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            from datasets import load_dataset
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            dataset = load_dataset("KarthikaRajagopal/wikipedia-2")
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            dataset
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            from sentence_transformers import SentenceTransformer
         | 
| 29 | 
            +
            ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
         | 
| 30 | 
            +
             | 
| 31 | 
            +
            # embed the dataset
         | 
| 32 | 
            +
            def embed(batch):
         | 
| 33 | 
            +
              # or you can combine multiple columns here, for example the title and the text
         | 
| 34 | 
            +
              information = batch["text"]
         | 
| 35 | 
            +
              return {"embeddings" : ST.encode(information)}
         | 
| 36 | 
            +
            dataset = dataset.map(embed,batched=True,batch_size=16)
         | 
| 37 | 
            +
             | 
| 38 | 
            +
            !pip install datasets
         | 
| 39 | 
            +
             | 
| 40 | 
            +
            from datasets import load_dataset
         | 
| 41 | 
            +
             | 
| 42 | 
            +
            dataset = load_dataset("KarthikaRajagopal/wikipedia-2",revision = "embedded")
         | 
| 43 | 
            +
             | 
| 44 | 
            +
            # Push it to your Hugging Face repository
         | 
| 45 | 
            +
            dataset.push_to_hub("KarthikaRajagopal/wikipedia-2", revision="embedded")
         | 
| 46 | 
            +
             | 
| 47 | 
            +
            from datasets import load_dataset
         | 
| 48 | 
            +
             | 
| 49 | 
            +
            dataset = load_dataset("KarthikaRajagopal/wikipedia-2",revision = "embedded")
         | 
| 50 | 
            +
             | 
| 51 | 
            +
            data = dataset["train"]
         | 
| 52 | 
            +
            data = data.add_faiss_index("embeddings") # column name that has the embeddings of the dataset
         | 
| 53 | 
            +
             | 
| 54 | 
            +
            def search(query: str, k: int = 3 ):
         | 
| 55 | 
            +
                """a function that embeds a new query and returns the most probable results"""
         | 
| 56 | 
            +
                embedded_query = ST.encode(query) # embed new query
         | 
| 57 | 
            +
                scores, retrieved_examples = data.get_nearest_examples( # retrieve results
         | 
| 58 | 
            +
                    "embeddings", embedded_query, # compare our new embedded query with the dataset embeddings
         | 
| 59 | 
            +
                    k=k # get only top k results
         | 
| 60 | 
            +
                )
         | 
| 61 | 
            +
                return scores, retrieved_examples
         | 
| 62 | 
            +
             | 
| 63 | 
            +
            scores , result = search("anarchy", 4 ) # search for word anarchy and get the best 4 matching values from the dataset
         | 
| 64 | 
            +
             | 
| 65 | 
            +
            # the lower the better
         | 
| 66 | 
            +
            scores
         | 
| 67 | 
            +
             | 
| 68 | 
            +
            result['title']
         | 
| 69 | 
            +
             | 
| 70 | 
            +
            print(result["text"][0])
         | 
| 71 | 
            +
             | 
| 72 | 
            +
            """# chatbot on top of the retrieved results"""
         | 
| 73 | 
            +
             | 
| 74 | 
            +
            !pip install -q datasets sentence-transformers faiss-cpu accelerate bitsandbytes
         | 
| 75 | 
            +
             | 
| 76 | 
            +
            from sentence_transformers import SentenceTransformer
         | 
| 77 | 
            +
            ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
         | 
| 78 | 
            +
             | 
| 79 | 
            +
            from datasets import load_dataset
         | 
| 80 | 
            +
             | 
| 81 | 
            +
            dataset = load_dataset("KarthikaRajagopal/wikipedia-2",revision = "embedded")
         | 
| 82 | 
            +
             | 
| 83 | 
            +
            data = dataset["train"]
         | 
| 84 | 
            +
            data = data.add_faiss_index("embeddings") # column name that has the embeddings of the dataset
         | 
| 85 | 
            +
             | 
| 86 | 
            +
            def search(query: str, k: int = 3 ):
         | 
| 87 | 
            +
                """a function that embeds a new query and returns the most probable results"""
         | 
| 88 | 
            +
                embedded_query = ST.encode(query) # embed new query
         | 
| 89 | 
            +
                scores, retrieved_examples = data.get_nearest_examples( # retrieve results
         | 
| 90 | 
            +
                    "embeddings", embedded_query, # compare our new embedded query with the dataset embeddings
         | 
| 91 | 
            +
                    k=k # get only top k results
         | 
| 92 | 
            +
                )
         | 
| 93 | 
            +
                return scores, retrieved_examples
         | 
| 94 | 
            +
             | 
| 95 | 
            +
            from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
         | 
| 96 | 
            +
            import torch
         | 
| 97 | 
            +
             | 
| 98 | 
            +
            model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
         | 
| 99 | 
            +
             | 
| 100 | 
            +
            bnb_config = BitsAndBytesConfig(
         | 
| 101 | 
            +
                load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
         | 
| 102 | 
            +
            )
         | 
| 103 | 
            +
             | 
| 104 | 
            +
            tokenizer = AutoTokenizer.from_pretrained(model_id)
         | 
| 105 | 
            +
            model = AutoModelForCausalLM.from_pretrained(
         | 
| 106 | 
            +
                model_id,
         | 
| 107 | 
            +
                torch_dtype=torch.bfloat16,
         | 
| 108 | 
            +
                device_map="auto",
         | 
| 109 | 
            +
                quantization_config=bnb_config
         | 
| 110 | 
            +
            )
         | 
| 111 | 
            +
            terminators = [
         | 
| 112 | 
            +
                tokenizer.eos_token_id,
         | 
| 113 | 
            +
                tokenizer.convert_tokens_to_ids("<|eot_id|>")
         | 
| 114 | 
            +
            ]
         | 
| 115 | 
            +
             | 
| 116 | 
            +
            SYS_PROMPT = """You are an assistant for answering questions.
         | 
| 117 | 
            +
            You are given the extracted parts of a long document and a question. Provide a conversational answer.
         | 
| 118 | 
            +
            If you don't know the answer, just say "I do not know." Don't make up an answer."""
         | 
| 119 | 
            +
             | 
| 120 | 
            +
            def format_prompt(prompt,retrieved_documents,k):
         | 
| 121 | 
            +
              """using the retrieved documents we will prompt the model to generate our responses"""
         | 
| 122 | 
            +
              PROMPT = f"Question:{prompt}\nContext:"
         | 
| 123 | 
            +
              for idx in range(k) :
         | 
| 124 | 
            +
                PROMPT+= f"{retrieved_documents['text'][idx]}\n"
         | 
| 125 | 
            +
              return PROMPT
         | 
| 126 | 
            +
             | 
| 127 | 
            +
            def generate(formatted_prompt):
         | 
| 128 | 
            +
              formatted_prompt = formatted_prompt[:2000] # to avoid GPU OOM
         | 
| 129 | 
            +
              messages = [{"role":"system","content":SYS_PROMPT},{"role":"user","content":formatted_prompt}]
         | 
| 130 | 
            +
              # tell the model to generate
         | 
| 131 | 
            +
              input_ids = tokenizer.apply_chat_template(
         | 
| 132 | 
            +
                  messages,
         | 
| 133 | 
            +
                  add_generation_prompt=True,
         | 
| 134 | 
            +
                  return_tensors="pt"
         | 
| 135 | 
            +
              ).to(model.device)
         | 
| 136 | 
            +
              outputs = model.generate(
         | 
| 137 | 
            +
                  input_ids,
         | 
| 138 | 
            +
                  max_new_tokens=1024,
         | 
| 139 | 
            +
                  eos_token_id=terminators,
         | 
| 140 | 
            +
                  do_sample=True,
         | 
| 141 | 
            +
                  temperature=0.6,
         | 
| 142 | 
            +
                  top_p=0.9,
         | 
| 143 | 
            +
              )
         | 
| 144 | 
            +
              response = outputs[0][input_ids.shape[-1]:]
         | 
| 145 | 
            +
              return tokenizer.decode(response, skip_special_tokens=True)
         | 
| 146 | 
            +
             | 
| 147 | 
            +
            def rag_chatbot(prompt:str,k:int=2):
         | 
| 148 | 
            +
              scores , retrieved_documents = search(prompt, k)
         | 
| 149 | 
            +
              formatted_prompt = format_prompt(prompt,retrieved_documents,k)
         | 
| 150 | 
            +
              return generate(formatted_prompt)
         | 
| 151 | 
            +
             | 
| 152 | 
            +
            rag_chatbot("what's anarchy ?", k = 2)
         | 
| 153 | 
            +
             | 
