"""Command-line interface entry points for BitTransformerLM.""" import sys import logging from pathlib import Path from typing import Optional import torch from .cli_standards import create_training_parser, create_inference_parser, BitTransformerCLI from .config import ( ExperimentConfig, ModelConfig, TrainingConfig, SafetyConfig, DataConfig, get_small_config, get_medium_config, get_large_config, ) from .model import BitTransformerLM, diffusion_inference from .training import train_loop from .bit_io import text_to_bits, bits_to_text, infer_text from .utils import save_model, load_model from .dashboard_app import run_dashboard def setup_logging(level: str = "INFO") -> None: """Setup logging configuration.""" logging.basicConfig( level=getattr(logging, level.upper()), format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", handlers=[ logging.StreamHandler(sys.stdout), ], ) def train_cli() -> None: """CLI entry point for training BitTransformerLM models.""" parser = create_training_parser() args = parser.parse_args() setup_logging(args.log_level) logger = logging.getLogger(__name__) # Get preset configuration if specified if args.model_size == "small": config = get_small_config() elif args.model_size == "medium": config = get_medium_config() elif args.model_size == "large": config = get_large_config() else: config = ExperimentConfig() # Override with command line arguments config.model.d_model = args.d_model config.model.nhead = args.num_heads config.model.num_layers = args.num_layers config.model.max_seq_len = args.max_seq_len config.training.epochs = args.epochs config.training.batch_size = args.batch_size config.training.learning_rate = args.learning_rate config.training.weight_decay = args.weight_decay config.training.gradient_clip_val = args.grad_clip config.training.warmup_steps = args.warmup_steps config.training.amp = args.use_amp config.training.compile_model = args.compile_model config.safety.k_threshold = args.min_negentropy config.safety.c_threshold = args.max_complexity config.safety.s_threshold = args.min_symbiosis config.safety.enable_safety = args.enable_safety_gates config.data.dataset_path = Path(args.input_path) if args.input_path else None config.data.max_sequence_length = args.seq_length config.data.num_workers = args.num_workers config.output_dir = Path(args.output_path) config.seed = args.seed # Set device if torch.cuda.is_available(): config.device = "cuda" else: config.device = "cpu" logger.info(f"Starting training with config: {config.experiment_name}") logger.info(f"Model: {config.model.d_model}d, {config.model.num_layers}L, {config.model.nhead}H") logger.info(f"Device: {config.device}") # Create model model = BitTransformerLM(**config.model.to_dict()) model = model.to(config.device) # Create synthetic dataset for demonstration logger.info("Creating synthetic training data...") torch.manual_seed(config.seed) data = torch.randint(0, 2, (args.dataset_size, config.data.max_sequence_length)) # Train model logger.info("Starting training...") try: train_loop( model, data, epochs=config.training.epochs, batch_size=config.training.batch_size, amp=config.training.amp, compile_model=config.training.compile_model, log=True, ) # Save model save_path = config.output_dir / "model_final.pt" save_model(model, save_path) logger.info(f"Model saved to {save_path}") except Exception as e: logger.error(f"Training failed: {e}") sys.exit(1) def infer_cli() -> None: """CLI entry point for BitTransformerLM inference.""" parser = create_inference_parser() parser.add_argument("--prompt", type=str, required=True, help="Text prompt for generation") parser.add_argument("--max-tokens", type=int, default=50, help="Maximum tokens to generate") parser.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature") parser.add_argument("--use-diffusion", action="store_true", help="Use diffusion mode") args = parser.parse_args() setup_logging(args.log_level) logger = logging.getLogger(__name__) # Load model if not Path(args.weights_path).exists(): logger.error(f"Model weights not found at {args.weights_path}") sys.exit(1) logger.info(f"Loading model from {args.weights_path}") model = load_model(args.weights_path) model.eval() # Set device device = "cuda" if torch.cuda.is_available() else "cpu" model = model.to(device) logger.info(f"Model loaded on {device}") logger.info(f"Prompt: {args.prompt}") try: if args.use_diffusion: # Diffusion inference logger.info("Using diffusion inference mode") prompt_bits = text_to_bits(args.prompt) length = len(prompt_bits) + args.max_tokens * 9 # Approximate generated_bits = diffusion_inference( model, length=length, steps=args.diffusion_steps, schedule=args.noise_schedule, ) result = bits_to_text(generated_bits[0].tolist()) else: # Standard autoregressive inference with safety if args.enable_safety_gates: result = infer_text( model, args.prompt, c_floor=args.max_complexity, s_floor=args.min_symbiosis, ) else: # Simple generation without safety gates from .bit_io import sample_text result = sample_text( model, args.prompt, max_new_tokens=args.max_tokens, temperature=args.temperature, ) print(f"\nGenerated text:\n{result}") except Exception as e: logger.error(f"Inference failed: {e}") sys.exit(1) def dashboard_cli() -> None: """CLI entry point for BitTransformerLM dashboard.""" parser = BitTransformerCLI.create_standard_parser( "BitTransformerLM Dashboard", ["io"] ) parser.add_argument("--host", type=str, default="127.0.0.1", help="Dashboard host") parser.add_argument("--port", type=int, default=7860, help="Dashboard port") parser.add_argument("--share", action="store_true", help="Create public link") args = parser.parse_args() setup_logging(args.log_level) logger = logging.getLogger(__name__) logger.info(f"Starting BitTransformerLM dashboard on {args.host}:{args.port}") try: run_dashboard( host=args.host, port=args.port, share=args.share, ) except Exception as e: logger.error(f"Dashboard failed to start: {e}") sys.exit(1) if __name__ == "__main__": # Simple dispatcher based on script name import os script_name = os.path.basename(sys.argv[0]) if "train" in script_name: train_cli() elif "infer" in script_name: infer_cli() elif "dashboard" in script_name: dashboard_cli() else: print("Available commands:") print(" bit-transformer-train - Train a BitTransformerLM model") print(" bit-transformer-infer - Run inference with a trained model") print(" bit-transformer-dashboard - Launch interactive dashboard") sys.exit(1)