#!/usr/bin/env python3 # prefill.py # Copyright (c) 2025 Anemll # Licensed under the MIT License import argparse import os import re import glob from pathlib import Path import coremltools as ct from transformers import AutoTokenizer import torch import torch.nn.functional as F import numpy as np import time import yaml import sys # ANSI color codes LIGHT_BLUE = "\033[94m" DARK_BLUE = "\033[34m" LIGHT_GREEN = "\033[92m" RESET_COLOR = "\033[0m" def parse_model_path(path): """Parse model path and return full path with .mlmodelc or .mlpackage extension.""" path = Path(path) # If path exists exactly as specified, return it if path.exists(): return str(path) # Try with both extensions candidates = [ path, # Original path path.with_suffix('.mlmodelc'), # With .mlmodelc path.with_suffix('.mlpackage'), # With .mlpackage Path(str(path) + '.mlmodelc'), # Handle case where extension is included Path(str(path) + '.mlpackage') ] # Try all possible paths for candidate in candidates: if candidate.exists(): print(f"Found model at: {candidate}") return str(candidate) # If we get here, no valid path was found print("\nError: Model not found. Tried following paths:") for candidate in candidates: print(f" {candidate}") raise FileNotFoundError(f"Model not found: {path}") def parse_ffn_filename(path): """Parse FFN model filename to extract chunk information.""" path = Path(path) pattern = r'FFN_PF.*_chunk_(\d+)of(\d+)' match = re.search(pattern, path.name) if match: current_chunk = int(match.group(1)) total_chunks = int(match.group(2)) return current_chunk, total_chunks return None, None def find_all_chunks(base_path): """Find all chunk files matching the base FFN path pattern.""" path = Path(base_path) pattern = re.sub(r'_chunk_\d+of\d+', '_chunk_*', str(path)) return sorted(glob.glob(pattern)) def load_model(path, function_name=None): """Load a CoreML model, handling both .mlmodelc and .mlpackage formats.""" path = Path(path) compute_unit = ct.ComputeUnit.CPU_AND_NE try: if path.suffix == '.mlmodelc': # For compiled models (.mlmodelc), use CompiledMLModel if function_name: return ct.models.CompiledMLModel(str(path), compute_unit, function_name=function_name) else: return ct.models.CompiledMLModel(str(path), compute_unit) else: # For packages (.mlpackage) if function_name: return ct.models.MLModel(str(path), function_name=function_name) else: return ct.models.MLModel(str(path)) except RuntimeError as e: if "valid manifest does not exist" in str(e): print(f"\nError: Could not load compiled model at {path}") print("This might be because:") print("1. The model is not properly compiled") print("2. The model was compiled for a different OS version") print("3. The model needs to be recompiled") print("\nTry using the .mlpackage version instead, or recompile the model.") raise def load_metadata(model, args): # Extract metadata and config parameters metadata = {} if hasattr(model, 'user_defined_metadata'): meta = model.user_defined_metadata # Extract key parameters with defaults metadata['context_length'] = int(meta.get('com.anemll.context_length', 512)) metadata['state_length'] = int(meta.get('com.anemll.state_length', metadata['context_length'])) metadata['batch_size'] = int(meta.get('com.anemll.batch_size', 64)) metadata['lut_bits'] = int(meta.get('com.anemll.lut_bits', 0)) metadata['num_chunks'] = int(meta.get('com.anemll.num_chunks', 1)) print("\nExtracted Parameters:") print(f" Context Length: {metadata['context_length']}") print(f" State Length: {metadata['state_length']}") print(f" Prefill Batch Size: {metadata['batch_size']}") print(f" LUT Bits: {metadata['lut_bits']}") print(f" Number of Chunks: {metadata['num_chunks']}") else: print("\nWarning: No metadata found in model") # Check if model directory name contains context length pattern (ctxXXX) ctx_len = 512 if args.context_length is None: import re ctx_match = re.search(r'ctx(\d+)', str(args.d)) if ctx_match: ctx_len0 = int(ctx_match.group(1)) if 512 <= ctx_len0 <= 8096: ctx_len = ctx_len0 print(f"\nDetected context length {ctx_len} from directory name") else: print(f"\nWarning: No context length found in directory, using default {ctx_len}") else: ctx_len = args.context_length # Use defaults or values from args metadata['context_length'] = ctx_len metadata['state_length'] = ctx_len # Get batch size from args or use default metadata['batch_size'] = getattr(args, 'batch_size', 64) metadata['lut_bits'] = 4 metadata['num_chunks'] = getattr(args, 'num_chunks', 4) print("\nUsing parameters:") print(f" Context Length: {metadata['context_length']}") print(f" State Length: {metadata['state_length']}") print(f" Prefill Batch Size: {metadata['batch_size']}") print(f" LUT Bits: {metadata['lut_bits']}") print(f" Number of Chunks: {metadata['num_chunks']}") # Override with values from args if they exist if hasattr(args, 'batch_size') and args.batch_size is not None: metadata['batch_size'] = args.batch_size print(f"\nOverriding batch size from args: {args.batch_size}") if hasattr(args, 'num_chunks') and args.num_chunks is not None: metadata['num_chunks'] = args.num_chunks print(f"\nOverriding num chunks from args: {args.num_chunks}") return metadata def load_models(args, metadata): """Load all required models and extract metadata.""" print("\nLoading models...") try: # Load embeddings model print("\nLoading embeddings model...") embed_path = parse_model_path(args.embed) print(f"Loading from: {embed_path}") embed_model = load_model(embed_path) print("Embeddings model loaded successfully") metadata = load_metadata(embed_model, args) # Load FFN model(s) print("\nLoading PREFILL functionality only...") ffn_path = parse_model_path(args.ffn) chunk_no, total_chunks = parse_ffn_filename(ffn_path) ffn_models = [] if chunk_no and total_chunks: print(f"\nDetected chunked model with {total_chunks} chunks") # Find and load all chunks chunk_paths = find_all_chunks(ffn_path) if len(chunk_paths) != total_chunks: raise ValueError(f"Found {len(chunk_paths)} chunks but filename indicates {total_chunks} chunks") for chunk_path in chunk_paths: print(f"\nLoading PREFILL function from chunk: {Path(chunk_path).name}") try: # For prefill testing, we only need the prefill function prefill_model = load_model(chunk_path, function_name='prefill') ffn_models.append(prefill_model) print("Chunk loaded successfully (prefill only)") except Exception as e: print(f"Error loading chunk {chunk_path}: {str(e)}") raise metadata = load_metadata(ffn_models[0], args) else: print("\nLoading single model (prefill functionality only)...") ffn_models.append(load_model(ffn_path)) print("Model loaded successfully") return embed_model, ffn_models, metadata except Exception as e: print(f"\nError loading models: {str(e)}") print("\nPlease ensure all model files exist and are accessible.") print("Expected files:") print(f" Embeddings: {args.embed}") print(f" FFN: {args.ffn}") raise def initialize_tokenizer(model_path=None): """Initialize and configure the tokenizer.""" try: tokenizer = AutoTokenizer.from_pretrained( str(model_path), use_fast=False, trust_remote_code=True ) print("\nTokenizer Configuration:") print(f"Tokenizer type: {type(tokenizer)}") print(f"Tokenizer name: {tokenizer.__class__.__name__}") print(f"Vocabulary size: {len(tokenizer)}") if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token_id = tokenizer.eos_token_id print("Set PAD token to EOS token") tokenizer.padding_side = "left" return tokenizer except Exception as e: print(f"\nError: Failed to load tokenizer from {model_path}") print(f"Error details: {str(e)}") raise def make_causal_mask(length, start): """Create causal attention mask.""" mask = np.full((1, 1, length, length), -np.inf, dtype=np.float16) row_indices = np.arange(length).reshape(length, 1) col_indices = np.arange(length).reshape(1, length) mask[:, :, col_indices <= (row_indices + start)] = 0 return mask def initialize_causal_mask(context_length): """Initialize causal mask for transformer attention.""" causal_mask = make_causal_mask(context_length, 0) causal_mask = torch.tensor(causal_mask, dtype=torch.float16) print(f"\nInitialized causal mask for context length {context_length}") return causal_mask def run_prefill(embed_model, ffn_models, input_ids, context_pos, context_length, batch_size=64, state=None, causal_mask=None): """Run prefill on the input sequence.""" # Use provided causal mask or create one if not provided if causal_mask is None: causal_mask = make_causal_mask(context_length, 0) causal_mask = torch.tensor(causal_mask, dtype=torch.float16) # Process in batches batch_pos = 0 while batch_pos < context_pos: batch_end = min(batch_pos + batch_size, context_pos) current_batch_size = batch_end - batch_pos # Get current batch batch_input = input_ids[:, batch_pos:batch_end] # Always pad to full batch size for prefill batch_input = F.pad( batch_input, (0, batch_size - current_batch_size), value=0 ) # Generate position IDs for full batch size position_ids = torch.arange(batch_size, dtype=torch.int32) batch_causal_mask = causal_mask[:, :, :batch_size, :] # Run embeddings with proper batch size hidden_states = torch.from_numpy( embed_model.predict({ 'input_ids': batch_input.numpy(), 'batch_size': np.array([batch_size], dtype=np.int32) })['hidden_states'] ) # Run through FFN chunks with state for ffn_model in ffn_models: # Handle both direct model and dictionary formats if isinstance(ffn_model, dict) and 'prefill' in ffn_model: # For backward compatibility with dictionary format prefill_model = ffn_model['prefill'] else: # Direct access for models loaded with function_name='prefill' prefill_model = ffn_model inputs = { 'hidden_states': hidden_states.numpy(), 'position_ids': position_ids.numpy(), 'causal_mask': batch_causal_mask.numpy(), 'current_pos': np.array([batch_pos], dtype=np.int32) } output = prefill_model.predict(inputs, state) hidden_states = torch.from_numpy(output['output_hidden_states']) batch_pos = batch_end return torch.tensor([context_pos], dtype=torch.int32) def create_unified_state(ffn_models, context_length): """Create unified KV cache state for transformer.""" if hasattr(ffn_models[0], 'make_state'): # Direct access for models loaded with 'prefill' function_name state = ffn_models[0].make_state() print(f"\nCreated unified transformer state for {len(ffn_models)} chunks") return state else: # Fallback for dictionary-based models (for backward compatibility) if isinstance(ffn_models[0], dict) and 'prefill' in ffn_models[0]: state = ffn_models[0]['prefill'].make_state() print(f"\nCreated unified transformer state for {len(ffn_models)} chunks") return state else: state = ffn_models[0].make_state() print("\nCreated unified transformer state") return state def test_prefill_speed(embed_model, ffn_models, tokenizer, batch_size, context_length, num_test_tokens, num_runs=20, test_single_chunk=True): """Test prefill speed with sample token sequences.""" print(f"\n{LIGHT_GREEN}Testing prefill speed for {num_test_tokens} tokens (using internal batch size {batch_size}){RESET_COLOR}") print(f"Running {num_runs} iterations for warmup and measurement") # Create sample input sequence of exactly num_test_tokens tokens sample_text = "This is a test sequence. " * ((num_test_tokens + 4) // 5) # Ensure enough text input_ids = tokenizer(sample_text, return_tensors="pt").input_ids.to(torch.int32) # Trim or pad to exactly num_test_tokens tokens if input_ids.size(1) > num_test_tokens: input_ids = input_ids[:, :num_test_tokens] elif input_ids.size(1) < num_test_tokens: pad_length = num_test_tokens - input_ids.size(1) input_ids = F.pad(input_ids, (0, pad_length), value=tokenizer.pad_token_id) print(f"Sample input sequence length: {input_ids.size(1)} tokens") # Test with all chunks first print(f"\n{LIGHT_BLUE}Testing with all chunks ({len(ffn_models)} chunks){RESET_COLOR}") # Create unified state state_all_chunks = create_unified_state(ffn_models, context_length) # Initialize causal mask causal_mask = initialize_causal_mask(context_length) # Run prefill multiple times for warmup and testing all_chunks_times = [] for i in range(num_runs): # Reset state for each run if i == 0: print("\nStarting warmup runs...") elif i == num_runs // 2: print("\nWarmup complete, starting measurement runs...") start_time = time.time() # Run prefill run_prefill( embed_model, ffn_models, input_ids, input_ids.size(1), # context_pos context_length, batch_size, # Internal batching within run_prefill state_all_chunks, causal_mask ) elapsed = time.time() - start_time all_chunks_times.append(elapsed) # Print progress if i < num_runs // 2: # Warmup phase print(f"Warmup run {i+1}/{num_runs//2}: {elapsed:.4f}s ({batch_size/elapsed:.1f} tokens/s)") else: # Measurement phase print(f"Run {i+1-num_runs//2}/{num_runs//2}: {elapsed:.4f}s ({batch_size/elapsed:.1f} tokens/s)") # Calculate and report statistics for all chunks (excluding warmup runs) all_chunks_measurement_times = all_chunks_times[num_runs // 2:] all_chunks_avg_time = sum(all_chunks_measurement_times) / len(all_chunks_measurement_times) all_chunks_min_time = min(all_chunks_measurement_times) all_chunks_max_time = max(all_chunks_measurement_times) all_chunks_tokens_per_sec = num_test_tokens / all_chunks_avg_time # Use num_test_tokens for speed calculation print(f"\n{LIGHT_BLUE}All Chunks Prefill Speed Results:{RESET_COLOR}") print(f"Number of Chunks: {len(ffn_models)}") print(f"Test Tokens: {num_test_tokens} tokens") print(f"Internal Batch Size: {batch_size} tokens") print(f"Context Size: {context_length} tokens") print(f"Average Time: {all_chunks_avg_time:.4f}s") print(f"Min Time: {all_chunks_min_time:.4f}s") print(f"Max Time: {all_chunks_max_time:.4f}s") print(f"Average Speed: {all_chunks_tokens_per_sec:.1f} tokens/second") print(f"Best Speed: {num_test_tokens / all_chunks_min_time:.1f} tokens/second") # Use num_test_tokens # Test with single chunk if requested and if multiple chunks exist single_chunk_tokens_per_sec = 0 if test_single_chunk and len(ffn_models) > 1: print(f"\n{LIGHT_BLUE}Testing with single chunk (first chunk only){RESET_COLOR}") # Create a list with only the first chunk single_chunk_model = [ffn_models[0]] # Create unified state for single chunk state_single_chunk = create_unified_state(single_chunk_model, context_length) # Run prefill multiple times for single chunk single_chunk_times = [] for i in range(num_runs): if i == 0: print("\nStarting single chunk warmup runs...") elif i == num_runs // 2: print("\nSingle chunk warmup complete, starting measurement runs...") start_time = time.time() # Run prefill with single chunk run_prefill( embed_model, single_chunk_model, input_ids, input_ids.size(1), # context_pos context_length, batch_size, # Internal batching within run_prefill state_single_chunk, causal_mask ) elapsed = time.time() - start_time single_chunk_times.append(elapsed) # Print progress if i < num_runs // 2: # Warmup phase print(f"Single chunk warmup run {i+1}/{num_runs//2}: {elapsed:.4f}s ({batch_size/elapsed:.1f} tokens/s)") else: # Measurement phase print(f"Single chunk run {i+1-num_runs//2}/{num_runs//2}: {elapsed:.4f}s ({batch_size/elapsed:.1f} tokens/s)") # Calculate and report statistics for single chunk single_chunk_measurement_times = single_chunk_times[num_runs // 2:] single_chunk_avg_time = sum(single_chunk_measurement_times) / len(single_chunk_measurement_times) single_chunk_min_time = min(single_chunk_measurement_times) single_chunk_max_time = max(single_chunk_measurement_times) single_chunk_tokens_per_sec = num_test_tokens / single_chunk_avg_time # Use num_test_tokens print(f"\n{LIGHT_BLUE}Single Chunk Prefill Speed Results:{RESET_COLOR}") print(f"Test Tokens: {num_test_tokens} tokens") print(f"Internal Batch Size: {batch_size} tokens") print(f"Context Size: {context_length} tokens") print(f"Average Time: {single_chunk_avg_time:.4f}s") print(f"Min Time: {single_chunk_min_time:.4f}s") print(f"Max Time: {single_chunk_max_time:.4f}s") print(f"Average Speed: {single_chunk_tokens_per_sec:.1f} tokens/second") print(f"Best Speed: {num_test_tokens / single_chunk_min_time:.1f} tokens/second") # Use num_test_tokens # Calculate overhead per chunk if len(ffn_models) > 1: chunk_overhead = (all_chunks_avg_time - single_chunk_avg_time) / (len(ffn_models) - 1) print(f"\n{LIGHT_GREEN}Chunk Overhead Analysis:{RESET_COLOR}") print(f"Single Chunk Time: {single_chunk_avg_time:.4f}s") print(f"All Chunks Time ({len(ffn_models)} chunks): {all_chunks_avg_time:.4f}s") print(f"Additional Time Per Chunk: {chunk_overhead:.4f}s") print(f"Overhead Percentage: {(all_chunks_avg_time/single_chunk_avg_time - 1)*100:.1f}%") return all_chunks_tokens_per_sec, single_chunk_tokens_per_sec def parse_args(): parser = argparse.ArgumentParser(description='Test prefill speed with CoreML LLaMA models (c) 2025 Anemll') # Add meta.yaml option parser.add_argument('--meta', type=str, help='Path to meta.yaml to load all parameters') # Model paths parser.add_argument('--d', '--dir', type=str, default='.', help='Directory containing model files (default: current directory)') parser.add_argument('--embed', type=str, required=False, help='Path to embeddings model (relative to --dir)') parser.add_argument('--ffn', type=str, required=False, help='Path to FFN model (can be chunked, relative to --dir)') parser.add_argument('--tokenizer', type=str, required=False, help='Path to tokenizer') # Test configuration parser.add_argument('--batch-size', type=int, help='Batch size for prefill test (default: 64)') parser.add_argument('--runs', type=int, default=20, help='Number of test runs (default: 20)') parser.add_argument('--context-length', type=int, help='Context length for the model') parser.add_argument('--no-single-chunk', action='store_true', help='Disable single chunk testing') parser.add_argument('--test-tokens', type=int, help='Number of tokens to use for the speed test (default: batch_size)') parser.add_argument('--test-full-context', action='store_true', help='Test prefill speed using the full context length (overrides --test-tokens)') args = parser.parse_args() # If meta.yaml is provided, load parameters from it if args.meta: try: with open(args.meta, 'r') as f: meta = yaml.safe_load(f) params = meta['model_info']['parameters'] # Set model directory to meta.yaml directory if not specified if not args.d or args.d == '.': args.d = str(Path(args.meta).parent) # Build model paths based on parameters prefix = params.get('model_prefix', 'llama') lut_ffn = f"_lut{params['lut_ffn']}" if params['lut_ffn'] != 'none' else '' lut_embeddings = f"_lut{params['lut_embeddings']}" if params['lut_embeddings'] != 'none' else '' num_chunks = int(params['num_chunks']) # Set model paths if not specified if not args.embed: args.embed = f'{prefix}_embeddings{lut_embeddings}' if not args.ffn: args.ffn = f'{prefix}_FFN_PF{lut_ffn}_chunk_01of{num_chunks:02d}' if not args.tokenizer: args.tokenizer = args.d # Set other parameters if not overridden by command line if args.context_length is None: args.context_length = int(params['context_length']) if args.batch_size is None: args.batch_size = int(params['batch_size']) args.num_chunks = num_chunks print(f"\nLoaded parameters from {args.meta}:") print(f" Context Length: {args.context_length}") print(f" Batch Size: {args.batch_size}") print(f" Num Chunks: {args.num_chunks}") print(f" Models Directory: {args.d}") print(f" Embeddings: {args.embed}") print(f" FFN: {args.ffn}") except Exception as e: print(f"\nError loading meta.yaml: {str(e)}") sys.exit(1) return args def main(): args = parse_args() # Use default batch size if not specified if args.batch_size is None: args.batch_size = 64 print(f"\nUsing default batch size: {args.batch_size}") # Convert directory to absolute path model_dir = Path(args.d).resolve() if not model_dir.exists(): print(f"\nError: Model directory not found: {model_dir}") return 1 print(f"\nUsing model directory: {model_dir}") try: # Update paths to be relative to model directory args.embed = str(model_dir / args.embed) args.ffn = str(model_dir / args.ffn) # Handle tokenizer path separately if args.tokenizer is None: args.tokenizer = str(model_dir) if not Path(args.tokenizer).exists(): print(f"\nError: Tokenizer directory not found: {args.tokenizer}") return 1 args.tokenizer = str(Path(args.tokenizer).resolve()) print(f"Using tokenizer path: {args.tokenizer}") # Load models and extract metadata metadata = {} embed_model, ffn_models, metadata = load_models(args, metadata) # Override context length from command line if provided if args.context_length is not None: metadata['context_length'] = args.context_length metadata['state_length'] = args.context_length print(f"\nOverriding context length from command line: {args.context_length}") # Load tokenizer tokenizer = initialize_tokenizer(args.tokenizer) if tokenizer is None: raise RuntimeError("Failed to initialize tokenizer") # Determine number of tokens for the test if args.test_full_context: num_test_tokens = metadata['context_length'] print(f"\nTesting with full context length: {num_test_tokens} tokens") elif args.test_tokens is not None: num_test_tokens = args.test_tokens print(f"\nTesting with specified tokens: {num_test_tokens} tokens") else: num_test_tokens = args.batch_size # Default to batch size print(f"\nTesting with default tokens (batch size): {num_test_tokens} tokens") # Ensure test tokens do not exceed context length if num_test_tokens > metadata['context_length']: print(f"\nWarning: Requested test tokens ({num_test_tokens}) exceed context length ({metadata['context_length']}).") print(f"Clamping test tokens to context length.") num_test_tokens = metadata['context_length'] # Run prefill speed test test_prefill_speed( embed_model=embed_model, ffn_models=ffn_models, tokenizer=tokenizer, batch_size=args.batch_size, # Pass original batch_size for run_prefill internal logic context_length=metadata['context_length'], num_test_tokens=num_test_tokens, # Pass the number of tokens to actually test num_runs=args.runs, test_single_chunk=not args.no_single_chunk ) except Exception as e: print(f"\nError: {str(e)}") import traceback traceback.print_exc() return 1 return 0 if __name__ == "__main__": exit(main())