KRISH09bha commited on
Commit
42ac404
·
verified ·
1 Parent(s): 62b1d20

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +8 -3
main.py CHANGED
@@ -1,4 +1,4 @@
1
- import os
2
  import tempfile
3
  import requests
4
  from fastapi import FastAPI, HTTPException, Header, Request
@@ -14,10 +14,10 @@ from sentence_transformers import SentenceTransformer
14
  import faiss
15
  import numpy as np
16
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
 
17
 
18
  import os
19
  os.environ["HF_HOME"] = "./cache"
20
- # Load environment variables
21
  load_dotenv()
22
  HF_API_TOKEN = os.getenv("HF_API_TOKEN")
23
  API_KEY = os.getenv("API_KEY")
@@ -37,7 +37,12 @@ EMBED_MODEL = SentenceTransformer("all-MiniLM-L6-v2")
37
  model_name = "deepseek-ai/deepseek-llm-7b-base"
38
  hf_token = os.getenv("HF_API_TOKEN") # Make sure your .env has HF_API_TOKEN
39
  tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)
40
- model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", token=hf_token)
 
 
 
 
 
41
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
42
 
43
  def query_llm(question: str, context_chunks: list):
 
1
+
2
  import tempfile
3
  import requests
4
  from fastapi import FastAPI, HTTPException, Header, Request
 
14
  import faiss
15
  import numpy as np
16
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
17
+ # Load environment variables
18
 
19
  import os
20
  os.environ["HF_HOME"] = "./cache"
 
21
  load_dotenv()
22
  HF_API_TOKEN = os.getenv("HF_API_TOKEN")
23
  API_KEY = os.getenv("API_KEY")
 
37
  model_name = "deepseek-ai/deepseek-llm-7b-base"
38
  hf_token = os.getenv("HF_API_TOKEN") # Make sure your .env has HF_API_TOKEN
39
  tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)
40
+ model = AutoModelForCausalLM.from_pretrained(
41
+ model_name,
42
+ device_map="auto",
43
+ token=hf_token,
44
+ offload_folder="./cache/offload"
45
+ )
46
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
47
 
48
  def query_llm(question: str, context_chunks: list):