Sarthak
commited on
Commit
Β·
0dbb356
1
Parent(s):
7837959
chore: update README and REPORT with performance insights and dataset changes
Browse filesThis commit adds a warning in the README regarding the performance degradation observed with C4 fine-tuning, recommending the use of basic distillation for optimal results. The REPORT has been updated to reflect revised performance metrics for the fine-tuned model, adjustments in average performance statistics, and the inclusion of new radar charts for model comparisons. Additionally, the dataset configuration has been modified to utilize the C4 dataset for tokenlearn featurization.
- NOTES.md +187 -0
- README.md +3 -0
- REPORT.md +88 -16
- analysis_charts/batch_size_scaling.png +2 -2
- analysis_charts/benchmark_performance.png +2 -2
- analysis_charts/efficiency_analysis.png +2 -2
- analysis_charts/language_heatmap.png +2 -2
- analysis_charts/memory_scaling.png +2 -2
- analysis_charts/model_comparison.png +2 -2
- analysis_charts/model_specifications.png +2 -2
- analysis_charts/peer_comparison.png +2 -2
- analysis_charts/radar_code_model2vec_all_mpnet_base_v2_fine_tuned.png +2 -2
- src/distiller/__main__.py +0 -4
- src/distiller/config.py +5 -7
- src/distiller/distill.py +23 -220
NOTES.md
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Research Notes: Performance Analysis of C4 Fine-tuning vs Base Distillation
|
2 |
+
|
3 |
+
## π Executive Summary
|
4 |
+
|
5 |
+
**Key Finding**: C4 fine-tuning significantly degraded performance across almost all metrics and programming languages compared to simple Model2Vec distillation.
|
6 |
+
|
7 |
+
**Recommendation**: Use simple Model2Vec distillation without additional training for optimal code embedding performance.
|
8 |
+
|
9 |
+
---
|
10 |
+
|
11 |
+
## π Overall Performance Degradation
|
12 |
+
|
13 |
+
The comparison between base distilled models and C4-fine-tuned models reveals substantial performance regression:
|
14 |
+
|
15 |
+
| Metric | Base Model | Fine-tuned Model | Performance Drop |
|
16 |
+
|--------|------------|------------------|------------------|
|
17 |
+
| **NDCG@10** | 0.7387 | 0.6147 | **-16.8%** |
|
18 |
+
| **MRR** | 0.7010 | 0.5720 | **-18.4%** |
|
19 |
+
| **Recall@5** | 0.8017 | 0.6950 | **-13.3%** |
|
20 |
+
| **Recall@1** | 0.6169 | 0.4650 | **-24.6%** |
|
21 |
+
|
22 |
+
**Impact**: Double-digit performance drops across all major retrieval metrics, with Recall@1 suffering the most severe degradation at nearly 25%.
|
23 |
+
|
24 |
+
---
|
25 |
+
|
26 |
+
## π Language-Specific Impact Analysis
|
27 |
+
|
28 |
+
The performance degradation varied significantly across programming languages, revealing interesting patterns about domain sensitivity:
|
29 |
+
|
30 |
+
### π¨ **Severely Affected Languages**
|
31 |
+
|
32 |
+
#### **Java** (Catastrophic degradation):
|
33 |
+
- **NDCG@10**: 0.7027 β 0.2820 (**-59.9%**)
|
34 |
+
- **MRR**: 0.6553 β 0.2419 (**-63.1%**)
|
35 |
+
- **Mean Rank**: 7.24 β 20.38 (almost 3x worse ranking)
|
36 |
+
- **Analysis**: Java suffered the most severe degradation, suggesting its documentation patterns are most incompatible with C4's web text distribution.
|
37 |
+
|
38 |
+
#### **PHP** (Major degradation):
|
39 |
+
- **NDCG@10**: 0.7055 β 0.4453 (**-36.9%**)
|
40 |
+
- **MRR**: 0.6631 β 0.3981 (**-40.0%**)
|
41 |
+
- **Analysis**: PHP's unique syntax and documentation style may have been particularly disrupted by general web text training.
|
42 |
+
|
43 |
+
### π **Moderately Affected Languages**
|
44 |
+
|
45 |
+
#### **Python** (Best preserved):
|
46 |
+
- **NDCG@10**: 0.9674 β 0.9219 (**-4.7%**)
|
47 |
+
- **MRR**: 0.9572 β 0.8964 (**-6.3%**)
|
48 |
+
- **Analysis**: Python showed the smallest degradation, likely due to its prevalence in web tutorials and documentation that might overlap with C4 content.
|
49 |
+
|
50 |
+
#### **Ruby** (Minor degradation):
|
51 |
+
- **NDCG@10**: 0.7287 β 0.7178 (**-1.5%**)
|
52 |
+
- **MRR**: 0.6869 β 0.6776 (**-1.4%**)
|
53 |
+
|
54 |
+
#### **Go** (Minor degradation):
|
55 |
+
- **NDCG@10**: 0.7529 β 0.7250 (**-3.7%**)
|
56 |
+
- **MRR**: 0.7059 β 0.6699 (**-5.1%**)
|
57 |
+
|
58 |
+
### β
**Single Improvement**
|
59 |
+
|
60 |
+
#### **JavaScript** (Slight improvement):
|
61 |
+
- **NDCG@10**: 0.5752 β 0.5959 (**+3.6%**)
|
62 |
+
- **MRR**: 0.5378 β 0.5481 (**+1.9%**)
|
63 |
+
- **Analysis**: JavaScript was the only language to show improvement, possibly due to extensive JavaScript content in web pages that align with C4's distribution.
|
64 |
+
|
65 |
+
---
|
66 |
+
|
67 |
+
## π Model Characteristics Comparison
|
68 |
+
|
69 |
+
| Aspect | Base Model | Fine-tuned Model | Change | Impact |
|
70 |
+
|--------|------------|------------------|--------|---------|
|
71 |
+
| **Parameters** | 7.56M | 9.38M | +24% larger | Increased complexity |
|
72 |
+
| **Disk Size** | 15.07MB | 36.94MB | +145% larger | Storage overhead |
|
73 |
+
| **Performance** | Superior | Inferior | Significantly worse | Counterproductive |
|
74 |
+
| **Efficiency** | High | Low | Worse per parameter | Resource waste |
|
75 |
+
|
76 |
+
**Key Insight**: The fine-tuned model is larger, more complex, and performs worseβa clear example of the "bigger is not always better" principle.
|
77 |
+
|
78 |
+
---
|
79 |
+
|
80 |
+
## π§ Root Cause Analysis
|
81 |
+
|
82 |
+
### 1. **π Domain Mismatch**
|
83 |
+
- **Problem**: C4 contains general web text (articles, forums, websites, news)
|
84 |
+
- **Impact**: Code documentation has fundamentally different linguistic patterns, vocabulary, and structure
|
85 |
+
- **Result**: Training on web text actively degraded code-specific knowledge
|
86 |
+
|
87 |
+
### 2. **π§ Catastrophic Forgetting**
|
88 |
+
- **Problem**: The model "forgot" code-specific embeddings during C4 training
|
89 |
+
- **Evidence**: Java and PHP were hit hardest (59.9% and 36.9% NDCG@10 drops respectively)
|
90 |
+
- **Mechanism**: New training overwrote previously learned code-specific representations
|
91 |
+
|
92 |
+
### 3. **π Distribution Shift**
|
93 |
+
- **Problem**: C4's token distribution is vastly different from code comments and documentation
|
94 |
+
- **Impact**: Model learned patterns that are irrelevant or harmful for code retrieval
|
95 |
+
- **Evidence**: Uniform degradation across most languages suggests systematic distribution mismatch
|
96 |
+
|
97 |
+
### 4. **βοΈ Training Methodology Issues**
|
98 |
+
- **Problem**: Tokenlearn training on C4 introduced noise rather than signal
|
99 |
+
- **Analysis**: The POTION approach works well for general text but fails for specialized domains
|
100 |
+
- **Conclusion**: Domain-agnostic training methods can be counterproductive
|
101 |
+
|
102 |
+
---
|
103 |
+
|
104 |
+
## π Performance vs Complexity Analysis
|
105 |
+
|
106 |
+
```
|
107 |
+
Performance Efficiency = NDCG@10 / Model_Size_MB
|
108 |
+
|
109 |
+
Base Model: 0.7387 / 15.07 = 0.049 (High efficiency)
|
110 |
+
Fine-tuned Model: 0.6147 / 36.94 = 0.017 (Low efficiency)
|
111 |
+
|
112 |
+
Efficiency Loss: 65.3%
|
113 |
+
```
|
114 |
+
|
115 |
+
The fine-tuned model is not only worse performing but also dramatically less efficient, representing a significant regression in both absolute and relative terms.
|
116 |
+
|
117 |
+
---
|
118 |
+
|
119 |
+
## π― Key Research Insights
|
120 |
+
|
121 |
+
### 1. **Domain Specificity Matters**
|
122 |
+
Code embeddings require domain-specific training data. General web text (C4) actively harms code retrieval performance.
|
123 |
+
|
124 |
+
### 2. **Language-Dependent Vulnerability**
|
125 |
+
Programming languages show different sensitivity to domain shift:
|
126 |
+
- **High vulnerability**: Java, PHP (enterprise/web languages)
|
127 |
+
- **Medium vulnerability**: Go, Ruby
|
128 |
+
- **Low vulnerability**: Python (ubiquitous in tutorials)
|
129 |
+
- **Potential benefit**: JavaScript (web-native language)
|
130 |
+
|
131 |
+
### 3. **Simple Distillation Superiority**
|
132 |
+
Model2Vec's simple distillation approach outperforms complex fine-tuning when training data is misaligned with the target domain.
|
133 |
+
|
134 |
+
### 4. **Training Data Quality > Quantity**
|
135 |
+
Using massive but irrelevant data (C4) is worse than using no additional training at all.
|
136 |
+
|
137 |
+
---
|
138 |
+
|
139 |
+
## π Actionable Recommendations
|
140 |
+
|
141 |
+
### β **What NOT to Do**
|
142 |
+
1. **Don't use C4 for code models**: General web text degrades code-specific performance
|
143 |
+
2. **Don't assume more training is better**: Additional training can be counterproductive
|
144 |
+
3. **Don't ignore domain alignment**: Training data must match target application domain
|
145 |
+
4. **Don't prioritize model size**: Larger models can perform worse if poorly trained
|
146 |
+
|
147 |
+
### β
**What TO Do**
|
148 |
+
1. **Stick to base distillation**: Simple Model2Vec distillation gives optimal results for code tasks
|
149 |
+
2. **Use code-specific datasets only**: If fine-tuning is needed, use CodeSearchNet or similar datasets
|
150 |
+
3. **Validate domain alignment**: Ensure training data distribution matches target use case
|
151 |
+
4. **Measure efficiency**: Consider performance per parameter, not just absolute performance
|
152 |
+
5. **Test incrementally**: Validate that each training step improves rather than degrades performance
|
153 |
+
|
154 |
+
### π¬ **Future Research Directions**
|
155 |
+
1. **Code-specific fine-tuning**: Investigate tokenlearn training with CodeSearchNet instead of C4
|
156 |
+
2. **Selective fine-tuning**: Apply additional training only to languages that show potential benefit (JavaScript)
|
157 |
+
3. **Hybrid approaches**: Combine base distillation with minimal, targeted code-specific training
|
158 |
+
4. **Domain adaptation techniques**: Develop methods to prevent catastrophic forgetting during domain transfer
|
159 |
+
|
160 |
+
---
|
161 |
+
|
162 |
+
## π Statistical Significance
|
163 |
+
|
164 |
+
All performance drops are substantial and consistent across metrics:
|
165 |
+
- **Minimum degradation**: 1.4% (Ruby MRR)
|
166 |
+
- **Maximum degradation**: 63.1% (Java MRR)
|
167 |
+
- **Median degradation**: ~15% across all metrics
|
168 |
+
- **Only improvement**: JavaScript (+3.6% NDCG@10)
|
169 |
+
|
170 |
+
**Conclusion**: The degradation is not due to random variation but represents a systematic failure of the C4 fine-tuning approach.
|
171 |
+
|
172 |
+
---
|
173 |
+
|
174 |
+
## π Lessons Learned
|
175 |
+
|
176 |
+
1. **Domain expertise beats scale**: Code-specific knowledge is more valuable than training on massive general datasets
|
177 |
+
2. **Validate training approaches**: Always compare against simpler baselines before deploying complex training pipelines
|
178 |
+
3. **Language-specific patterns matter**: Different programming languages have varying sensitivity to domain shift
|
179 |
+
4. **Efficiency is crucial**: Model performance per parameter is often more important than absolute performance
|
180 |
+
5. **Simple can be superior**: Sometimes the simplest approach (basic distillation) outperforms sophisticated alternatives
|
181 |
+
|
182 |
+
---
|
183 |
+
|
184 |
+
**Documentation Date**: December 2024
|
185 |
+
**Model Comparison**: `sentence-transformers/all-mpnet-base-v2` teacher β Model2Vec distillation vs Model2Vec + C4 tokenlearn fine-tuning
|
186 |
+
**Evaluation Dataset**: CodeSearchNet across 6 programming languages
|
187 |
+
**Key Finding**: Simple distillation outperforms complex fine-tuning by 16.8% NDCG@10 on average
|
README.md
CHANGED
@@ -71,6 +71,9 @@ pipeline_tag: feature-extraction
|
|
71 |
>[!Important]
|
72 |
>Check out the comprehensive [REPORT.md](REPORT.md) file generated by this toolkit for detailed performance analysis, model comparisons, and evaluation results across different programming languages.
|
73 |
|
|
|
|
|
|
|
74 |
The **distiller** package provides a complete pipeline for:
|
75 |
|
76 |
1. **Distilling code-specialized embeddings** from large sentence transformer models using Model2Vec
|
|
|
71 |
>[!Important]
|
72 |
>Check out the comprehensive [REPORT.md](REPORT.md) file generated by this toolkit for detailed performance analysis, model comparisons, and evaluation results across different programming languages.
|
73 |
|
74 |
+
>[!Warning]
|
75 |
+
>**Research Finding**: See [NOTES.md](NOTES.md) for critical analysis showing that C4 fine-tuning significantly degraded performance (-16.8% NDCG@10) compared to simple Model2Vec distillation. **Recommendation**: Use basic distillation without additional training for optimal code embedding performance.
|
76 |
+
|
77 |
The **distiller** package provides a complete pipeline for:
|
78 |
|
79 |
1. **Distilling code-specialized embeddings** from large sentence transformer models using Model2Vec
|
REPORT.md
CHANGED
@@ -29,7 +29,7 @@ This report presents a comprehensive analysis of Model2Vec distillation experime
|
|
29 |
| code_model2vec_jina_embeddings_v2_base_code | [jina-embeddings-v2-base-code](https://huggingface.co/jina-embeddings-v2-base-code) | 0.7381 | 0.6996 | 0.8130 | π₯ 3rd |
|
30 |
| code_model2vec_paraphrase_MiniLM_L6_v2 | [sentence-transformers/paraphrase-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/paraphrase-MiniLM-L6-v2) | 0.7013 | 0.6638 | 0.7665 | #4 |
|
31 |
| code_model2vec_Reason_ModernColBERT | [lightonai/Reason-ModernColBERT](https://huggingface.co/lightonai/Reason-ModernColBERT) | 0.6598 | 0.6228 | 0.7260 | #5 |
|
32 |
-
| code_model2vec_all_mpnet_base_v2_fine_tuned | [sentence-transformers/all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2) | 0.
|
33 |
| code_model2vec_bge_m3 | [BAAI/bge-m3](https://huggingface.co/BAAI/bge-m3) | 0.4863 | 0.4439 | 0.5514 | #7 |
|
34 |
| code_model2vec_jina_embeddings_v3 | [jinaai/jina-embeddings-v3](https://huggingface.co/jinaai/jina-embeddings-v3) | 0.4755 | 0.4416 | 0.5456 | #8 |
|
35 |
| code_model2vec_nomic_embed_text_v2_moe | [nomic-ai/nomic-embed-text-v2-moe](https://huggingface.co/nomic-ai/nomic-embed-text-v2-moe) | 0.4532 | 0.4275 | 0.5094 | #9 |
|
@@ -51,7 +51,7 @@ Our distilled models exhibit consistent architectural characteristics across dif
|
|
51 |
| jina_embeddings_v2_base_code | 61,053 | 15.6M | 256 | 29.8MB |
|
52 |
| paraphrase_MiniLM_L6_v2 | 29,525 | 7.6M | 256 | 14.4MB |
|
53 |
| Reason_ModernColBERT | 50,254 | 12.9M | 256 | 24.5MB |
|
54 |
-
| all_mpnet_base_v2_fine_tuned |
|
55 |
| bge_m3 | 249,999 | 64.0M | 256 | 122.1MB |
|
56 |
| jina_embeddings_v3 | 249,999 | 64.0M | 256 | 122.1MB |
|
57 |
| nomic_embed_text_v2_moe | 249,999 | 64.0M | 256 | 122.1MB |
|
@@ -69,9 +69,9 @@ Our distilled models exhibit consistent architectural characteristics across dif
|
|
69 |
#### Key Insights from Model Specifications:
|
70 |
|
71 |
|
72 |
-
- **Vocabulary Consistency**: All models use vocabulary sizes ranging from 29,525 to 249,999 tokens (avg: 101,
|
73 |
-
- **Parameter Efficiency**: Models range from 7.6M to 64.0M parameters (avg:
|
74 |
-
- **Storage Efficiency**: Disk usage ranges from 14.4MB to 122.1MB (avg: 50.
|
75 |
- **Embedding Dimensions**: Consistent 256 dimensions across all models (optimized for efficiency)
|
76 |
|
77 |
|
@@ -81,13 +81,85 @@ Our distilled models exhibit consistent architectural characteristics across dif
|
|
81 |
- **Best Teacher Model**: code_model2vec_all_mpnet_base_v2 (NDCG@10: 0.7387)
|
82 |
- **Least Effective Teacher**: code_model2vec_codebert_base (NDCG@10: 0.2779)
|
83 |
- **Performance Range**: 62.4% difference between best and worst
|
84 |
-
- **Average Performance**: 0.
|
85 |
|
86 |
|
87 |
## π― Language Performance Radar Charts
|
88 |
|
89 |
### Best Model vs Peer Models Comparison
|
90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
| Rank | Model | Type | NDCG@10 | MRR | Recall@5 |
|
92 |
|------|-------|------|---------|-----|----------|
|
93 |
| 1 | Alibaba-NLP/gte-Qwen2-1.5B-instruct | General | 0.9729 | 0.9676 | 0.9825 |
|
@@ -109,9 +181,9 @@ Our distilled models exhibit consistent architectural characteristics across dif
|
|
109 |
| 17 | code_model2vec_jina_embeddings_v2_base_code | **π₯ Simplified Distillation** | 0.7381 | 0.6996 | 0.8130 |
|
110 |
| 18 | code_model2vec_paraphrase_MiniLM_L6_v2 | **π₯ Simplified Distillation** | 0.7013 | 0.6638 | 0.7665 |
|
111 |
| 19 | code_model2vec_Reason_ModernColBERT | **π₯ Simplified Distillation** | 0.6598 | 0.6228 | 0.7260 |
|
112 |
-
| 20 |
|
113 |
-
| 21 |
|
114 |
-
| 22 |
|
115 |
| 23 | Salesforce/codet5-base | Code-Specific | 0.4872 | 0.4500 | 0.5742 |
|
116 |
| 24 | code_model2vec_bge_m3 | **π₯ Simplified Distillation** | 0.4863 | 0.4439 | 0.5514 |
|
117 |
| 25 | code_model2vec_jina_embeddings_v3 | **π₯ Simplified Distillation** | 0.4755 | 0.4416 | 0.5456 |
|
@@ -171,12 +243,12 @@ Our distilled models exhibit consistent architectural characteristics across dif
|
|
171 |
|
172 |
| Language | Best Model Performance | Average Performance | Language Difficulty |
|
173 |
|----------|------------------------|--------------------|--------------------|
|
174 |
-
| Go | 0.9780 | 0.
|
175 |
-
| Java | 0.9921 | 0.
|
176 |
-
| Javascript | 0.9550 | 0.
|
177 |
-
| Php | 1.0000 | 0.
|
178 |
-
| Python | 1.0000 | 0.
|
179 |
-
| Ruby | 0.9493 | 0.
|
180 |
|
181 |
|
182 |
## π― Conclusions and Recommendations
|
@@ -230,5 +302,5 @@ Based on the evaluation results across all simplified distillation models:
|
|
230 |
|
231 |
---
|
232 |
|
233 |
-
*Report generated on 2025-
|
234 |
*For questions about methodology or results, please refer to the CodeSearchNet documentation.*
|
|
|
29 |
| code_model2vec_jina_embeddings_v2_base_code | [jina-embeddings-v2-base-code](https://huggingface.co/jina-embeddings-v2-base-code) | 0.7381 | 0.6996 | 0.8130 | π₯ 3rd |
|
30 |
| code_model2vec_paraphrase_MiniLM_L6_v2 | [sentence-transformers/paraphrase-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/paraphrase-MiniLM-L6-v2) | 0.7013 | 0.6638 | 0.7665 | #4 |
|
31 |
| code_model2vec_Reason_ModernColBERT | [lightonai/Reason-ModernColBERT](https://huggingface.co/lightonai/Reason-ModernColBERT) | 0.6598 | 0.6228 | 0.7260 | #5 |
|
32 |
+
| code_model2vec_all_mpnet_base_v2_fine_tuned | [sentence-transformers/all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2) | 0.6147 | 0.5720 | 0.6950 | #6 |
|
33 |
| code_model2vec_bge_m3 | [BAAI/bge-m3](https://huggingface.co/BAAI/bge-m3) | 0.4863 | 0.4439 | 0.5514 | #7 |
|
34 |
| code_model2vec_jina_embeddings_v3 | [jinaai/jina-embeddings-v3](https://huggingface.co/jinaai/jina-embeddings-v3) | 0.4755 | 0.4416 | 0.5456 | #8 |
|
35 |
| code_model2vec_nomic_embed_text_v2_moe | [nomic-ai/nomic-embed-text-v2-moe](https://huggingface.co/nomic-ai/nomic-embed-text-v2-moe) | 0.4532 | 0.4275 | 0.5094 | #9 |
|
|
|
51 |
| jina_embeddings_v2_base_code | 61,053 | 15.6M | 256 | 29.8MB |
|
52 |
| paraphrase_MiniLM_L6_v2 | 29,525 | 7.6M | 256 | 14.4MB |
|
53 |
| Reason_ModernColBERT | 50,254 | 12.9M | 256 | 24.5MB |
|
54 |
+
| all_mpnet_base_v2_fine_tuned | 36,624 | 9.4M | 256 | 35.8MB |
|
55 |
| bge_m3 | 249,999 | 64.0M | 256 | 122.1MB |
|
56 |
| jina_embeddings_v3 | 249,999 | 64.0M | 256 | 122.1MB |
|
57 |
| nomic_embed_text_v2_moe | 249,999 | 64.0M | 256 | 122.1MB |
|
|
|
69 |
#### Key Insights from Model Specifications:
|
70 |
|
71 |
|
72 |
+
- **Vocabulary Consistency**: All models use vocabulary sizes ranging from 29,525 to 249,999 tokens (avg: 101,594)
|
73 |
+
- **Parameter Efficiency**: Models range from 7.6M to 64.0M parameters (avg: 26.0M)
|
74 |
+
- **Storage Efficiency**: Disk usage ranges from 14.4MB to 122.1MB (avg: 50.9MB)
|
75 |
- **Embedding Dimensions**: Consistent 256 dimensions across all models (optimized for efficiency)
|
76 |
|
77 |
|
|
|
81 |
- **Best Teacher Model**: code_model2vec_all_mpnet_base_v2 (NDCG@10: 0.7387)
|
82 |
- **Least Effective Teacher**: code_model2vec_codebert_base (NDCG@10: 0.2779)
|
83 |
- **Performance Range**: 62.4% difference between best and worst
|
84 |
+
- **Average Performance**: 0.5248 NDCG@10
|
85 |
|
86 |
|
87 |
## π― Language Performance Radar Charts
|
88 |
|
89 |
### Best Model vs Peer Models Comparison
|
90 |
|
91 |
+

|
92 |
+
|
93 |
+
*Comparative view showing how the best simplified distillation model performs against top peer models across programming languages.*
|
94 |
+
|
95 |
+
### Individual Model Performance by Language
|
96 |
+
|
97 |
+
#### code_model2vec_all_mpnet_base_v2 (Teacher: [sentence-transformers/all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2)) - NDCG@10: 0.7387
|
98 |
+
|
99 |
+

|
100 |
+
|
101 |
+
#### code_model2vec_all_MiniLM_L6_v2 (Teacher: [sentence-transformers/all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)) - NDCG@10: 0.7385
|
102 |
+
|
103 |
+

|
104 |
+
|
105 |
+
#### code_model2vec_jina_embeddings_v2_base_code (Teacher: [jina-embeddings-v2-base-code](https://huggingface.co/jina-embeddings-v2-base-code)) - NDCG@10: 0.7381
|
106 |
+
|
107 |
+

|
108 |
+
|
109 |
+
#### code_model2vec_paraphrase_MiniLM_L6_v2 (Teacher: [sentence-transformers/paraphrase-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/paraphrase-MiniLM-L6-v2)) - NDCG@10: 0.7013
|
110 |
+
|
111 |
+

|
112 |
+
|
113 |
+
#### code_model2vec_Reason_ModernColBERT (Teacher: [lightonai/Reason-ModernColBERT](https://huggingface.co/lightonai/Reason-ModernColBERT)) - NDCG@10: 0.6598
|
114 |
+
|
115 |
+

|
116 |
+
|
117 |
+
#### code_model2vec_all_mpnet_base_v2_fine_tuned (Teacher: [sentence-transformers/all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2)) - NDCG@10: 0.6147
|
118 |
+
|
119 |
+

|
120 |
+
|
121 |
+
#### code_model2vec_bge_m3 (Teacher: [BAAI/bge-m3](https://huggingface.co/BAAI/bge-m3)) - NDCG@10: 0.4863
|
122 |
+
|
123 |
+

|
124 |
+
|
125 |
+
#### code_model2vec_jina_embeddings_v3 (Teacher: [jinaai/jina-embeddings-v3](https://huggingface.co/jinaai/jina-embeddings-v3)) - NDCG@10: 0.4755
|
126 |
+
|
127 |
+

|
128 |
+
|
129 |
+
#### code_model2vec_nomic_embed_text_v2_moe (Teacher: [nomic-ai/nomic-embed-text-v2-moe](https://huggingface.co/nomic-ai/nomic-embed-text-v2-moe)) - NDCG@10: 0.4532
|
130 |
+
|
131 |
+

|
132 |
+
|
133 |
+
#### code_model2vec_gte_Qwen2_1.5B_instruct (Teacher: [Alibaba-NLP/gte-Qwen2-1.5B-instruct](https://huggingface.co/Alibaba-NLP/gte-Qwen2-1.5B-instruct)) - NDCG@10: 0.4238
|
134 |
+
|
135 |
+

|
136 |
+
|
137 |
+
#### code_model2vec_Qodo_Embed_1_1.5B (Teacher: [Qodo/Qodo-Embed-1-1.5B](https://huggingface.co/Qodo/Qodo-Embed-1-1.5B)) - NDCG@10: 0.4101
|
138 |
+
|
139 |
+

|
140 |
+
|
141 |
+
#### code_model2vec_graphcodebert_base (Teacher: [microsoft/codebert-base](https://huggingface.co/microsoft/codebert-base)) - NDCG@10: 0.3420
|
142 |
+
|
143 |
+

|
144 |
+
|
145 |
+
#### code_model2vec_Linq_Embed_Mistral (Teacher: [Linq-AI-Research/Linq-Embed-Mistral](https://huggingface.co/Linq-AI-Research/Linq-Embed-Mistral)) - NDCG@10: 0.2868
|
146 |
+
|
147 |
+

|
148 |
+
|
149 |
+
#### code_model2vec_codebert_base (Teacher: [microsoft/codebert-base](https://huggingface.co/microsoft/codebert-base)) - NDCG@10: 0.2779
|
150 |
+
|
151 |
+

|
152 |
+
|
153 |
+
|
154 |
+
|
155 |
+
## π Peer Model Comparison
|
156 |
+
|
157 |
+

|
158 |
+
|
159 |
+
*Comparison with established code-specialized embedding models using actual evaluation results.*
|
160 |
+
|
161 |
+
### Complete Model Ranking
|
162 |
+
|
163 |
| Rank | Model | Type | NDCG@10 | MRR | Recall@5 |
|
164 |
|------|-------|------|---------|-----|----------|
|
165 |
| 1 | Alibaba-NLP/gte-Qwen2-1.5B-instruct | General | 0.9729 | 0.9676 | 0.9825 |
|
|
|
181 |
| 17 | code_model2vec_jina_embeddings_v2_base_code | **π₯ Simplified Distillation** | 0.7381 | 0.6996 | 0.8130 |
|
182 |
| 18 | code_model2vec_paraphrase_MiniLM_L6_v2 | **π₯ Simplified Distillation** | 0.7013 | 0.6638 | 0.7665 |
|
183 |
| 19 | code_model2vec_Reason_ModernColBERT | **π₯ Simplified Distillation** | 0.6598 | 0.6228 | 0.7260 |
|
184 |
+
| 20 | code_model2vec_all_mpnet_base_v2_fine_tuned | **π Fine-tuned Distillation** | 0.6147 | 0.5720 | 0.6950 |
|
185 |
+
| 21 | potion-multilingual-128M | Model2Vec | 0.6124 | 0.5683 | 0.7017 |
|
186 |
+
| 22 | huggingface/CodeBERTa-small-v1 | Code-Specific | 0.5903 | 0.5350 | 0.6779 |
|
187 |
| 23 | Salesforce/codet5-base | Code-Specific | 0.4872 | 0.4500 | 0.5742 |
|
188 |
| 24 | code_model2vec_bge_m3 | **π₯ Simplified Distillation** | 0.4863 | 0.4439 | 0.5514 |
|
189 |
| 25 | code_model2vec_jina_embeddings_v3 | **π₯ Simplified Distillation** | 0.4755 | 0.4416 | 0.5456 |
|
|
|
243 |
|
244 |
| Language | Best Model Performance | Average Performance | Language Difficulty |
|
245 |
|----------|------------------------|--------------------|--------------------|
|
246 |
+
| Go | 0.9780 | 0.6960 | Easy |
|
247 |
+
| Java | 0.9921 | 0.6553 | Easy |
|
248 |
+
| Javascript | 0.9550 | 0.5850 | Easy |
|
249 |
+
| Php | 1.0000 | 0.6321 | Easy |
|
250 |
+
| Python | 1.0000 | 0.8623 | Easy |
|
251 |
+
| Ruby | 0.9493 | 0.6397 | Easy |
|
252 |
|
253 |
|
254 |
## π― Conclusions and Recommendations
|
|
|
302 |
|
303 |
---
|
304 |
|
305 |
+
*Report generated on 2025-06-01 08:04:06 using automated analysis pipeline.*
|
306 |
*For questions about methodology or results, please refer to the CodeSearchNet documentation.*
|
analysis_charts/batch_size_scaling.png
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
analysis_charts/benchmark_performance.png
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
analysis_charts/efficiency_analysis.png
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
analysis_charts/language_heatmap.png
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
analysis_charts/memory_scaling.png
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
analysis_charts/model_comparison.png
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
analysis_charts/model_specifications.png
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
analysis_charts/peer_comparison.png
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
analysis_charts/radar_code_model2vec_all_mpnet_base_v2_fine_tuned.png
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
src/distiller/__main__.py
CHANGED
@@ -23,9 +23,6 @@ def distill(
|
|
23 |
clear_checkpoints: Annotated[
|
24 |
bool, typer.Option(help="Clear tokenlearn checkpoints to force fresh featurization and training")
|
25 |
] = False,
|
26 |
-
skip_ptr: Annotated[
|
27 |
-
bool, typer.Option("--skip-ptr", help="Skip post-training re-regularization (PCA + SIF weighting) step")
|
28 |
-
] = False,
|
29 |
use_optimized_dataset: Annotated[
|
30 |
bool,
|
31 |
typer.Option(
|
@@ -48,7 +45,6 @@ def distill(
|
|
48 |
pca_dims,
|
49 |
clear_cache,
|
50 |
clear_checkpoints,
|
51 |
-
skip_ptr,
|
52 |
use_optimized_dataset,
|
53 |
dataset_path,
|
54 |
)
|
|
|
23 |
clear_checkpoints: Annotated[
|
24 |
bool, typer.Option(help="Clear tokenlearn checkpoints to force fresh featurization and training")
|
25 |
] = False,
|
|
|
|
|
|
|
26 |
use_optimized_dataset: Annotated[
|
27 |
bool,
|
28 |
typer.Option(
|
|
|
45 |
pca_dims,
|
46 |
clear_cache,
|
47 |
clear_checkpoints,
|
|
|
48 |
use_optimized_dataset,
|
49 |
dataset_path,
|
50 |
)
|
src/distiller/config.py
CHANGED
@@ -210,16 +210,14 @@ class DistillationConfig(BaseModel):
|
|
210 |
apply_zipf: bool = True
|
211 |
|
212 |
# Tokenlearn-specific parameters (POTION approach)
|
213 |
-
tokenlearn_dataset: str = "
|
214 |
-
tokenlearn_dataset_name: str = "
|
215 |
-
tokenlearn_text_key: str =
|
216 |
-
"combined_text" # Text field to use from the dataset ('combined_text' for doc-code pairs)
|
217 |
-
)
|
218 |
tokenlearn_timeout_featurize: int = 21600 # 6 hour timeout for featurization (dataset needs ~5 hours)
|
219 |
tokenlearn_timeout_train: int = 7200 # 2 hour timeout for training
|
220 |
|
221 |
-
#
|
222 |
-
|
223 |
|
224 |
# Dataset configuration
|
225 |
use_optimized_dataset: bool = True # Use the pre-created optimized dataset from dataset.py
|
|
|
210 |
apply_zipf: bool = True
|
211 |
|
212 |
# Tokenlearn-specific parameters (POTION approach)
|
213 |
+
tokenlearn_dataset: str = "allenai/c4" # Dataset for tokenlearn featurization (following POTION paper)
|
214 |
+
tokenlearn_dataset_name: str = "en" # Use 'en' configuration for English text
|
215 |
+
tokenlearn_text_key: str = "text" # Text field to use from the dataset
|
|
|
|
|
216 |
tokenlearn_timeout_featurize: int = 21600 # 6 hour timeout for featurization (dataset needs ~5 hours)
|
217 |
tokenlearn_timeout_train: int = 7200 # 2 hour timeout for training
|
218 |
|
219 |
+
# Dataset sampling configuration
|
220 |
+
tokenlearn_max_samples: int = 50000 # Maximum samples to use for tokenlearn training
|
221 |
|
222 |
# Dataset configuration
|
223 |
use_optimized_dataset: bool = True # Use the pre-created optimized dataset from dataset.py
|
src/distiller/distill.py
CHANGED
@@ -28,7 +28,6 @@ import time
|
|
28 |
from pathlib import Path
|
29 |
from typing import Annotated, Any
|
30 |
|
31 |
-
import numpy as np
|
32 |
import torch
|
33 |
import typer
|
34 |
from beam import function
|
@@ -410,7 +409,7 @@ def simple_distillation(
|
|
410 |
|
411 |
|
412 |
def load_optimized_dataset(
|
413 |
-
max_samples: int =
|
414 |
checkpoint_manager: BeamCheckpointManager | None = None,
|
415 |
dataset_path: str | None = None,
|
416 |
) -> list[str]:
|
@@ -424,6 +423,10 @@ def load_optimized_dataset(
|
|
424 |
|
425 |
dataset_dir = Path(dataset_path) if dataset_path else DATASET_OUTPUT_DIR
|
426 |
|
|
|
|
|
|
|
|
|
427 |
logger.info(f"π― Loading optimized dataset from {dataset_dir}")
|
428 |
logger.info(f"π Target samples: {max_samples}")
|
429 |
|
@@ -462,12 +465,16 @@ def load_optimized_dataset(
|
|
462 |
|
463 |
|
464 |
def load_codesearchnet_dataset(
|
465 |
-
max_samples: int =
|
466 |
checkpoint_manager: BeamCheckpointManager | None = None,
|
467 |
) -> list[str]:
|
468 |
"""Load and format the CodeSearchNet dataset for token frequency computation."""
|
469 |
from datasets import load_dataset
|
470 |
|
|
|
|
|
|
|
|
|
471 |
logger.info(f"Loading CodeSearchNet dataset from {codesearchnet_config.dataset_name}")
|
472 |
logger.info(f"Limiting to {max_samples} samples for training efficiency")
|
473 |
logger.info(f"Languages: {', '.join(languages_config.all)}")
|
@@ -732,192 +739,10 @@ def generate_teacher_embeddings(
|
|
732 |
return teacher_embeddings
|
733 |
|
734 |
|
735 |
-
def compute_token_frequencies_for_sif(
|
736 |
-
teacher_model: SentenceTransformer,
|
737 |
-
features_dir: Path,
|
738 |
-
) -> None:
|
739 |
-
"""
|
740 |
-
Compute token frequencies from the training corpus for SIF weighting.
|
741 |
-
|
742 |
-
This follows the POTION approach for post-training re-regularization.
|
743 |
-
"""
|
744 |
-
import json
|
745 |
-
from collections import Counter
|
746 |
-
|
747 |
-
logger.info("π Computing token frequencies for SIF weighting...")
|
748 |
-
|
749 |
-
try:
|
750 |
-
# Load dataset to compute frequencies (limited sample for efficiency)
|
751 |
-
if distillation_config.use_optimized_dataset:
|
752 |
-
# Use the custom optimized dataset
|
753 |
-
from .dataset import load_optimized_dataset as load_custom_dataset
|
754 |
-
|
755 |
-
custom_dataset_dir = (
|
756 |
-
Path(distillation_config.custom_dataset_path)
|
757 |
-
if distillation_config.custom_dataset_path
|
758 |
-
else Path("code_model2vec/dataset")
|
759 |
-
)
|
760 |
-
|
761 |
-
if custom_dataset_dir.exists() and (custom_dataset_dir / "train.parquet").exists():
|
762 |
-
train_df = load_custom_dataset(output_dir=custom_dataset_dir, split="train")
|
763 |
-
# Sample a subset for frequency computation
|
764 |
-
sample_size = min(10000, len(train_df))
|
765 |
-
train_df_sample = train_df.sample(n=sample_size, random_state=42)
|
766 |
-
dataset_texts = train_df_sample["text"].tolist()
|
767 |
-
logger.info(f"π Using {len(dataset_texts)} samples from custom optimized dataset")
|
768 |
-
else:
|
769 |
-
# Fallback to original dataset loading
|
770 |
-
dataset_texts = load_codesearchnet_dataset(max_samples=10000)
|
771 |
-
logger.info(
|
772 |
-
f"π Custom dataset not found, using original CodeSearchNet with {len(dataset_texts)} texts"
|
773 |
-
)
|
774 |
-
else:
|
775 |
-
dataset_texts = load_codesearchnet_dataset(max_samples=10000)
|
776 |
-
logger.info(f"π Using original CodeSearchNet with {len(dataset_texts)} texts")
|
777 |
-
|
778 |
-
logger.info(f"π Computing frequencies on {len(dataset_texts)} texts...")
|
779 |
-
|
780 |
-
# Tokenize all texts and count token frequencies
|
781 |
-
tokenizer = teacher_model.tokenizer
|
782 |
-
token_counts: Counter[int] = Counter()
|
783 |
-
|
784 |
-
# Process in batches to avoid memory issues
|
785 |
-
batch_size = 100
|
786 |
-
for i in range(0, len(dataset_texts), batch_size):
|
787 |
-
batch_texts = dataset_texts[i : i + batch_size]
|
788 |
-
|
789 |
-
for text in batch_texts:
|
790 |
-
# Tokenize the text
|
791 |
-
tokens = tokenizer.encode(text, add_special_tokens=False)
|
792 |
-
token_counts.update(tokens)
|
793 |
-
|
794 |
-
if i % (batch_size * 10) == 0:
|
795 |
-
logger.info(f" Processed {i + len(batch_texts)}/{len(dataset_texts)} texts...")
|
796 |
-
|
797 |
-
# Convert to frequencies (token_id -> count)
|
798 |
-
token_frequencies = dict(token_counts)
|
799 |
-
|
800 |
-
# Save token frequencies to features directory for post-training regularization
|
801 |
-
freq_file = features_dir / "token_frequencies.json"
|
802 |
-
with freq_file.open("w") as f:
|
803 |
-
json.dump(token_frequencies, f, indent=2)
|
804 |
-
|
805 |
-
logger.info(f"β
Token frequencies saved to {freq_file}")
|
806 |
-
logger.info(f"π Total unique tokens: {len(token_frequencies)}")
|
807 |
-
logger.info(f"π Total token occurrences: {sum(token_frequencies.values())}")
|
808 |
-
|
809 |
-
except Exception as e:
|
810 |
-
logger.warning(f"β οΈ Failed to compute token frequencies: {e}")
|
811 |
-
logger.warning("β οΈ Post-training re-regularization will use default Zipf weighting")
|
812 |
-
|
813 |
-
|
814 |
-
def apply_post_training_regularization(
|
815 |
-
model: Any,
|
816 |
-
features_dir: Path,
|
817 |
-
pca_dims: int = 256,
|
818 |
-
) -> Any:
|
819 |
-
"""
|
820 |
-
Apply post-training re-regularization following the POTION approach.
|
821 |
-
|
822 |
-
This includes:
|
823 |
-
1. Token frequency weighting using corpus frequencies
|
824 |
-
2. PCA application
|
825 |
-
3. SIF weighting using formula: w = 1e-3 / (1e-3 + proba)
|
826 |
-
"""
|
827 |
-
import json
|
828 |
-
|
829 |
-
from sklearn.decomposition import PCA
|
830 |
-
|
831 |
-
logger.info("π§ Starting post-training re-regularization (POTION Step 4)")
|
832 |
-
|
833 |
-
# Step 4a: Load token frequencies from the training corpus
|
834 |
-
logger.info("π Computing token frequencies from training corpus...")
|
835 |
-
|
836 |
-
# Try to load token frequencies from features directory
|
837 |
-
freq_file = features_dir / "token_frequencies.json"
|
838 |
-
|
839 |
-
if freq_file.exists():
|
840 |
-
with freq_file.open("r") as f:
|
841 |
-
token_frequencies = json.load(f)
|
842 |
-
logger.info(f"β
Loaded token frequencies from {freq_file}")
|
843 |
-
else:
|
844 |
-
logger.warning("β οΈ Token frequencies not found - using default Zipf weighting")
|
845 |
-
# Fallback to basic frequency estimation based on rank
|
846 |
-
vocab_size = model.embedding.shape[0]
|
847 |
-
token_frequencies = {str(i): 1.0 / (i + 1) for i in range(vocab_size)}
|
848 |
-
|
849 |
-
# Step 4b: Apply PCA to the embeddings
|
850 |
-
logger.info(f"π Applying PCA with {pca_dims} dimensions...")
|
851 |
-
|
852 |
-
# Get current embeddings
|
853 |
-
# Handle both torch tensors and numpy arrays
|
854 |
-
if hasattr(model.embedding, "cpu"):
|
855 |
-
embeddings = model.embedding.cpu().numpy().astype(np.float64)
|
856 |
-
else:
|
857 |
-
embeddings = model.embedding.astype(np.float64)
|
858 |
-
original_shape = embeddings.shape
|
859 |
-
logger.info(f"Original embedding shape: {original_shape}")
|
860 |
-
|
861 |
-
# Apply PCA if dimensions don't match
|
862 |
-
if original_shape[1] != pca_dims:
|
863 |
-
pca = PCA(n_components=pca_dims, random_state=42)
|
864 |
-
embeddings_pca = pca.fit_transform(embeddings)
|
865 |
-
logger.info(f"PCA applied: {original_shape} β {embeddings_pca.shape}")
|
866 |
-
|
867 |
-
# Explained variance ratio
|
868 |
-
explained_var = pca.explained_variance_ratio_.sum()
|
869 |
-
logger.info(f"PCA explained variance ratio: {explained_var:.4f}")
|
870 |
-
else:
|
871 |
-
embeddings_pca = embeddings
|
872 |
-
logger.info("PCA dimensions match - no PCA transformation needed")
|
873 |
-
|
874 |
-
# Step 4c: Apply SIF weighting using corpus frequencies
|
875 |
-
logger.info("βοΈ Applying SIF weighting based on token frequencies...")
|
876 |
-
|
877 |
-
# Convert token frequencies to probabilities
|
878 |
-
total_tokens = sum(token_frequencies.values())
|
879 |
-
token_probs = {token: freq / total_tokens for token, freq in token_frequencies.items()}
|
880 |
-
|
881 |
-
# Apply SIF weighting: w = 1e-3 / (1e-3 + proba)
|
882 |
-
sif_coefficient = 1e-3 # Standard SIF coefficient
|
883 |
-
|
884 |
-
for i in range(embeddings_pca.shape[0]):
|
885 |
-
token_id = str(i)
|
886 |
-
prob = token_probs[token_id] if token_id in token_probs else 1.0 / len(token_probs)
|
887 |
-
|
888 |
-
# Apply SIF weighting formula
|
889 |
-
sif_weight = sif_coefficient / (sif_coefficient + prob)
|
890 |
-
embeddings_pca[i] *= sif_weight
|
891 |
-
|
892 |
-
logger.info("β
SIF weighting applied successfully")
|
893 |
-
|
894 |
-
# Step 4d: Create new model with re-regularized embeddings
|
895 |
-
logger.info("π¦ Creating final model with re-regularized embeddings...")
|
896 |
-
|
897 |
-
# Convert back to float32 numpy array
|
898 |
-
final_embeddings = embeddings_pca.astype(np.float32)
|
899 |
-
|
900 |
-
# Create new model with updated embeddings
|
901 |
-
from distiller.model2vec.model import StaticModel
|
902 |
-
|
903 |
-
# Save tokenizer and config from original model
|
904 |
-
tokenizer = model.tokenizer
|
905 |
-
config = model.config
|
906 |
-
|
907 |
-
# Create new model with re-regularized embeddings
|
908 |
-
final_model = StaticModel(vectors=final_embeddings, tokenizer=tokenizer, config=config)
|
909 |
-
|
910 |
-
logger.info("β
Post-training re-regularization completed successfully")
|
911 |
-
logger.info(f"Final model embedding shape: {final_model.embedding.shape}")
|
912 |
-
|
913 |
-
return final_model
|
914 |
-
|
915 |
-
|
916 |
def tokenlearn_training(
|
917 |
student_model: Any,
|
918 |
teacher_model: SentenceTransformer,
|
919 |
checkpoint_manager: BeamCheckpointManager | None = None, # noqa: ARG001
|
920 |
-
skip_post_training_regularization: bool = False,
|
921 |
) -> Any:
|
922 |
"""
|
923 |
Perform tokenlearn training following the official POTION approach.
|
@@ -926,7 +751,6 @@ def tokenlearn_training(
|
|
926 |
1. Model2Vec distillation (already done - student_model)
|
927 |
2. Sentence transformer inference (create features)
|
928 |
3. Tokenlearn training
|
929 |
-
4. Post-training re-regularization (PCA + SIF weighting)
|
930 |
"""
|
931 |
from pathlib import Path
|
932 |
|
@@ -1043,10 +867,6 @@ def tokenlearn_training(
|
|
1043 |
featurization_complete_marker.touch()
|
1044 |
logger.info(f"πΎ Created featurization checkpoint: {featurization_complete_marker}")
|
1045 |
|
1046 |
-
# Generate token frequencies for post-training re-regularization
|
1047 |
-
logger.info("π Computing token frequencies for SIF weighting...")
|
1048 |
-
compute_token_frequencies_for_sif(teacher_model, features_dir)
|
1049 |
-
|
1050 |
except Exception as e:
|
1051 |
logger.exception("π₯ Tokenlearn featurization failed")
|
1052 |
logger.exception("π₯ Tokenlearn featurization is required for training - cannot proceed")
|
@@ -1191,19 +1011,9 @@ def tokenlearn_training(
|
|
1191 |
logger.info("π Loading model from tokenlearn training...")
|
1192 |
trained_model = StaticModel.from_pretrained(str(trained_model_path))
|
1193 |
|
1194 |
-
#
|
1195 |
-
|
1196 |
-
|
1197 |
-
final_model = trained_model
|
1198 |
-
logger.info("β
Tokenlearn training pipeline completed successfully (without re-regularization)")
|
1199 |
-
else:
|
1200 |
-
logger.info("π§ Applying post-training re-regularization (PCA + SIF weighting)...")
|
1201 |
-
final_model = apply_post_training_regularization(
|
1202 |
-
trained_model, features_dir, pca_dims=distillation_config.optimal_pca_dims
|
1203 |
-
)
|
1204 |
-
logger.info("β
Tokenlearn training pipeline with post-training re-regularization completed successfully")
|
1205 |
-
|
1206 |
-
return final_model
|
1207 |
|
1208 |
except ValueError as e:
|
1209 |
if "Number of tokens" in str(e) and "does not match number of vectors" in str(e):
|
@@ -1366,7 +1176,6 @@ def distill_single_teacher(
|
|
1366 |
base_model,
|
1367 |
teacher_st_model,
|
1368 |
checkpoint_mgr,
|
1369 |
-
skip_post_training_regularization=distillation_config.skip_post_training_regularization,
|
1370 |
)
|
1371 |
|
1372 |
# Save final model
|
@@ -1706,9 +1515,6 @@ def main(
|
|
1706 |
clear_checkpoints: Annotated[
|
1707 |
bool, typer.Option(help="Clear tokenlearn checkpoints to force fresh featurization and training")
|
1708 |
] = False,
|
1709 |
-
skip_ptr: Annotated[
|
1710 |
-
bool, typer.Option("--skip-ptr", help="Skip post-training re-regularization (PCA + SIF weighting) step")
|
1711 |
-
] = False,
|
1712 |
use_optimized_dataset: Annotated[
|
1713 |
bool,
|
1714 |
typer.Option(
|
@@ -1723,17 +1529,15 @@ def main(
|
|
1723 |
"""Unified distillation command with optional training."""
|
1724 |
logger.info("π Starting unified Model2Vec distillation workflow")
|
1725 |
|
1726 |
-
# Set post-training regularization flag in config
|
1727 |
-
distillation_config.skip_post_training_regularization = skip_ptr
|
1728 |
-
if skip_ptr and train:
|
1729 |
-
logger.info("βοΈ Post-training re-regularization will be skipped (PCA + SIF weighting disabled)")
|
1730 |
-
|
1731 |
# Set dataset configuration
|
1732 |
distillation_config.use_optimized_dataset = use_optimized_dataset
|
1733 |
distillation_config.custom_dataset_path = dataset_path
|
|
|
1734 |
if use_optimized_dataset and train:
|
1735 |
dataset_source = dataset_path or "code_model2vec/dataset"
|
1736 |
logger.info(f"π― Using optimized dataset from: {dataset_source}")
|
|
|
|
|
1737 |
|
1738 |
logger.info(f"π Training mode: {'Tokenlearn (POTION) training' if train else 'Basic distillation only'}")
|
1739 |
logger.info(f"βοΈ Execution: {'Beam' if use_beam else 'Local'}")
|
@@ -2200,7 +2004,7 @@ def _prepare_custom_dataset_for_tokenlearn(tokenlearn_dir: Path) -> tuple[str, s
|
|
2200 |
if not custom_dataset_dir.exists() or not (custom_dataset_dir / "train.parquet").exists():
|
2201 |
logger.info("π Custom dataset not found - creating optimized dataset...")
|
2202 |
create_optimized_dataset(
|
2203 |
-
max_samples_per_lang=
|
2204 |
output_dir=custom_dataset_dir,
|
2205 |
create_multiple_formats=False, # Use simple format for tokenlearn
|
2206 |
)
|
@@ -2230,14 +2034,13 @@ def _prepare_custom_dataset_for_tokenlearn(tokenlearn_dir: Path) -> tuple[str, s
|
|
2230 |
return str(train_json_path), None, "text"
|
2231 |
|
2232 |
|
2233 |
-
def _prepare_original_dataset_for_tokenlearn() -> tuple[str, str, str]:
|
2234 |
-
"""Prepare original
|
2235 |
-
logger.info("π Using
|
2236 |
-
|
2237 |
return (
|
2238 |
-
str(distillation_config.tokenlearn_dataset), # "
|
2239 |
-
str(distillation_config.tokenlearn_dataset_name), # "
|
2240 |
-
str(distillation_config.tokenlearn_text_key), # "
|
2241 |
)
|
2242 |
|
2243 |
|
|
|
28 |
from pathlib import Path
|
29 |
from typing import Annotated, Any
|
30 |
|
|
|
31 |
import torch
|
32 |
import typer
|
33 |
from beam import function
|
|
|
409 |
|
410 |
|
411 |
def load_optimized_dataset(
|
412 |
+
max_samples: int | None = None,
|
413 |
checkpoint_manager: BeamCheckpointManager | None = None,
|
414 |
dataset_path: str | None = None,
|
415 |
) -> list[str]:
|
|
|
423 |
|
424 |
dataset_dir = Path(dataset_path) if dataset_path else DATASET_OUTPUT_DIR
|
425 |
|
426 |
+
# Use configuration default if not specified
|
427 |
+
if max_samples is None:
|
428 |
+
max_samples = distillation_config.tokenlearn_max_samples
|
429 |
+
|
430 |
logger.info(f"π― Loading optimized dataset from {dataset_dir}")
|
431 |
logger.info(f"π Target samples: {max_samples}")
|
432 |
|
|
|
465 |
|
466 |
|
467 |
def load_codesearchnet_dataset(
|
468 |
+
max_samples: int | None = None,
|
469 |
checkpoint_manager: BeamCheckpointManager | None = None,
|
470 |
) -> list[str]:
|
471 |
"""Load and format the CodeSearchNet dataset for token frequency computation."""
|
472 |
from datasets import load_dataset
|
473 |
|
474 |
+
# Use configuration default if not specified
|
475 |
+
if max_samples is None:
|
476 |
+
max_samples = distillation_config.tokenlearn_max_samples
|
477 |
+
|
478 |
logger.info(f"Loading CodeSearchNet dataset from {codesearchnet_config.dataset_name}")
|
479 |
logger.info(f"Limiting to {max_samples} samples for training efficiency")
|
480 |
logger.info(f"Languages: {', '.join(languages_config.all)}")
|
|
|
739 |
return teacher_embeddings
|
740 |
|
741 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
742 |
def tokenlearn_training(
|
743 |
student_model: Any,
|
744 |
teacher_model: SentenceTransformer,
|
745 |
checkpoint_manager: BeamCheckpointManager | None = None, # noqa: ARG001
|
|
|
746 |
) -> Any:
|
747 |
"""
|
748 |
Perform tokenlearn training following the official POTION approach.
|
|
|
751 |
1. Model2Vec distillation (already done - student_model)
|
752 |
2. Sentence transformer inference (create features)
|
753 |
3. Tokenlearn training
|
|
|
754 |
"""
|
755 |
from pathlib import Path
|
756 |
|
|
|
867 |
featurization_complete_marker.touch()
|
868 |
logger.info(f"πΎ Created featurization checkpoint: {featurization_complete_marker}")
|
869 |
|
|
|
|
|
|
|
|
|
870 |
except Exception as e:
|
871 |
logger.exception("π₯ Tokenlearn featurization failed")
|
872 |
logger.exception("π₯ Tokenlearn featurization is required for training - cannot proceed")
|
|
|
1011 |
logger.info("π Loading model from tokenlearn training...")
|
1012 |
trained_model = StaticModel.from_pretrained(str(trained_model_path))
|
1013 |
|
1014 |
+
# Return the trained model directly
|
1015 |
+
logger.info("β
Tokenlearn training pipeline completed successfully")
|
1016 |
+
return trained_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1017 |
|
1018 |
except ValueError as e:
|
1019 |
if "Number of tokens" in str(e) and "does not match number of vectors" in str(e):
|
|
|
1176 |
base_model,
|
1177 |
teacher_st_model,
|
1178 |
checkpoint_mgr,
|
|
|
1179 |
)
|
1180 |
|
1181 |
# Save final model
|
|
|
1515 |
clear_checkpoints: Annotated[
|
1516 |
bool, typer.Option(help="Clear tokenlearn checkpoints to force fresh featurization and training")
|
1517 |
] = False,
|
|
|
|
|
|
|
1518 |
use_optimized_dataset: Annotated[
|
1519 |
bool,
|
1520 |
typer.Option(
|
|
|
1529 |
"""Unified distillation command with optional training."""
|
1530 |
logger.info("π Starting unified Model2Vec distillation workflow")
|
1531 |
|
|
|
|
|
|
|
|
|
|
|
1532 |
# Set dataset configuration
|
1533 |
distillation_config.use_optimized_dataset = use_optimized_dataset
|
1534 |
distillation_config.custom_dataset_path = dataset_path
|
1535 |
+
|
1536 |
if use_optimized_dataset and train:
|
1537 |
dataset_source = dataset_path or "code_model2vec/dataset"
|
1538 |
logger.info(f"π― Using optimized dataset from: {dataset_source}")
|
1539 |
+
elif train:
|
1540 |
+
logger.info("π― Using C4 dataset for training (following POTION approach)")
|
1541 |
|
1542 |
logger.info(f"π Training mode: {'Tokenlearn (POTION) training' if train else 'Basic distillation only'}")
|
1543 |
logger.info(f"βοΈ Execution: {'Beam' if use_beam else 'Local'}")
|
|
|
2004 |
if not custom_dataset_dir.exists() or not (custom_dataset_dir / "train.parquet").exists():
|
2005 |
logger.info("π Custom dataset not found - creating optimized dataset...")
|
2006 |
create_optimized_dataset(
|
2007 |
+
max_samples_per_lang=distillation_config.tokenlearn_max_samples // 6, # Divide by number of languages
|
2008 |
output_dir=custom_dataset_dir,
|
2009 |
create_multiple_formats=False, # Use simple format for tokenlearn
|
2010 |
)
|
|
|
2034 |
return str(train_json_path), None, "text"
|
2035 |
|
2036 |
|
2037 |
+
def _prepare_original_dataset_for_tokenlearn() -> tuple[str, str | None, str]:
|
2038 |
+
"""Prepare original dataset for tokenlearn featurization (uses C4 by default following POTION approach)."""
|
2039 |
+
logger.info("π Using C4 dataset for tokenlearn (following POTION approach)...")
|
|
|
2040 |
return (
|
2041 |
+
str(distillation_config.tokenlearn_dataset), # "allenai/c4"
|
2042 |
+
str(distillation_config.tokenlearn_dataset_name), # "en"
|
2043 |
+
str(distillation_config.tokenlearn_text_key), # "text"
|
2044 |
)
|
2045 |
|
2046 |
|