WCNegentropy commited on
Commit
8414e94
·
verified ·
1 Parent(s): e2ef423

🚀 Final optimization: Update types.py with production-ready enhancements

Browse files
Files changed (1) hide show
  1. bit_transformer/types.py +117 -0
bit_transformer/types.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Type definitions and type aliases for BitTransformerLM.
3
+
4
+ Provides standardized type hints and common type aliases used throughout the codebase.
5
+ """
6
+
7
+ from __future__ import annotations
8
+ from typing import Union, List, Dict, Tuple, Optional, Any, Callable, Protocol
9
+ from pathlib import Path
10
+ import torch
11
+ import numpy as np
12
+
13
+ # Common tensor types
14
+ TensorLike = Union[torch.Tensor, np.ndarray, List[float], List[int]]
15
+ DeviceType = Union[str, torch.device]
16
+ DtypeType = Union[torch.dtype, type, str]
17
+
18
+ # Bit sequence types
19
+ BitSequence = List[int] # List of 0s and 1s
20
+ BitTensor = torch.Tensor # Tensor containing 0s and 1s
21
+ BitBatch = Union[List[BitSequence], torch.Tensor]
22
+
23
+ # Model types
24
+ ModelOutput = Union[torch.Tensor, Tuple[torch.Tensor, ...]]
25
+ TelemetryDict = Dict[str, Union[float, List[float], torch.Tensor]]
26
+ SafetyMetrics = Dict[str, float]
27
+
28
+ # File and path types
29
+ PathLike = Union[str, Path]
30
+ OptionalPath = Optional[PathLike]
31
+
32
+ # Training types
33
+ LossValue = Union[float, torch.Tensor]
34
+ OptimizerState = Dict[str, Any]
35
+ SchedulerState = Dict[str, Any]
36
+
37
+ # Configuration types
38
+ ModelConfig = Dict[str, Any]
39
+ TrainingConfig = Dict[str, Any]
40
+ DatasetConfig = Dict[str, Any]
41
+
42
+ # HuggingFace types
43
+ HFRepoId = str
44
+ HFToken = Optional[str]
45
+
46
+ # Function type protocols
47
+ class ModelForward(Protocol):
48
+ """Protocol for model forward functions."""
49
+ def __call__(self,
50
+ inputs: BitTensor,
51
+ attention_mask: Optional[torch.Tensor] = None,
52
+ **kwargs) -> ModelOutput: ...
53
+
54
+ class LossFunction(Protocol):
55
+ """Protocol for loss functions."""
56
+ def __call__(self,
57
+ predictions: torch.Tensor,
58
+ targets: torch.Tensor) -> LossValue: ...
59
+
60
+ class MetricFunction(Protocol):
61
+ """Protocol for metric computation functions."""
62
+ def __call__(self,
63
+ predictions: torch.Tensor,
64
+ targets: torch.Tensor) -> float: ...
65
+
66
+ # Compression types
67
+ CompressedData = torch.Tensor
68
+ CompressionRatio = float
69
+
70
+ # Safety and telemetry types
71
+ NegentropyScore = float # K metric: 0 (random) to 1 (ordered)
72
+ ComplexityScore = float # C metric: LZ complexity proxy
73
+ SymbiosisScore = float # S metric: KL divergence alignment
74
+
75
+ SafetyThresholds = Dict[str, float]
76
+ TelemetryCallback = Callable[[TelemetryDict], None]
77
+
78
+ # Distributed training types
79
+ WorldSize = int
80
+ ProcessRank = int
81
+ DistributedConfig = Dict[str, Any]
82
+
83
+ # Quantization types
84
+ QuantizationConfig = Dict[str, Any]
85
+ QuantizedModel = torch.nn.Module
86
+
87
+ # Common type aliases for cleaner signatures
88
+ BatchSize = int
89
+ SequenceLength = int
90
+ VocabSize = int
91
+ HiddenSize = int
92
+ NumHeads = int
93
+ NumLayers = int
94
+
95
+ # Attention types
96
+ AttentionWeights = torch.Tensor
97
+ AttentionMask = Optional[torch.Tensor]
98
+ ChunkSize = Optional[int]
99
+
100
+ # Generation types
101
+ GenerationConfig = Dict[str, Any]
102
+ GeneratedSequence = BitSequence
103
+ GenerationCallback = Callable[[GeneratedSequence], None]
104
+
105
+ # Diffusion types
106
+ NoiseSchedule = str # 'linear', 'cosine', 'exponential'
107
+ DiffusionSteps = int
108
+ DiffusionConfig = Dict[str, Any]
109
+
110
+ # Error handling types
111
+ ErrorHandler = Callable[[Exception], None]
112
+ RecoveryStrategy = Callable[[], Any]
113
+
114
+ # Logging types
115
+ LogLevel = str # 'DEBUG', 'INFO', 'WARNING', 'ERROR'
116
+ LogMessage = str
117
+ Logger = Any # To avoid circular import with logging module