Spaces:
Runtime error
Runtime error
""" | |
Model tracing evaluation for computing p-values from neuron matching statistics. | |
This module runs the model-tracing comparison using the main.py script from model-tracing | |
to determine structural similarity via p-value analysis. | |
""" | |
import os | |
import sys | |
import subprocess | |
import tempfile | |
import pickle | |
import statistics | |
# Check if model-tracing directory exists | |
model_tracing_path = os.path.join(os.path.dirname(__file__), '../../model-tracing') | |
MODEL_TRACING_AVAILABLE = os.path.exists(model_tracing_path) and os.path.exists(os.path.join(model_tracing_path, 'main.py')) | |
sys.stderr.write("π§ CHECKING MODEL TRACING AVAILABILITY...\n") | |
sys.stderr.write(f" - Model tracing path: {model_tracing_path}\n") | |
sys.stderr.write(f" - Path exists: {os.path.exists(model_tracing_path)}\n") | |
sys.stderr.write(f" - main.py exists: {os.path.exists(os.path.join(model_tracing_path, 'main.py'))}\n") | |
sys.stderr.write(f"π― Final MODEL_TRACING_AVAILABLE = {MODEL_TRACING_AVAILABLE}\n") | |
sys.stderr.flush() | |
def run_model_trace_analysis(ft_model_name, revision="main", precision="float16"): | |
""" | |
Run model tracing analysis using the main.py script from model-tracing directory. | |
Runs the exact command: | |
python main.py --base_model_id meta-llama/Llama-2-7b-hf --ft_model_id <ft_model_name> --stat match --align | |
Args: | |
ft_model_name: HuggingFace model identifier for the fine-tuned model | |
revision: Model revision/commit hash | |
precision: Model precision (float16, bfloat16) | |
Returns: | |
tuple: (success: bool, result: float or error_message) | |
If success, result is the aggregate p-value from aligned test stat | |
If failure, result is error message | |
""" | |
if not MODEL_TRACING_AVAILABLE: | |
return False, "Model tracing main.py script not available" | |
try: | |
sys.stderr.write(f"\n=== RUNNING MODEL TRACE ANALYSIS VIA SUBPROCESS ===\n") | |
sys.stderr.write(f"Base model: meta-llama/Llama-2-7b-hf\n") | |
sys.stderr.write(f"Fine-tuned model: {ft_model_name}\n") | |
sys.stderr.write(f"Revision: {revision}\n") | |
sys.stderr.write(f"Precision: {precision}\n") | |
sys.stderr.flush() | |
# Create a temporary file for results | |
with tempfile.NamedTemporaryFile(suffix='.pkl', delete=False) as tmp_file: | |
tmp_results_path = tmp_file.name | |
sys.stderr.write(f"π Temporary results file: {tmp_results_path}\n") | |
sys.stderr.flush() | |
# Build the command exactly as user specified | |
base_model_id = "meta-llama/Llama-2-7b-hf" | |
# Build the command | |
cmd = [ | |
"python", "main.py", | |
"--base_model_id", base_model_id, | |
"--ft_model_id", ft_model_name, | |
"--stat", "match", | |
"--save", tmp_results_path | |
] | |
# Add revision if not main/default | |
if revision and revision != "main": | |
# Note: main.py doesn't seem to have a revision flag, but we log it for reference | |
sys.stderr.write(f"β οΈ Note: Revision '{revision}' specified but main.py doesn't support --revision flag\n") | |
sys.stderr.flush() | |
sys.stderr.write(f"π Running command: {' '.join(cmd)}\n") | |
sys.stderr.flush() | |
# Change to model-tracing directory and run the command | |
original_cwd = os.getcwd() | |
try: | |
os.chdir(model_tracing_path) | |
sys.stderr.write(f"π Changed to directory: {model_tracing_path}\n") | |
sys.stderr.flush() | |
# Run the subprocess | |
result = subprocess.run( | |
cmd, | |
capture_output=True, | |
text=True, | |
timeout=3600 # 1 hour timeout | |
) | |
sys.stderr.write(f"π Subprocess completed with return code: {result.returncode}\n") | |
# Log stdout and stderr from the subprocess | |
if result.stdout: | |
sys.stderr.write(f"π STDOUT from model tracing:\n{result.stdout}\n") | |
if result.stderr: | |
sys.stderr.write(f"β οΈ STDERR from model tracing:\n{result.stderr}\n") | |
sys.stderr.flush() | |
if result.returncode != 0: | |
error_msg = f"Model tracing script failed with return code {result.returncode}" | |
if result.stderr: | |
error_msg += f"\nSTDERR: {result.stderr}" | |
return False, error_msg | |
finally: | |
os.chdir(original_cwd) | |
sys.stderr.write(f"π Changed back to directory: {original_cwd}\n") | |
sys.stderr.flush() | |
# Load and parse the results | |
try: | |
sys.stderr.write(f"π Loading results from: {tmp_results_path}\n") | |
sys.stderr.flush() | |
with open(tmp_results_path, 'rb') as f: | |
results = pickle.load(f) | |
sys.stderr.write(f"β Results loaded successfully\n") | |
sys.stderr.write(f"π Available result keys: {list(results.keys())}\n") | |
sys.stderr.flush() | |
# Get the aligned test stat (this is what we want with --align flag) | |
if "aligned test stat" in results: | |
aligned_stat = results["aligned test stat"] | |
sys.stderr.write(f"π Aligned test stat: {aligned_stat}\n") | |
sys.stderr.write(f"π Type: {type(aligned_stat)}\n") | |
# The match statistic returns a list of p-values per layer | |
if isinstance(aligned_stat, list): | |
sys.stderr.write(f"π List of {len(aligned_stat)} p-values: {aligned_stat}\n") | |
# Filter valid p-values | |
valid_p_values = [p for p in aligned_stat if p is not None and isinstance(p, (int, float)) and 0 <= p <= 1] | |
sys.stderr.write(f"π Valid p-values: {len(valid_p_values)}/{len(aligned_stat)}\n") | |
if valid_p_values: | |
# Use median as the representative p-value | |
aggregate_p_value = statistics.median(valid_p_values) | |
sys.stderr.write(f"π Using median p-value: {aggregate_p_value}\n") | |
else: | |
sys.stderr.write("β οΈ No valid p-values found, using default\n") | |
aggregate_p_value = 1.0 | |
elif isinstance(aligned_stat, (int, float)): | |
aggregate_p_value = float(aligned_stat) | |
sys.stderr.write(f"π Using single p-value: {aggregate_p_value}\n") | |
else: | |
sys.stderr.write(f"β οΈ Unexpected aligned_stat type: {type(aligned_stat)}, using default\n") | |
aggregate_p_value = 1.0 | |
else: | |
sys.stderr.write("β οΈ No 'aligned test stat' found in results, checking non-aligned\n") | |
if "non-aligned test stat" in results: | |
non_aligned_stat = results["non-aligned test stat"] | |
sys.stderr.write(f"π Using non-aligned test stat: {non_aligned_stat}\n") | |
if isinstance(non_aligned_stat, list): | |
valid_p_values = [p for p in non_aligned_stat if p is not None and isinstance(p, (int, float)) and 0 <= p <= 1] | |
if valid_p_values: | |
aggregate_p_value = statistics.median(valid_p_values) | |
else: | |
aggregate_p_value = 1.0 | |
else: | |
aggregate_p_value = float(non_aligned_stat) if isinstance(non_aligned_stat, (int, float)) else 1.0 | |
else: | |
sys.stderr.write("β No test stat found in results\n") | |
return False, "No test statistic found in results" | |
sys.stderr.flush() | |
except Exception as e: | |
sys.stderr.write(f"β Failed to load results: {e}\n") | |
sys.stderr.flush() | |
return False, f"Failed to load results: {e}" | |
finally: | |
# Clean up temporary file | |
try: | |
os.unlink(tmp_results_path) | |
sys.stderr.write(f"ποΈ Cleaned up temporary file: {tmp_results_path}\n") | |
except: | |
pass | |
sys.stderr.write(f"β Final aggregate p-value: {aggregate_p_value}\n") | |
sys.stderr.write("=== MODEL TRACE ANALYSIS COMPLETED ===\n") | |
sys.stderr.flush() | |
return True, aggregate_p_value | |
except subprocess.TimeoutExpired: | |
sys.stderr.write("β Model tracing analysis timed out after 1 hour\n") | |
sys.stderr.flush() | |
return False, "Analysis timed out" | |
except Exception as e: | |
error_msg = str(e) | |
sys.stderr.write(f"π₯ Error in model trace analysis: {error_msg}\n") | |
import traceback | |
sys.stderr.write(f"Traceback: {traceback.format_exc()}\n") | |
sys.stderr.flush() | |
return False, error_msg | |
def compute_model_trace_p_value(model_name, revision="main", precision="float16"): | |
""" | |
Wrapper function to compute model trace p-value for a single model. | |
Args: | |
model_name: HuggingFace model identifier | |
revision: Model revision | |
precision: Model precision | |
Returns: | |
float or None: P-value if successful, None if failed | |
""" | |
sys.stderr.write(f"\n{'='*60}\n") | |
sys.stderr.write(f"COMPUTE_MODEL_TRACE_P_VALUE CALLED\n") | |
sys.stderr.write(f"Model: {model_name}\n") | |
sys.stderr.write(f"Revision: {revision}\n") | |
sys.stderr.write(f"Precision: {precision}\n") | |
sys.stderr.write(f"Model tracing available: {MODEL_TRACING_AVAILABLE}\n") | |
sys.stderr.write(f"{'='*60}\n") | |
sys.stderr.flush() | |
if not MODEL_TRACING_AVAILABLE: | |
sys.stderr.write("β MODEL TRACING NOT AVAILABLE - returning None\n") | |
sys.stderr.flush() | |
return None | |
try: | |
sys.stderr.write("π Starting model trace analysis...\n") | |
sys.stderr.flush() | |
success, result = run_model_trace_analysis(model_name, revision, precision) | |
sys.stderr.write(f"π Analysis completed - Success: {success}, Result: {result}\n") | |
sys.stderr.flush() | |
if success: | |
sys.stderr.write(f"β SUCCESS: Returning p-value {result}\n") | |
sys.stderr.flush() | |
return result | |
else: | |
sys.stderr.write(f"β FAILED: {result}\n") | |
sys.stderr.write("π Returning None as fallback\n") | |
sys.stderr.flush() | |
return None | |
except Exception as e: | |
sys.stderr.write(f"π₯ CRITICAL ERROR in compute_model_trace_p_value for {model_name}:\n") | |
sys.stderr.write(f"Exception: {e}\n") | |
import traceback | |
sys.stderr.write(f"Full traceback:\n{traceback.format_exc()}\n") | |
sys.stderr.write("π Returning None as fallback\n") | |
sys.stderr.flush() | |
return None |