Raj-Maharajwala commited on
Commit
0db5f78
·
verified ·
1 Parent(s): 1b77eb0

Updated Inference with Q4_K_M in Readme.md

Browse files
Files changed (1) hide show
  1. README.md +97 -88
README.md CHANGED
@@ -148,15 +148,16 @@ install(show_locals=True)
148
  @dataclass
149
  class ModelConfig:
150
  model_name: str = "Raj-Maharajwala/Open-Insurance-LLM-Llama3-8B-GGUF"
151
- model_file: str = "open-insurance-llm-q8_0.gguf"
152
- # model_file: str = "open-insurance-llm-q4_k_m.gguf"
 
153
  max_tokens: int = 1000
154
  top_k: int = 15
155
- top_p: float = 0.2
156
  repeat_penalty: float = 1.2
157
  num_beams: int = 4
158
- n_gpu_layers: int = -2 # -1 for complete GPU usage
159
- temperature: float = 0.1
160
  n_ctx: int = 2048 # 2048 - 8192 -> As per Llama 3 Full Capacity
161
  n_batch: int = 256
162
  verbose: bool = False
@@ -165,12 +166,12 @@ class ModelConfig:
165
  offload_kqv: bool =True
166
 
167
  class CustomFormatter(logging.Formatter):
168
- """Enhanced formatter with detailed context for different log levels"""
169
  FORMATS = {
170
  logging.DEBUG: "🔍 %(asctime)s - %(name)s - [%(filename)s:%(lineno)d] - %(levelname)s - %(message)s",
171
  logging.INFO: "ℹ️ %(asctime)s - %(name)s - [%(funcName)s] - %(levelname)s - %(message)s",
172
  logging.WARNING: "⚠️ %(asctime)s - %(name)s - [%(funcName)s] - %(levelname)s - %(message)s\nContext: %(pathname)s",
173
- logging.ERROR: "❌ %(asctime)s - %(name)s - [%(funcName)s:%(lineno)d] - %(levelname)s - %(message)s\nTraceback: %(exc_info)s",
174
  logging.CRITICAL: """🚨 %(asctime)s - %(name)s - %(levelname)s
175
  Location: %(pathname)s:%(lineno)d
176
  Function: %(funcName)s
@@ -178,7 +179,6 @@ Process: %(process)d
178
  Thread: %(thread)d
179
  Message: %(message)s
180
  Memory: %(memory).2fMB
181
- %(exc_info)s
182
  """
183
  }
184
 
@@ -186,18 +186,14 @@ Memory: %(memory).2fMB
186
  # Add memory usage information
187
  if not hasattr(record, 'memory'):
188
  record.memory = psutil.Process().memory_info().rss / (1024 * 1024)
189
-
190
- # Add exception info for ERROR and CRITICAL
191
- if record.levelno >= logging.ERROR and not record.exc_info:
192
- record.exc_info = traceback.format_exc()
193
-
194
  log_fmt = self.FORMATS.get(record.levelno)
195
  formatter = logging.Formatter(log_fmt, datefmt='%Y-%m-%d %H:%M:%S')
196
-
197
- # Add performance metrics
198
  if hasattr(record, 'duration'):
199
  record.message = f"{record.message}\nDuration: {record.duration:.2f}s"
200
-
201
  return formatter.format(record)
202
 
203
  def setup_logging(log_dir: str = "logs") -> logging.Logger:
@@ -206,36 +202,35 @@ def setup_logging(log_dir: str = "logs") -> logging.Logger:
206
  timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
207
  log_path = (Path(log_dir) / f"l_{timestamp}")
208
  log_path.mkdir(exist_ok=True)
209
- # Create separate log files for different levels
 
 
 
 
 
 
 
210
  handlers = {
211
- 'debug': logging.FileHandler(log_path / f"debug_{timestamp}.log"),
212
- 'info': logging.FileHandler(log_path / f"info_{timestamp}.log"),
213
- 'error': logging.FileHandler(log_path / f"error_{timestamp}.log"),
214
- 'critical': logging.FileHandler(log_path / f"critical_{timestamp}.log"),
215
- 'console': RichHandler(
216
  console=Console(theme=custom_theme),
217
  show_time=True,
218
  show_path=False,
219
  enable_link_path=True
220
- )
221
- }
222
- # Set levels for handlers
223
- handlers['debug'].setLevel(logging.DEBUG)
224
- handlers['info'].setLevel(logging.INFO)
225
- handlers['error'].setLevel(logging.ERROR)
226
- handlers['critical'].setLevel(logging.CRITICAL)
227
- handlers['console'].setLevel(logging.INFO)
228
- # Apply formatter to all handlers
229
  formatter = CustomFormatter()
230
- for handler in handlers.values():
 
231
  handler.setFormatter(formatter)
232
- # Configure root logger
233
- logger = logging.getLogger("InsuranceLLM")
234
- logger.setLevel(logging.DEBUG)
235
- # Add all handlers
236
- for handler in handlers.values():
237
  logger.addHandler(handler)
238
- # Log startup information
 
239
  logger.info(f"Starting new session {timestamp}")
240
  logger.info(f"Log directory: {log_dir}")
241
  return logger
@@ -250,23 +245,24 @@ class PerformanceMetrics:
250
  self.start_time = time.time()
251
  self.tokens = 0
252
  self.response_times = []
253
-
254
- def update(self, tokens: int, response_time: float = None):
 
 
 
 
 
255
  self.tokens += tokens
256
- if response_time:
257
- self.response_times.append(response_time)
258
-
259
  @property
260
  def elapsed_time(self) -> float:
261
  return time.time() - self.start_time
262
-
263
- @property
264
- def tokens_per_second(self) -> float:
265
- return self.tokens / self.elapsed_time if self.elapsed_time > 0 else 0
266
-
267
  @property
268
- def average_response_time(self) -> float:
269
- return sum(self.response_times) / len(self.response_times) if self.response_times else 0
270
 
271
  class InsuranceLLM:
272
  def __init__(self, config: ModelConfig):
@@ -274,7 +270,7 @@ class InsuranceLLM:
274
  self.llm_ctx: Optional[Llama] = None
275
  self.metrics = PerformanceMetrics()
276
  self.logger = setup_logging()
277
-
278
  nvidia_llama3_chatqa_system = (
279
  "This is a chat between a user and an artificial intelligence assistant. "
280
  "The assistant gives helpful, detailed, and polite answers to the user's questions based on the context. "
@@ -317,12 +313,12 @@ class InsuranceLLM:
317
  # self.check_metal_support()
318
  quantized_path = os.path.join(os.getcwd(), "gguf_dir")
319
  directory = Path(quantized_path)
320
-
321
  try:
322
  model_path = str(list(directory.glob(self.config.model_file))[0])
323
  except IndexError:
324
  model_path = self.download_model()
325
-
326
  with console.status("[bold green]Loading model..."):
327
  self.llm_ctx = Llama(
328
  model_path=model_path,
@@ -336,7 +332,7 @@ class InsuranceLLM:
336
  offload_kqv=self.config.offload_kqv
337
  )
338
  self.logger.info("Model loaded successfully")
339
-
340
  except Exception as e:
341
  self.logger.error(f"Error loading model: {str(e)}")
342
  raise
@@ -354,14 +350,20 @@ class InsuranceLLM:
354
  "Assistant:"
355
  )
356
 
 
357
  def generate_response(self, prompt: str) -> Dict[str, Any]:
358
  if not self.llm_ctx:
359
  raise RuntimeError("Model not loaded. Call load_model() first.")
360
-
361
  try:
362
  response = {"text": "", "tokens": 0}
363
- console.print("[bold cyan]Assistant:[/bold cyan]", end=" ")
364
-
 
 
 
 
 
 
365
  for chunk in self.llm_ctx.create_completion(
366
  prompt,
367
  max_tokens=self.config.max_tokens,
@@ -374,67 +376,71 @@ class InsuranceLLM:
374
  text_chunk = chunk["choices"][0]["text"]
375
  response["text"] += text_chunk
376
  response["tokens"] += 1
377
- console.print(text_chunk, end="", markup=False)
378
-
379
- console.print()
 
 
 
 
 
 
 
380
  return response
381
-
382
  except RuntimeError as e:
383
  if "llama_decode returned -3" in str(e):
384
  self.logger.error("Memory allocation failed. Try reducing context window or batch size")
385
  raise
386
 
387
- def print_metrics(self, response_tokens: int, memory_used: float):
388
- try:
389
- console.print("\n[dim]Performance Metrics:[/dim]")
390
- console.print(f"[dim]Memory usage: {memory_used:.2f} MB[/dim]")
391
- console.print(f"[dim]Tokens generated: {response_tokens}[/dim]")
392
- console.print(f"[dim]Average tokens/sec: {self.metrics.tokens_per_second:.2f}[/dim]")
393
- console.print(f"[dim]Total tokens: {self.metrics.tokens}[/dim]")
394
- console.print(f"[dim]Total time: {self.metrics.elapsed_time:.2f}s[/dim]\n")
395
- except Exception as e:
396
- self.logger.error(f"Error printing metrics: {str(e)}")
397
-
398
  def run_inference_loop(self):
399
  try:
400
  self.load_model()
401
  console.print("\n[bold green]Welcome to Open-Insurance-LLM![/bold green]")
402
  console.print("Enter your questions (type '/bye', 'exit', or 'quit' to end the session)\n")
403
  console.print("Optional: You can provide context by typing 'context:' followed by your context, then 'question:' followed by your question\n")
404
-
 
405
  while True:
406
  try:
407
  user_input = console.input("[bold cyan]User:[/bold cyan] ").strip()
408
-
409
- exit_commands = ["exit", "/bye", "quit"]
410
- if user_input.lower() in exit_commands:
 
411
  console.print("\n[bold green]Thank you for using OpenInsuranceLLM![/bold green]")
412
  break
413
-
414
  context = ""
415
  question = user_input
416
  if "context:" in user_input.lower() and "question:" in user_input.lower():
417
  parts = user_input.split("question:", 1)
418
  context = parts[0].replace("context:", "").strip()
419
  question = parts[1].strip()
420
-
421
  prompt = self.get_prompt(question, context)
422
-
423
- with self.timer("Response generation"):
424
- response = self.generate_response(prompt)
425
-
 
 
 
 
426
  self.metrics.update(response["tokens"])
427
- memory_used = psutil.Process().memory_info().rss / 1024 / 1024
428
- self.print_metrics(response["tokens"], memory_used)
429
 
 
 
 
 
 
430
  except KeyboardInterrupt:
431
  console.print("\n[yellow]Input interrupted. Type '/bye', 'exit', or 'quit' to quit.[/yellow]")
432
  continue
433
  except Exception as e:
434
  self.logger.error(f"Error processing input: {str(e)}")
435
- console.print(f"\n[red]Error processing input: {str(e)}[/red]")
436
  continue
437
-
438
  except Exception as e:
439
  self.logger.error(f"Fatal error in inference loop: {str(e)}")
440
  console.print(f"\n[red]Fatal error: {str(e)}[/red]")
@@ -443,6 +449,8 @@ class InsuranceLLM:
443
  del self.llm_ctx
444
 
445
  def main():
 
 
446
  try:
447
  config = ModelConfig()
448
  llm = InsuranceLLM(config)
@@ -450,8 +458,9 @@ def main():
450
  except KeyboardInterrupt:
451
  console.print("\n[yellow]Program interrupted by user[/yellow]")
452
  except Exception as e:
453
- logging.error(f"Application error: {str(e)}")
454
- console.print(f"\n[red]Fatal error: {str(e)}[/red]")
 
455
 
456
  if __name__ == "__main__":
457
  main()
 
148
  @dataclass
149
  class ModelConfig:
150
  model_name: str = "Raj-Maharajwala/Open-Insurance-LLM-Llama3-8B-GGUF"
151
+ model_file: str = "open-insurance-llm-q4_k_m.gguf"
152
+ # model_file: str = "open-insurance-llm-q8_0.gguf"
153
+ # model_file: str = "open-insurance-llm-q5_k_m.gguf"
154
  max_tokens: int = 1000
155
  top_k: int = 15
156
+ top_p: float = 0.2
157
  repeat_penalty: float = 1.2
158
  num_beams: int = 4
159
+ n_gpu_layers: int = -2 #-2 # -1 for complete GPU usage
160
+ temperature: float = 0.1 # Coherent(0.1) vs Creativity(0.8)
161
  n_ctx: int = 2048 # 2048 - 8192 -> As per Llama 3 Full Capacity
162
  n_batch: int = 256
163
  verbose: bool = False
 
166
  offload_kqv: bool =True
167
 
168
  class CustomFormatter(logging.Formatter):
169
+ """Enhanced formatter with detailed context for different log levels"""
170
  FORMATS = {
171
  logging.DEBUG: "🔍 %(asctime)s - %(name)s - [%(filename)s:%(lineno)d] - %(levelname)s - %(message)s",
172
  logging.INFO: "ℹ️ %(asctime)s - %(name)s - [%(funcName)s] - %(levelname)s - %(message)s",
173
  logging.WARNING: "⚠️ %(asctime)s - %(name)s - [%(funcName)s] - %(levelname)s - %(message)s\nContext: %(pathname)s",
174
+ logging.ERROR: "❌ %(asctime)s - %(name)s - [%(funcName)s:%(lineno)d] - %(levelname)s - %(message)s",
175
  logging.CRITICAL: """🚨 %(asctime)s - %(name)s - %(levelname)s
176
  Location: %(pathname)s:%(lineno)d
177
  Function: %(funcName)s
 
179
  Thread: %(thread)d
180
  Message: %(message)s
181
  Memory: %(memory).2fMB
 
182
  """
183
  }
184
 
 
186
  # Add memory usage information
187
  if not hasattr(record, 'memory'):
188
  record.memory = psutil.Process().memory_info().rss / (1024 * 1024)
189
+
 
 
 
 
190
  log_fmt = self.FORMATS.get(record.levelno)
191
  formatter = logging.Formatter(log_fmt, datefmt='%Y-%m-%d %H:%M:%S')
192
+
193
+ # Add performance metrics if available
194
  if hasattr(record, 'duration'):
195
  record.message = f"{record.message}\nDuration: {record.duration:.2f}s"
196
+
197
  return formatter.format(record)
198
 
199
  def setup_logging(log_dir: str = "logs") -> logging.Logger:
 
202
  timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
203
  log_path = (Path(log_dir) / f"l_{timestamp}")
204
  log_path.mkdir(exist_ok=True)
205
+
206
+ # Create logger
207
+ logger = logging.getLogger("InsuranceLLM")
208
+ # Clear any existing handlers
209
+ logger.handlers.clear()
210
+ logger.setLevel(logging.DEBUG)
211
+
212
+ # Create handlers with level-specific files
213
  handlers = {
214
+ 'debug': (logging.FileHandler(log_path / f"debug_{timestamp}.log"), logging.DEBUG),
215
+ 'info': (logging.FileHandler(log_path / f"info_{timestamp}.log"), logging.INFO),
216
+ 'error': (logging.FileHandler(log_path / f"error_{timestamp}.log"), logging.ERROR),
217
+ 'critical': (logging.FileHandler(log_path / f"critical_{timestamp}.log"), logging.CRITICAL),
218
+ 'console': (RichHandler(
219
  console=Console(theme=custom_theme),
220
  show_time=True,
221
  show_path=False,
222
  enable_link_path=True
223
+ ), logging.INFO)
224
+ }
225
+
226
+ # Configure handlers
 
 
 
 
 
227
  formatter = CustomFormatter()
228
+ for (handler, level) in handlers.values():
229
+ handler.setLevel(level)
230
  handler.setFormatter(formatter)
 
 
 
 
 
231
  logger.addHandler(handler)
232
+
233
+ # Log startup information (will now appear only once)
234
  logger.info(f"Starting new session {timestamp}")
235
  logger.info(f"Log directory: {log_dir}")
236
  return logger
 
245
  self.start_time = time.time()
246
  self.tokens = 0
247
  self.response_times = []
248
+ self.last_reset = self.start_time
249
+
250
+ def reset_timer(self):
251
+ """Reset the timer for individual response measurements"""
252
+ self.last_reset = time.time()
253
+
254
+ def update(self, tokens: int):
255
  self.tokens += tokens
256
+ response_time = time.time() - self.last_reset
257
+ self.response_times.append(response_time)
258
+
259
  @property
260
  def elapsed_time(self) -> float:
261
  return time.time() - self.start_time
262
+
 
 
 
 
263
  @property
264
+ def last_response_time(self) -> float:
265
+ return self.response_times[-1] if self.response_times else 0
266
 
267
  class InsuranceLLM:
268
  def __init__(self, config: ModelConfig):
 
270
  self.llm_ctx: Optional[Llama] = None
271
  self.metrics = PerformanceMetrics()
272
  self.logger = setup_logging()
273
+
274
  nvidia_llama3_chatqa_system = (
275
  "This is a chat between a user and an artificial intelligence assistant. "
276
  "The assistant gives helpful, detailed, and polite answers to the user's questions based on the context. "
 
313
  # self.check_metal_support()
314
  quantized_path = os.path.join(os.getcwd(), "gguf_dir")
315
  directory = Path(quantized_path)
316
+
317
  try:
318
  model_path = str(list(directory.glob(self.config.model_file))[0])
319
  except IndexError:
320
  model_path = self.download_model()
321
+
322
  with console.status("[bold green]Loading model..."):
323
  self.llm_ctx = Llama(
324
  model_path=model_path,
 
332
  offload_kqv=self.config.offload_kqv
333
  )
334
  self.logger.info("Model loaded successfully")
335
+
336
  except Exception as e:
337
  self.logger.error(f"Error loading model: {str(e)}")
338
  raise
 
350
  "Assistant:"
351
  )
352
 
353
+
354
  def generate_response(self, prompt: str) -> Dict[str, Any]:
355
  if not self.llm_ctx:
356
  raise RuntimeError("Model not loaded. Call load_model() first.")
 
357
  try:
358
  response = {"text": "", "tokens": 0}
359
+
360
+ # Print the initial prompt
361
+ # print("Assistant: ", end="", flush=True)
362
+ console.print("\n[bold cyan]Assistant: [/bold cyan]", end="")
363
+
364
+ # Initialize complete response
365
+ complete_response = ""
366
+
367
  for chunk in self.llm_ctx.create_completion(
368
  prompt,
369
  max_tokens=self.config.max_tokens,
 
376
  text_chunk = chunk["choices"][0]["text"]
377
  response["text"] += text_chunk
378
  response["tokens"] += 1
379
+
380
+ # Append to complete response
381
+ complete_response += text_chunk
382
+
383
+ # Use simple print for streaming output
384
+ print(text_chunk, end="", flush=True)
385
+
386
+ # Print final newline
387
+ print()
388
+
389
  return response
390
+
391
  except RuntimeError as e:
392
  if "llama_decode returned -3" in str(e):
393
  self.logger.error("Memory allocation failed. Try reducing context window or batch size")
394
  raise
395
 
 
 
 
 
 
 
 
 
 
 
 
396
  def run_inference_loop(self):
397
  try:
398
  self.load_model()
399
  console.print("\n[bold green]Welcome to Open-Insurance-LLM![/bold green]")
400
  console.print("Enter your questions (type '/bye', 'exit', or 'quit' to end the session)\n")
401
  console.print("Optional: You can provide context by typing 'context:' followed by your context, then 'question:' followed by your question\n")
402
+ memory_used = psutil.Process().memory_info().rss / 1024 / 1024
403
+ console.print(f"[dim]Memory usage: {memory_used:.2f} MB[/dim]")
404
  while True:
405
  try:
406
  user_input = console.input("[bold cyan]User:[/bold cyan] ").strip()
407
+
408
+ if user_input.lower() in ["exit", "/bye", "quit"]:
409
+ console.print(f"[dim]Total tokens uptill now: {self.metrics.tokens}[/dim]")
410
+ console.print(f"[dim]Total Session Time: {self.metrics.elapsed_time:.2}[/dim]")
411
  console.print("\n[bold green]Thank you for using OpenInsuranceLLM![/bold green]")
412
  break
 
413
  context = ""
414
  question = user_input
415
  if "context:" in user_input.lower() and "question:" in user_input.lower():
416
  parts = user_input.split("question:", 1)
417
  context = parts[0].replace("context:", "").strip()
418
  question = parts[1].strip()
419
+
420
  prompt = self.get_prompt(question, context)
421
+
422
+ # Reset timer before generation
423
+ self.metrics.reset_timer()
424
+
425
+ # Generate response
426
+ response = self.generate_response(prompt)
427
+
428
+ # Update metrics after generation
429
  self.metrics.update(response["tokens"])
 
 
430
 
431
+ # Print metrics
432
+ console.print(f"[dim]Average tokens/sec: {response['tokens']/(self.metrics.last_response_time if self.metrics.last_response_time!=0 else 1):.2f} ||[/dim]",
433
+ f"[dim]Tokens generated: {response['tokens']} ||[/dim]",
434
+ f"[dim]Response time: {self.metrics.last_response_time:.2f}s[/dim]", end="\n\n\n")
435
+
436
  except KeyboardInterrupt:
437
  console.print("\n[yellow]Input interrupted. Type '/bye', 'exit', or 'quit' to quit.[/yellow]")
438
  continue
439
  except Exception as e:
440
  self.logger.error(f"Error processing input: {str(e)}")
441
+ console.print(f"\n[red]Error: {str(e)}[/red]")
442
  continue
443
+
444
  except Exception as e:
445
  self.logger.error(f"Fatal error in inference loop: {str(e)}")
446
  console.print(f"\n[red]Fatal error: {str(e)}[/red]")
 
449
  del self.llm_ctx
450
 
451
  def main():
452
+ if hasattr(multiprocessing, "set_start_method"):
453
+ multiprocessing.set_start_method("spawn", force=True)
454
  try:
455
  config = ModelConfig()
456
  llm = InsuranceLLM(config)
 
458
  except KeyboardInterrupt:
459
  console.print("\n[yellow]Program interrupted by user[/yellow]")
460
  except Exception as e:
461
+ error_msg = f"Application error: {str(e)}"
462
+ logging.error(error_msg)
463
+ console.print(f"\n[red]{error_msg}[/red]")
464
 
465
  if __name__ == "__main__":
466
  main()