Raj-Maharajwala commited on
Commit
72d3baf
·
verified ·
1 Parent(s): ccbf47f

Update inference_open-insurance-llm-gguf.py

Browse files
Files changed (1) hide show
  1. inference_open-insurance-llm-gguf.py +101 -228
inference_open-insurance-llm-gguf.py CHANGED
@@ -1,159 +1,47 @@
1
  import os
2
  import time
3
- import logging
4
- import sys
5
- import psutil
6
- import datetime
7
- import traceback
8
- import multiprocessing
9
  from pathlib import Path
10
  from llama_cpp import Llama
11
- from typing import Optional, Dict, Any
12
- from dataclasses import dataclass
13
  from rich.console import Console
14
- from rich.logging import RichHandler
15
- from contextlib import contextmanager
16
- from rich.traceback import install
17
- from rich.theme import Theme
18
  from huggingface_hub import hf_hub_download
19
- # from rich.progress import Progress, SpinnerColumn, TimeElapsedColumn
20
- # Install rich traceback handler
21
- install(show_locals=True)
22
 
23
  @dataclass
24
  class ModelConfig:
 
25
  model_name: str = "Raj-Maharajwala/Open-Insurance-LLM-Llama3-8B-GGUF"
26
  model_file: str = "open-insurance-llm-q4_k_m.gguf"
27
- # model_file: str = "open-insurance-llm-q8_0.gguf"
28
- # model_file: str = "open-insurance-llm-q5_k_m.gguf"
29
- max_tokens: int = 1000
30
- top_k: int = 15
31
- top_p: float = 0.2
32
- repeat_penalty: float = 1.2
33
- num_beams: int = 4
34
- n_gpu_layers: int = -2 #-2 # -1 for complete GPU usage
35
- temperature: float = 0.1 # Coherent(0.1) vs Creativity(0.8)
36
- n_ctx: int = 2048 # 2048 - 8192 -> As per Llama 3 Full Capacity
37
- n_batch: int = 256
38
- verbose: bool = False
39
- use_mmap: bool = False
40
- use_mlock: bool = True
41
- offload_kqv: bool =True
42
-
43
- class CustomFormatter(logging.Formatter):
44
- """Enhanced formatter with detailed context for different log levels"""
45
- FORMATS = {
46
- logging.DEBUG: "🔍 %(asctime)s - %(name)s - [%(filename)s:%(lineno)d] - %(levelname)s - %(message)s",
47
- logging.INFO: "ℹ️ %(asctime)s - %(name)s - [%(funcName)s] - %(levelname)s - %(message)s",
48
- logging.WARNING: "⚠️ %(asctime)s - %(name)s - [%(funcName)s] - %(levelname)s - %(message)s\nContext: %(pathname)s",
49
- logging.ERROR: "❌ %(asctime)s - %(name)s - [%(funcName)s:%(lineno)d] - %(levelname)s - %(message)s",
50
- logging.CRITICAL: """🚨 %(asctime)s - %(name)s - %(levelname)s
51
- Location: %(pathname)s:%(lineno)d
52
- Function: %(funcName)s
53
- Process: %(process)d
54
- Thread: %(thread)d
55
- Message: %(message)s
56
- Memory: %(memory).2fMB
57
- """
58
- }
59
-
60
- def format(self, record):
61
- # Add memory usage information
62
- if not hasattr(record, 'memory'):
63
- record.memory = psutil.Process().memory_info().rss / (1024 * 1024)
64
-
65
- log_fmt = self.FORMATS.get(record.levelno)
66
- formatter = logging.Formatter(log_fmt, datefmt='%Y-%m-%d %H:%M:%S')
67
-
68
- # Add performance metrics if available
69
- if hasattr(record, 'duration'):
70
- record.message = f"{record.message}\nDuration: {record.duration:.2f}s"
71
-
72
- return formatter.format(record)
73
-
74
- def setup_logging(log_dir: str = "logs") -> logging.Logger:
75
- """Enhanced logging setup with multiple handlers and log files"""
76
- Path(log_dir).mkdir(exist_ok=True)
77
- timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
78
- log_path = (Path(log_dir) / f"l_{timestamp}")
79
- log_path.mkdir(exist_ok=True)
80
-
81
- # Create logger
82
- logger = logging.getLogger("InsuranceLLM")
83
- # Clear any existing handlers
84
- logger.handlers.clear()
85
- logger.setLevel(logging.DEBUG)
86
-
87
- # Create handlers with level-specific files
88
- handlers = {
89
- 'debug': (logging.FileHandler(log_path / f"debug_{timestamp}.log"), logging.DEBUG),
90
- 'info': (logging.FileHandler(log_path / f"info_{timestamp}.log"), logging.INFO),
91
- 'error': (logging.FileHandler(log_path / f"error_{timestamp}.log"), logging.ERROR),
92
- 'critical': (logging.FileHandler(log_path / f"critical_{timestamp}.log"), logging.CRITICAL),
93
- 'console': (RichHandler(
94
- console=Console(theme=custom_theme),
95
- show_time=True,
96
- show_path=False,
97
- enable_link_path=True
98
- ), logging.INFO)
99
- }
100
-
101
- # Configure handlers
102
- formatter = CustomFormatter()
103
- for (handler, level) in handlers.values():
104
- handler.setLevel(level)
105
- handler.setFormatter(formatter)
106
- logger.addHandler(handler)
107
-
108
- # Log startup information (will now appear only once)
109
- logger.info(f"Starting new session {timestamp}")
110
- logger.info(f"Log directory: {log_dir}")
111
- return logger
112
-
113
 
114
- # Custom theme configuration
115
- custom_theme = Theme({"info": "bold cyan","warning": "bold yellow", "error": "bold red","critical": "bold white on red","success": "bold green","timestamp": "bold magenta","metrics": "bold blue","memory": "bold yellow","performance": "bold cyan",})
116
 
117
- console = Console(theme=custom_theme)
118
-
119
- class PerformanceMetrics:
120
- def __init__(self):
121
- self.start_time = time.time()
122
- self.tokens = 0
123
- self.response_times = []
124
- self.last_reset = self.start_time
125
-
126
- def reset_timer(self):
127
- """Reset the timer for individual response measurements"""
128
- self.last_reset = time.time()
129
-
130
- def update(self, tokens: int):
131
- self.tokens += tokens
132
- response_time = time.time() - self.last_reset
133
- self.response_times.append(response_time)
134
-
135
- @property
136
- def elapsed_time(self) -> float:
137
- return time.time() - self.start_time
138
-
139
- @property
140
- def last_response_time(self) -> float:
141
- return self.response_times[-1] if self.response_times else 0
142
 
143
  class InsuranceLLM:
144
  def __init__(self, config: ModelConfig):
145
  self.config = config
146
- self.llm_ctx: Optional[Llama] = None
147
- self.metrics = PerformanceMetrics()
148
- self.logger = setup_logging()
149
-
150
- nvidia_llama3_chatqa_system = (
151
  "This is a chat between a user and an artificial intelligence assistant. "
152
  "The assistant gives helpful, detailed, and polite answers to the user's questions based on the context. "
153
  "The assistant should also indicate when the answer cannot be found in the context. "
154
- )
155
- enhanced_system_message = (
156
- "You are an expert and experienced from the Insurance domain with extensive insurance knowledge and "
157
  "professional writer skills, especially about insurance policies. "
158
  "Your name is OpenInsuranceLLM, and you were developed by Raj Maharajwala. "
159
  "You are willing to help answer the user's query with a detailed explanation. "
@@ -161,32 +49,22 @@ class InsuranceLLM:
161
  "complex coverage plans, or other pertinent insurance concepts. Use precise insurance terminology while "
162
  "still aiming to make the explanation clear and accessible to a general audience."
163
  )
164
- self.full_system_message = nvidia_llama3_chatqa_system + enhanced_system_message
165
-
166
- @contextmanager
167
- def timer(self, description: str):
168
- start_time = time.time()
169
- yield
170
- elapsed_time = time.time() - start_time
171
- self.logger.info(f"{description}: {elapsed_time:.2f}s")
172
 
173
  def download_model(self) -> str:
174
  try:
175
- with console.status("[bold green]Downloading model..."):
176
  model_path = hf_hub_download(
177
  self.config.model_name,
178
  filename=self.config.model_file,
179
  local_dir=os.path.join(os.getcwd(), 'gguf_dir')
180
  )
181
- self.logger.info(f"Model downloaded successfully to {model_path}")
182
  return model_path
183
  except Exception as e:
184
- self.logger.error(f"Error downloading model: {str(e)}")
185
  raise
186
 
187
  def load_model(self) -> None:
188
  try:
189
- # self.check_metal_support()
190
  quantized_path = os.path.join(os.getcwd(), "gguf_dir")
191
  directory = Path(quantized_path)
192
 
@@ -195,7 +73,7 @@ class InsuranceLLM:
195
  except IndexError:
196
  model_path = self.download_model()
197
 
198
- with console.status("[bold green]Loading model..."):
199
  self.llm_ctx = Llama(
200
  model_path=model_path,
201
  n_gpu_layers=self.config.n_gpu_layers,
@@ -207,40 +85,37 @@ class InsuranceLLM:
207
  use_mmap=self.config.use_mmap,
208
  offload_kqv=self.config.offload_kqv
209
  )
210
- self.logger.info("Model loaded successfully")
211
-
212
  except Exception as e:
213
- self.logger.error(f"Error loading model: {str(e)}")
214
  raise
215
 
216
- def get_prompt(self, question: str, context: str = "") -> str:
 
 
 
 
 
 
 
 
217
  if context:
218
- return (
219
- f"System: {self.full_system_message}\n\n"
220
- f"User: Context: {context}\nQuestion: {question}\n\n"
221
- "Assistant:"
222
- )
223
- return (
224
- f"System: {self.full_system_message}\n\n"
225
- f"User: Question: {question}\n\n"
226
- "Assistant:"
227
- )
228
-
229
-
230
- def generate_response(self, prompt: str) -> Dict[str, Any]:
231
  if not self.llm_ctx:
232
  raise RuntimeError("Model not loaded. Call load_model() first.")
 
 
 
 
 
233
 
234
  try:
235
- response = {"text": "", "tokens": 0}
236
-
237
- # Print the initial prompt
238
- # print("Assistant: ", end="", flush=True)
239
- console.print("\n[bold cyan]Assistant: [/bold cyan]", end="")
240
-
241
- # Initialize complete response
242
- complete_response = ""
243
-
244
  for chunk in self.llm_ctx.create_completion(
245
  prompt,
246
  max_tokens=self.config.max_tokens,
@@ -251,43 +126,42 @@ class InsuranceLLM:
251
  stream=True
252
  ):
253
  text_chunk = chunk["choices"][0]["text"]
254
- response["text"] += text_chunk
255
- response["tokens"] += 1
256
-
257
- # Append to complete response
258
  complete_response += text_chunk
259
-
260
- # Use simple print for streaming output
261
  print(text_chunk, end="", flush=True)
262
-
263
- # Print final newline
264
  print()
 
 
 
 
265
 
266
- return response
267
-
268
- except RuntimeError as e:
269
- if "llama_decode returned -3" in str(e):
270
- self.logger.error("Memory allocation failed. Try reducing context window or batch size")
271
- raise
272
-
273
- def run_inference_loop(self):
274
  try:
275
  self.load_model()
276
- console.print("\n[bold green]Welcome to Open-Insurance-LLM![/bold green]")
277
- console.print("Enter your questions (type '/bye', 'exit', or 'quit' to end the session)\n")
278
- console.print("Optional: You can provide context by typing 'context:' followed by your context, then 'question:' followed by your question\n")
279
- memory_used = psutil.Process().memory_info().rss / 1024 / 1024
280
- console.print(f"[dim]Memory usage: {memory_used:.2f} MB[/dim]")
 
 
281
  while True:
282
  try:
283
- user_input = console.input("[bold cyan]User:[/bold cyan] ").strip()
284
 
285
  if user_input.lower() in ["exit", "/bye", "quit"]:
286
- console.print(f"[dim]Total tokens uptill now: {self.metrics.tokens}[/dim]")
287
- console.print(f"[dim]Total Session Time: {self.metrics.elapsed_time:.2}[/dim]")
288
- console.print("\n[bold green]Thank you for using OpenInsuranceLLM![/bold green]")
289
  break
290
 
 
 
 
 
 
 
291
  context = ""
292
  question = user_input
293
  if "context:" in user_input.lower() and "question:" in user_input.lower():
@@ -295,51 +169,50 @@ class InsuranceLLM:
295
  context = parts[0].replace("context:", "").strip()
296
  question = parts[1].strip()
297
 
298
- prompt = self.get_prompt(question, context)
299
-
300
- # Reset timer before generation
301
- self.metrics.reset_timer()
302
-
303
- # Generate response
304
- response = self.generate_response(prompt)
305
-
306
- # Update metrics after generation
307
- self.metrics.update(response["tokens"])
 
308
 
309
-
310
  # Print metrics
311
- console.print(f"[dim]Average tokens/sec: {response['tokens']/(self.metrics.last_response_time if self.metrics.last_response_time!=0 else 1):.2f} ||[/dim]",
312
- f"[dim]Tokens generated: {response['tokens']} ||[/dim]",
313
- f"[dim]Response time: {self.metrics.last_response_time:.2f}s[/dim]", end="\n\n\n")
314
-
 
 
 
 
315
  except KeyboardInterrupt:
316
- console.print("\n[yellow]Input interrupted. Type '/bye', 'exit', or 'quit' to quit.[/yellow]")
317
  continue
318
  except Exception as e:
319
- self.logger.error(f"Error processing input: {str(e)}")
320
- console.print(f"\n[red]Error: {str(e)}[/red]")
321
  continue
322
-
323
  except Exception as e:
324
- self.logger.error(f"Fatal error in inference loop: {str(e)}")
325
- console.print(f"\n[red]Fatal error: {str(e)}[/red]")
326
  finally:
327
  if self.llm_ctx:
328
  del self.llm_ctx
329
 
 
330
  def main():
331
- if hasattr(multiprocessing, "set_start_method"):
332
- multiprocessing.set_start_method("spawn", force=True)
333
  try:
334
  config = ModelConfig()
335
  llm = InsuranceLLM(config)
336
- llm.run_inference_loop()
337
  except KeyboardInterrupt:
338
- console.print("\n[yellow]Program interrupted by user[/yellow]")
339
  except Exception as e:
340
- error_msg = f"Application error: {str(e)}"
341
- logging.error(error_msg)
342
- console.print(f"\n[red]{error_msg}[/red]")
343
 
344
  if __name__ == "__main__":
345
  main()
 
1
  import os
2
  import time
 
 
 
 
 
 
3
  from pathlib import Path
4
  from llama_cpp import Llama
 
 
5
  from rich.console import Console
 
 
 
 
6
  from huggingface_hub import hf_hub_download
7
+ from dataclasses import dataclass
8
+ from typing import List, Dict, Any, Tuple
 
9
 
10
  @dataclass
11
  class ModelConfig:
12
+ # Optimized parameters for coherent responses and efficient performance on devices like MacBook Air M2
13
  model_name: str = "Raj-Maharajwala/Open-Insurance-LLM-Llama3-8B-GGUF"
14
  model_file: str = "open-insurance-llm-q4_k_m.gguf"
15
+ # model_file: str = "open-insurance-llm-q8_0.gguf" # 8-bit quantization; higher precision, better quality, increased resource usage
16
+ # model_file: str = "open-insurance-llm-q5_k_m.gguf" # 5-bit quantization; balance between performance and resource efficiency
17
+ max_tokens: int = 1000 # Maximum number of tokens to generate in a single output
18
+ temperature: float = 0.1 # Controls randomness in output; lower values produce more coherent responses (performs scaling distribution)
19
+ top_k: int = 15 # After temperature scaling, Consider the top 15 most probable tokens during sampling
20
+ top_p: float = 0.2 # After reducing the set to 15 tokens, Uses nucleus sampling to select tokens with a cumulative probability of 20%
21
+ repeat_penalty: float = 1.2 # Penalize repeated tokens to reduce redundancy
22
+ num_beams: int = 4 # Number of beams for beam search; higher values improve quality at the cost of speed
23
+ n_gpu_layers: int = -2 # Number of layers to offload to GPU; -1 for full GPU utilization, -2 for automatic configuration
24
+ n_ctx: int = 2048 # Context window size; Llama 3 models support up to 8192 tokens context length
25
+ n_batch: int = 256 # Number of tokens to process simultaneously; adjust based on available hardware (suggested 512)
26
+ verbose: bool = False # True for enabling verbose logging for debugging purposes
27
+ use_mmap: bool = False # Memory-map model to reduce RAM usage; set to True if running on limited memory systems
28
+ use_mlock: bool = True # Lock model into RAM to prevent swapping; improves performance on systems with sufficient RAM
29
+ offload_kqv: bool = True # Offload key, query, value matrices to GPU to accelerate inference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
 
 
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  class InsuranceLLM:
34
  def __init__(self, config: ModelConfig):
35
  self.config = config
36
+ self.llm_ctx = None
37
+ self.console = Console()
38
+ self.conversation_history: List[Dict[str, str]] = []
39
+
40
+ self.system_message = (
41
  "This is a chat between a user and an artificial intelligence assistant. "
42
  "The assistant gives helpful, detailed, and polite answers to the user's questions based on the context. "
43
  "The assistant should also indicate when the answer cannot be found in the context. "
44
+ "You are an expert from the Insurance domain with extensive insurance knowledge and "
 
 
45
  "professional writer skills, especially about insurance policies. "
46
  "Your name is OpenInsuranceLLM, and you were developed by Raj Maharajwala. "
47
  "You are willing to help answer the user's query with a detailed explanation. "
 
49
  "complex coverage plans, or other pertinent insurance concepts. Use precise insurance terminology while "
50
  "still aiming to make the explanation clear and accessible to a general audience."
51
  )
 
 
 
 
 
 
 
 
52
 
53
  def download_model(self) -> str:
54
  try:
55
+ with self.console.status("[bold green]Downloading model..."):
56
  model_path = hf_hub_download(
57
  self.config.model_name,
58
  filename=self.config.model_file,
59
  local_dir=os.path.join(os.getcwd(), 'gguf_dir')
60
  )
 
61
  return model_path
62
  except Exception as e:
63
+ self.console.print(f"[red]Error downloading model: {str(e)}[/red]")
64
  raise
65
 
66
  def load_model(self) -> None:
67
  try:
 
68
  quantized_path = os.path.join(os.getcwd(), "gguf_dir")
69
  directory = Path(quantized_path)
70
 
 
73
  except IndexError:
74
  model_path = self.download_model()
75
 
76
+ with self.console.status("[bold green]Loading model..."):
77
  self.llm_ctx = Llama(
78
  model_path=model_path,
79
  n_gpu_layers=self.config.n_gpu_layers,
 
85
  use_mmap=self.config.use_mmap,
86
  offload_kqv=self.config.offload_kqv
87
  )
 
 
88
  except Exception as e:
89
+ self.console.print(f"[red]Error loading model: {str(e)}[/red]")
90
  raise
91
 
92
+ def build_conversation_prompt(self, new_question: str, context: str = "") -> str:
93
+ prompt = f"System: {self.system_message}\n\n"
94
+
95
+ # Add conversation history
96
+ for exchange in self.conversation_history:
97
+ prompt += f"User: {exchange['user']}\n\n"
98
+ prompt += f"Assistant: {exchange['assistant']}\n\n"
99
+
100
+ # Add the new question
101
  if context:
102
+ prompt += f"User: Context: {context}\nQuestion: {new_question}\n\n"
103
+ else:
104
+ prompt += f"User: {new_question}\n\n"
105
+
106
+ prompt += "Assistant:"
107
+ return prompt
108
+
109
+ def generate_response(self, prompt: str) -> Tuple[str, int, float]:
 
 
 
 
 
110
  if not self.llm_ctx:
111
  raise RuntimeError("Model not loaded. Call load_model() first.")
112
+
113
+ self.console.print("[bold cyan]Assistant: [/bold cyan]", end="")
114
+ complete_response = ""
115
+ token_count = 0
116
+ start_time = time.time()
117
 
118
  try:
 
 
 
 
 
 
 
 
 
119
  for chunk in self.llm_ctx.create_completion(
120
  prompt,
121
  max_tokens=self.config.max_tokens,
 
126
  stream=True
127
  ):
128
  text_chunk = chunk["choices"][0]["text"]
 
 
 
 
129
  complete_response += text_chunk
130
+ token_count += 1
 
131
  print(text_chunk, end="", flush=True)
132
+
133
+ elapsed_time = time.time() - start_time
134
  print()
135
+ return complete_response, token_count, elapsed_time
136
+ except Exception as e:
137
+ self.console.print(f"\n[red]Error generating response: {str(e)}[/red]")
138
+ return f"I encountered an error while generating a response. Please try again or ask a different question.", 0, 0
139
 
140
+ def run_chat(self):
 
 
 
 
 
 
 
141
  try:
142
  self.load_model()
143
+ self.console.print("\n[bold green]Welcome to Open-Insurance-LLM![/bold green]")
144
+ self.console.print("Enter your questions (type '/bye', 'exit', or 'quit' to end the session)\n")
145
+ self.console.print("Optional: You can provide context by typing 'context:' followed by your context, then 'question:' followed by your question\n")
146
+ self.console.print("Your conversation history will be maintained for context-aware responses.\n")
147
+
148
+ total_tokens = 0
149
+
150
  while True:
151
  try:
152
+ user_input = self.console.input("[bold cyan]User:[/bold cyan] ").strip()
153
 
154
  if user_input.lower() in ["exit", "/bye", "quit"]:
155
+ self.console.print(f"\n[dim]Total tokens: {total_tokens}[/dim]")
156
+ self.console.print("\n[bold green]Thank you for using Open-Insurance-LLM![/bold green]")
 
157
  break
158
 
159
+ # Reset conversation with command
160
+ if user_input.lower() == "/reset":
161
+ self.conversation_history = []
162
+ self.console.print("[yellow]Conversation history has been reset.[/yellow]")
163
+ continue
164
+
165
  context = ""
166
  question = user_input
167
  if "context:" in user_input.lower() and "question:" in user_input.lower():
 
169
  context = parts[0].replace("context:", "").strip()
170
  question = parts[1].strip()
171
 
172
+ prompt = self.build_conversation_prompt(question, context)
173
+ response, tokens, elapsed_time = self.generate_response(prompt)
174
+
175
+ # Add to conversation history
176
+ self.conversation_history.append({
177
+ "user": question,
178
+ "assistant": response
179
+ })
180
+
181
+ # Update total tokens
182
+ total_tokens += tokens
183
 
 
184
  # Print metrics
185
+ tokens_per_sec = tokens / elapsed_time if elapsed_time > 0 else 0
186
+ self.console.print(
187
+ f"[dim]Tokens: {tokens} || " +
188
+ f"Time: {elapsed_time:.2f}s || " +
189
+ f"Speed: {tokens_per_sec:.2f} tokens/sec[/dim]"
190
+ )
191
+ print() # Add a blank line after each response
192
+
193
  except KeyboardInterrupt:
194
+ self.console.print("\n[yellow]Input interrupted. Type '/bye', 'exit', or 'quit' to quit.[/yellow]")
195
  continue
196
  except Exception as e:
197
+ self.console.print(f"\n[red]Error processing input: {str(e)}[/red]")
 
198
  continue
 
199
  except Exception as e:
200
+ self.console.print(f"\n[red]Fatal error: {str(e)}[/red]")
 
201
  finally:
202
  if self.llm_ctx:
203
  del self.llm_ctx
204
 
205
+
206
  def main():
 
 
207
  try:
208
  config = ModelConfig()
209
  llm = InsuranceLLM(config)
210
+ llm.run_chat()
211
  except KeyboardInterrupt:
212
+ print("\nProgram interrupted by user")
213
  except Exception as e:
214
+ print(f"\nApplication error: {str(e)}")
215
+
 
216
 
217
  if __name__ == "__main__":
218
  main()