#!/usr/bin/env python """ Script to distill Alibaba-NLP/gte-Qwen2-7B-instruct using Model2Vec. This script performs the following operations: 1. Downloads the Alibaba-NLP/gte-Qwen2-7B-instruct model 2. Distills it using Model2Vec to create a smaller, faster static model 3. Saves the distilled model for further use """ import logging import shutil import time from pathlib import Path from model2vec.distill import distill # ============================================================================= # CONFIGURATION CONSTANTS # ============================================================================= # Model Configuration MODEL_NAME = "Alibaba-NLP/gte-Qwen2-7B-instruct" # Model name or path for the source model OUTPUT_DIR = "." # Directory to save the distilled model (current directory) PCA_DIMS = 256 # Dimensions for PCA reduction (smaller = faster but less accurate) # Hub Configuration SAVE_TO_HUB = False # Whether to push the model to HuggingFace Hub HUB_MODEL_ID = None # Model ID for HuggingFace Hub (if saving to hub) # Generation Configuration SKIP_README = True # Skip generating the README file # ============================================================================= # Configure logging logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) def main() -> None: """Run the distillation process for Alibaba-NLP/gte-Qwen2-7B-instruct.""" # Create output directory if it doesn't exist output_dir = Path(OUTPUT_DIR) output_dir.mkdir(parents=True, exist_ok=True) logger.info(f"Starting distillation of {MODEL_NAME}") logger.info(f"Distilled model will be saved to {output_dir}") logger.info(f"Using PCA dimensions: {PCA_DIMS}") logger.info(f"Skipping README generation: {SKIP_README}") # Record start time for benchmarking start_time = time.time() # Run the distillation try: logger.info("Starting Model2Vec distillation...") m2v_model = distill( model_name=MODEL_NAME, pca_dims=PCA_DIMS, ) distill_time = time.time() - start_time logger.info(f"Distillation completed in {distill_time:.2f} seconds") # Save the distilled model m2v_model.save_pretrained(OUTPUT_DIR) logger.info(f"Model saved to {OUTPUT_DIR}") # Remove README.md if it was created and we want to skip it if SKIP_README and (output_dir / "README.md").exists(): (output_dir / "README.md").unlink() logger.info("Removed auto-generated README.md") # Get model size information model_size_mb = sum( f.stat().st_size for f in output_dir.glob("**/*") if f.is_file() and f.name != "README.md" ) / (1024 * 1024) logger.info(f"Distilled model size: {model_size_mb:.2f} MB") # Push to hub if requested if SAVE_TO_HUB: if HUB_MODEL_ID: logger.info(f"Pushing model to HuggingFace Hub as {HUB_MODEL_ID}") # Create a temporary README for Hub upload if needed readme_path = output_dir / "README.md" had_readme = readme_path.exists() if SKIP_README and had_readme: # Backup the README shutil.move(readme_path, output_dir / "README.md.bak") # Push to Hub m2v_model.push_to_hub(HUB_MODEL_ID) # Restore state if SKIP_README: if had_readme: # Restore the backup shutil.move(output_dir / "README.md.bak", readme_path) elif (output_dir / "README.md").exists(): # Remove README created during push_to_hub (output_dir / "README.md").unlink() else: logger.error("HUB_MODEL_ID must be specified when SAVE_TO_HUB is True") logger.info("Distillation process completed successfully!") except Exception: logger.exception("Error during distillation") raise if __name__ == "__main__": main()