model_trace / src /evaluation /model_trace_eval.py
Ahmed Ahmed
ok
4864926
"""
Model tracing evaluation for computing p-values from neuron matching statistics.
This module runs the model-tracing comparison using the main.py script from model-tracing
to determine structural similarity via p-value analysis.
"""
import os
import sys
import subprocess
import tempfile
import pickle
import statistics
# Check if model-tracing directory exists
model_tracing_path = os.path.join(os.path.dirname(__file__), '../../model-tracing')
MODEL_TRACING_AVAILABLE = os.path.exists(model_tracing_path) and os.path.exists(os.path.join(model_tracing_path, 'main.py'))
sys.stderr.write("πŸ”§ CHECKING MODEL TRACING AVAILABILITY...\n")
sys.stderr.write(f" - Model tracing path: {model_tracing_path}\n")
sys.stderr.write(f" - Path exists: {os.path.exists(model_tracing_path)}\n")
sys.stderr.write(f" - main.py exists: {os.path.exists(os.path.join(model_tracing_path, 'main.py'))}\n")
sys.stderr.write(f"🎯 Final MODEL_TRACING_AVAILABLE = {MODEL_TRACING_AVAILABLE}\n")
sys.stderr.flush()
def run_model_trace_analysis(ft_model_name, revision="main", precision="float16"):
"""
Run model tracing analysis using the main.py script from model-tracing directory.
Runs the exact command:
python main.py --base_model_id meta-llama/Llama-2-7b-hf --ft_model_id <ft_model_name> --stat match --align
Args:
ft_model_name: HuggingFace model identifier for the fine-tuned model
revision: Model revision/commit hash
precision: Model precision (float16, bfloat16)
Returns:
tuple: (success: bool, result: float or error_message)
If success, result is the aggregate p-value from aligned test stat
If failure, result is error message
"""
if not MODEL_TRACING_AVAILABLE:
return False, "Model tracing main.py script not available"
try:
sys.stderr.write(f"\n=== RUNNING MODEL TRACE ANALYSIS VIA SUBPROCESS ===\n")
sys.stderr.write(f"Base model: meta-llama/Llama-2-7b-hf\n")
sys.stderr.write(f"Fine-tuned model: {ft_model_name}\n")
sys.stderr.write(f"Revision: {revision}\n")
sys.stderr.write(f"Precision: {precision}\n")
sys.stderr.flush()
# Create a temporary file for results
with tempfile.NamedTemporaryFile(suffix='.pkl', delete=False) as tmp_file:
tmp_results_path = tmp_file.name
sys.stderr.write(f"πŸ“ Temporary results file: {tmp_results_path}\n")
sys.stderr.flush()
# Build the command exactly as user specified
base_model_id = "meta-llama/Llama-2-7b-hf"
# Build the command
cmd = [
"python", "main.py",
"--base_model_id", base_model_id,
"--ft_model_id", ft_model_name,
"--stat", "match",
"--save", tmp_results_path
]
# Add revision if not main/default
if revision and revision != "main":
# Note: main.py doesn't seem to have a revision flag, but we log it for reference
sys.stderr.write(f"⚠️ Note: Revision '{revision}' specified but main.py doesn't support --revision flag\n")
sys.stderr.flush()
sys.stderr.write(f"πŸš€ Running command: {' '.join(cmd)}\n")
sys.stderr.flush()
# Change to model-tracing directory and run the command
original_cwd = os.getcwd()
try:
os.chdir(model_tracing_path)
sys.stderr.write(f"πŸ“‚ Changed to directory: {model_tracing_path}\n")
sys.stderr.flush()
# Run the subprocess
result = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=3600 # 1 hour timeout
)
sys.stderr.write(f"πŸ“Š Subprocess completed with return code: {result.returncode}\n")
# Log stdout and stderr from the subprocess
if result.stdout:
sys.stderr.write(f"πŸ“ STDOUT from model tracing:\n{result.stdout}\n")
if result.stderr:
sys.stderr.write(f"⚠️ STDERR from model tracing:\n{result.stderr}\n")
sys.stderr.flush()
if result.returncode != 0:
error_msg = f"Model tracing script failed with return code {result.returncode}"
if result.stderr:
error_msg += f"\nSTDERR: {result.stderr}"
return False, error_msg
finally:
os.chdir(original_cwd)
sys.stderr.write(f"πŸ“‚ Changed back to directory: {original_cwd}\n")
sys.stderr.flush()
# Load and parse the results
try:
sys.stderr.write(f"πŸ“– Loading results from: {tmp_results_path}\n")
sys.stderr.flush()
with open(tmp_results_path, 'rb') as f:
results = pickle.load(f)
sys.stderr.write(f"βœ… Results loaded successfully\n")
sys.stderr.write(f"πŸ“‹ Available result keys: {list(results.keys())}\n")
sys.stderr.flush()
# Get the aligned test stat (this is what we want with --align flag)
if "aligned test stat" in results:
aligned_stat = results["aligned test stat"]
sys.stderr.write(f"πŸ“Š Aligned test stat: {aligned_stat}\n")
sys.stderr.write(f"πŸ“Š Type: {type(aligned_stat)}\n")
# The match statistic returns a list of p-values per layer
if isinstance(aligned_stat, list):
sys.stderr.write(f"πŸ“Š List of {len(aligned_stat)} p-values: {aligned_stat}\n")
# Filter valid p-values
valid_p_values = [p for p in aligned_stat if p is not None and isinstance(p, (int, float)) and 0 <= p <= 1]
sys.stderr.write(f"πŸ“Š Valid p-values: {len(valid_p_values)}/{len(aligned_stat)}\n")
if valid_p_values:
# Use median as the representative p-value
aggregate_p_value = statistics.median(valid_p_values)
sys.stderr.write(f"πŸ“Š Using median p-value: {aggregate_p_value}\n")
else:
sys.stderr.write("⚠️ No valid p-values found, using default\n")
aggregate_p_value = 1.0
elif isinstance(aligned_stat, (int, float)):
aggregate_p_value = float(aligned_stat)
sys.stderr.write(f"πŸ“Š Using single p-value: {aggregate_p_value}\n")
else:
sys.stderr.write(f"⚠️ Unexpected aligned_stat type: {type(aligned_stat)}, using default\n")
aggregate_p_value = 1.0
else:
sys.stderr.write("⚠️ No 'aligned test stat' found in results, checking non-aligned\n")
if "non-aligned test stat" in results:
non_aligned_stat = results["non-aligned test stat"]
sys.stderr.write(f"πŸ“Š Using non-aligned test stat: {non_aligned_stat}\n")
if isinstance(non_aligned_stat, list):
valid_p_values = [p for p in non_aligned_stat if p is not None and isinstance(p, (int, float)) and 0 <= p <= 1]
if valid_p_values:
aggregate_p_value = statistics.median(valid_p_values)
else:
aggregate_p_value = 1.0
else:
aggregate_p_value = float(non_aligned_stat) if isinstance(non_aligned_stat, (int, float)) else 1.0
else:
sys.stderr.write("❌ No test stat found in results\n")
return False, "No test statistic found in results"
sys.stderr.flush()
except Exception as e:
sys.stderr.write(f"❌ Failed to load results: {e}\n")
sys.stderr.flush()
return False, f"Failed to load results: {e}"
finally:
# Clean up temporary file
try:
os.unlink(tmp_results_path)
sys.stderr.write(f"πŸ—‘οΈ Cleaned up temporary file: {tmp_results_path}\n")
except:
pass
sys.stderr.write(f"βœ… Final aggregate p-value: {aggregate_p_value}\n")
sys.stderr.write("=== MODEL TRACE ANALYSIS COMPLETED ===\n")
sys.stderr.flush()
return True, aggregate_p_value
except subprocess.TimeoutExpired:
sys.stderr.write("❌ Model tracing analysis timed out after 1 hour\n")
sys.stderr.flush()
return False, "Analysis timed out"
except Exception as e:
error_msg = str(e)
sys.stderr.write(f"πŸ’₯ Error in model trace analysis: {error_msg}\n")
import traceback
sys.stderr.write(f"Traceback: {traceback.format_exc()}\n")
sys.stderr.flush()
return False, error_msg
def compute_model_trace_p_value(model_name, revision="main", precision="float16"):
"""
Wrapper function to compute model trace p-value for a single model.
Args:
model_name: HuggingFace model identifier
revision: Model revision
precision: Model precision
Returns:
float or None: P-value if successful, None if failed
"""
sys.stderr.write(f"\n{'='*60}\n")
sys.stderr.write(f"COMPUTE_MODEL_TRACE_P_VALUE CALLED\n")
sys.stderr.write(f"Model: {model_name}\n")
sys.stderr.write(f"Revision: {revision}\n")
sys.stderr.write(f"Precision: {precision}\n")
sys.stderr.write(f"Model tracing available: {MODEL_TRACING_AVAILABLE}\n")
sys.stderr.write(f"{'='*60}\n")
sys.stderr.flush()
if not MODEL_TRACING_AVAILABLE:
sys.stderr.write("❌ MODEL TRACING NOT AVAILABLE - returning None\n")
sys.stderr.flush()
return None
try:
sys.stderr.write("πŸš€ Starting model trace analysis...\n")
sys.stderr.flush()
success, result = run_model_trace_analysis(model_name, revision, precision)
sys.stderr.write(f"πŸ“Š Analysis completed - Success: {success}, Result: {result}\n")
sys.stderr.flush()
if success:
sys.stderr.write(f"βœ… SUCCESS: Returning p-value {result}\n")
sys.stderr.flush()
return result
else:
sys.stderr.write(f"❌ FAILED: {result}\n")
sys.stderr.write("πŸ”„ Returning None as fallback\n")
sys.stderr.flush()
return None
except Exception as e:
sys.stderr.write(f"πŸ’₯ CRITICAL ERROR in compute_model_trace_p_value for {model_name}:\n")
sys.stderr.write(f"Exception: {e}\n")
import traceback
sys.stderr.write(f"Full traceback:\n{traceback.format_exc()}\n")
sys.stderr.write("πŸ”„ Returning None as fallback\n")
sys.stderr.flush()
return None