Sarthak commited on
Commit
0dbb356
Β·
1 Parent(s): 7837959

chore: update README and REPORT with performance insights and dataset changes

Browse files

This commit adds a warning in the README regarding the performance degradation observed with C4 fine-tuning, recommending the use of basic distillation for optimal results. The REPORT has been updated to reflect revised performance metrics for the fine-tuned model, adjustments in average performance statistics, and the inclusion of new radar charts for model comparisons. Additionally, the dataset configuration has been modified to utilize the C4 dataset for tokenlearn featurization.

NOTES.md ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Research Notes: Performance Analysis of C4 Fine-tuning vs Base Distillation
2
+
3
+ ## πŸ“Š Executive Summary
4
+
5
+ **Key Finding**: C4 fine-tuning significantly degraded performance across almost all metrics and programming languages compared to simple Model2Vec distillation.
6
+
7
+ **Recommendation**: Use simple Model2Vec distillation without additional training for optimal code embedding performance.
8
+
9
+ ---
10
+
11
+ ## πŸ“‰ Overall Performance Degradation
12
+
13
+ The comparison between base distilled models and C4-fine-tuned models reveals substantial performance regression:
14
+
15
+ | Metric | Base Model | Fine-tuned Model | Performance Drop |
16
+ |--------|------------|------------------|------------------|
17
+ | **NDCG@10** | 0.7387 | 0.6147 | **-16.8%** |
18
+ | **MRR** | 0.7010 | 0.5720 | **-18.4%** |
19
+ | **Recall@5** | 0.8017 | 0.6950 | **-13.3%** |
20
+ | **Recall@1** | 0.6169 | 0.4650 | **-24.6%** |
21
+
22
+ **Impact**: Double-digit performance drops across all major retrieval metrics, with Recall@1 suffering the most severe degradation at nearly 25%.
23
+
24
+ ---
25
+
26
+ ## πŸ” Language-Specific Impact Analysis
27
+
28
+ The performance degradation varied significantly across programming languages, revealing interesting patterns about domain sensitivity:
29
+
30
+ ### 🚨 **Severely Affected Languages**
31
+
32
+ #### **Java** (Catastrophic degradation):
33
+ - **NDCG@10**: 0.7027 β†’ 0.2820 (**-59.9%**)
34
+ - **MRR**: 0.6553 β†’ 0.2419 (**-63.1%**)
35
+ - **Mean Rank**: 7.24 β†’ 20.38 (almost 3x worse ranking)
36
+ - **Analysis**: Java suffered the most severe degradation, suggesting its documentation patterns are most incompatible with C4's web text distribution.
37
+
38
+ #### **PHP** (Major degradation):
39
+ - **NDCG@10**: 0.7055 β†’ 0.4453 (**-36.9%**)
40
+ - **MRR**: 0.6631 β†’ 0.3981 (**-40.0%**)
41
+ - **Analysis**: PHP's unique syntax and documentation style may have been particularly disrupted by general web text training.
42
+
43
+ ### πŸ“Š **Moderately Affected Languages**
44
+
45
+ #### **Python** (Best preserved):
46
+ - **NDCG@10**: 0.9674 β†’ 0.9219 (**-4.7%**)
47
+ - **MRR**: 0.9572 β†’ 0.8964 (**-6.3%**)
48
+ - **Analysis**: Python showed the smallest degradation, likely due to its prevalence in web tutorials and documentation that might overlap with C4 content.
49
+
50
+ #### **Ruby** (Minor degradation):
51
+ - **NDCG@10**: 0.7287 β†’ 0.7178 (**-1.5%**)
52
+ - **MRR**: 0.6869 β†’ 0.6776 (**-1.4%**)
53
+
54
+ #### **Go** (Minor degradation):
55
+ - **NDCG@10**: 0.7529 β†’ 0.7250 (**-3.7%**)
56
+ - **MRR**: 0.7059 β†’ 0.6699 (**-5.1%**)
57
+
58
+ ### βœ… **Single Improvement**
59
+
60
+ #### **JavaScript** (Slight improvement):
61
+ - **NDCG@10**: 0.5752 β†’ 0.5959 (**+3.6%**)
62
+ - **MRR**: 0.5378 β†’ 0.5481 (**+1.9%**)
63
+ - **Analysis**: JavaScript was the only language to show improvement, possibly due to extensive JavaScript content in web pages that align with C4's distribution.
64
+
65
+ ---
66
+
67
+ ## πŸ” Model Characteristics Comparison
68
+
69
+ | Aspect | Base Model | Fine-tuned Model | Change | Impact |
70
+ |--------|------------|------------------|--------|---------|
71
+ | **Parameters** | 7.56M | 9.38M | +24% larger | Increased complexity |
72
+ | **Disk Size** | 15.07MB | 36.94MB | +145% larger | Storage overhead |
73
+ | **Performance** | Superior | Inferior | Significantly worse | Counterproductive |
74
+ | **Efficiency** | High | Low | Worse per parameter | Resource waste |
75
+
76
+ **Key Insight**: The fine-tuned model is larger, more complex, and performs worseβ€”a clear example of the "bigger is not always better" principle.
77
+
78
+ ---
79
+
80
+ ## 🧠 Root Cause Analysis
81
+
82
+ ### 1. **🌐 Domain Mismatch**
83
+ - **Problem**: C4 contains general web text (articles, forums, websites, news)
84
+ - **Impact**: Code documentation has fundamentally different linguistic patterns, vocabulary, and structure
85
+ - **Result**: Training on web text actively degraded code-specific knowledge
86
+
87
+ ### 2. **🧠 Catastrophic Forgetting**
88
+ - **Problem**: The model "forgot" code-specific embeddings during C4 training
89
+ - **Evidence**: Java and PHP were hit hardest (59.9% and 36.9% NDCG@10 drops respectively)
90
+ - **Mechanism**: New training overwrote previously learned code-specific representations
91
+
92
+ ### 3. **πŸ“Š Distribution Shift**
93
+ - **Problem**: C4's token distribution is vastly different from code comments and documentation
94
+ - **Impact**: Model learned patterns that are irrelevant or harmful for code retrieval
95
+ - **Evidence**: Uniform degradation across most languages suggests systematic distribution mismatch
96
+
97
+ ### 4. **βš–οΈ Training Methodology Issues**
98
+ - **Problem**: Tokenlearn training on C4 introduced noise rather than signal
99
+ - **Analysis**: The POTION approach works well for general text but fails for specialized domains
100
+ - **Conclusion**: Domain-agnostic training methods can be counterproductive
101
+
102
+ ---
103
+
104
+ ## πŸ“ˆ Performance vs Complexity Analysis
105
+
106
+ ```
107
+ Performance Efficiency = NDCG@10 / Model_Size_MB
108
+
109
+ Base Model: 0.7387 / 15.07 = 0.049 (High efficiency)
110
+ Fine-tuned Model: 0.6147 / 36.94 = 0.017 (Low efficiency)
111
+
112
+ Efficiency Loss: 65.3%
113
+ ```
114
+
115
+ The fine-tuned model is not only worse performing but also dramatically less efficient, representing a significant regression in both absolute and relative terms.
116
+
117
+ ---
118
+
119
+ ## 🎯 Key Research Insights
120
+
121
+ ### 1. **Domain Specificity Matters**
122
+ Code embeddings require domain-specific training data. General web text (C4) actively harms code retrieval performance.
123
+
124
+ ### 2. **Language-Dependent Vulnerability**
125
+ Programming languages show different sensitivity to domain shift:
126
+ - **High vulnerability**: Java, PHP (enterprise/web languages)
127
+ - **Medium vulnerability**: Go, Ruby
128
+ - **Low vulnerability**: Python (ubiquitous in tutorials)
129
+ - **Potential benefit**: JavaScript (web-native language)
130
+
131
+ ### 3. **Simple Distillation Superiority**
132
+ Model2Vec's simple distillation approach outperforms complex fine-tuning when training data is misaligned with the target domain.
133
+
134
+ ### 4. **Training Data Quality > Quantity**
135
+ Using massive but irrelevant data (C4) is worse than using no additional training at all.
136
+
137
+ ---
138
+
139
+ ## πŸ“‹ Actionable Recommendations
140
+
141
+ ### ❌ **What NOT to Do**
142
+ 1. **Don't use C4 for code models**: General web text degrades code-specific performance
143
+ 2. **Don't assume more training is better**: Additional training can be counterproductive
144
+ 3. **Don't ignore domain alignment**: Training data must match target application domain
145
+ 4. **Don't prioritize model size**: Larger models can perform worse if poorly trained
146
+
147
+ ### βœ… **What TO Do**
148
+ 1. **Stick to base distillation**: Simple Model2Vec distillation gives optimal results for code tasks
149
+ 2. **Use code-specific datasets only**: If fine-tuning is needed, use CodeSearchNet or similar datasets
150
+ 3. **Validate domain alignment**: Ensure training data distribution matches target use case
151
+ 4. **Measure efficiency**: Consider performance per parameter, not just absolute performance
152
+ 5. **Test incrementally**: Validate that each training step improves rather than degrades performance
153
+
154
+ ### πŸ”¬ **Future Research Directions**
155
+ 1. **Code-specific fine-tuning**: Investigate tokenlearn training with CodeSearchNet instead of C4
156
+ 2. **Selective fine-tuning**: Apply additional training only to languages that show potential benefit (JavaScript)
157
+ 3. **Hybrid approaches**: Combine base distillation with minimal, targeted code-specific training
158
+ 4. **Domain adaptation techniques**: Develop methods to prevent catastrophic forgetting during domain transfer
159
+
160
+ ---
161
+
162
+ ## πŸ“Š Statistical Significance
163
+
164
+ All performance drops are substantial and consistent across metrics:
165
+ - **Minimum degradation**: 1.4% (Ruby MRR)
166
+ - **Maximum degradation**: 63.1% (Java MRR)
167
+ - **Median degradation**: ~15% across all metrics
168
+ - **Only improvement**: JavaScript (+3.6% NDCG@10)
169
+
170
+ **Conclusion**: The degradation is not due to random variation but represents a systematic failure of the C4 fine-tuning approach.
171
+
172
+ ---
173
+
174
+ ## πŸŽ“ Lessons Learned
175
+
176
+ 1. **Domain expertise beats scale**: Code-specific knowledge is more valuable than training on massive general datasets
177
+ 2. **Validate training approaches**: Always compare against simpler baselines before deploying complex training pipelines
178
+ 3. **Language-specific patterns matter**: Different programming languages have varying sensitivity to domain shift
179
+ 4. **Efficiency is crucial**: Model performance per parameter is often more important than absolute performance
180
+ 5. **Simple can be superior**: Sometimes the simplest approach (basic distillation) outperforms sophisticated alternatives
181
+
182
+ ---
183
+
184
+ **Documentation Date**: December 2024
185
+ **Model Comparison**: `sentence-transformers/all-mpnet-base-v2` teacher β†’ Model2Vec distillation vs Model2Vec + C4 tokenlearn fine-tuning
186
+ **Evaluation Dataset**: CodeSearchNet across 6 programming languages
187
+ **Key Finding**: Simple distillation outperforms complex fine-tuning by 16.8% NDCG@10 on average
README.md CHANGED
@@ -71,6 +71,9 @@ pipeline_tag: feature-extraction
71
  >[!Important]
72
  >Check out the comprehensive [REPORT.md](REPORT.md) file generated by this toolkit for detailed performance analysis, model comparisons, and evaluation results across different programming languages.
73
 
 
 
 
74
  The **distiller** package provides a complete pipeline for:
75
 
76
  1. **Distilling code-specialized embeddings** from large sentence transformer models using Model2Vec
 
71
  >[!Important]
72
  >Check out the comprehensive [REPORT.md](REPORT.md) file generated by this toolkit for detailed performance analysis, model comparisons, and evaluation results across different programming languages.
73
 
74
+ >[!Warning]
75
+ >**Research Finding**: See [NOTES.md](NOTES.md) for critical analysis showing that C4 fine-tuning significantly degraded performance (-16.8% NDCG@10) compared to simple Model2Vec distillation. **Recommendation**: Use basic distillation without additional training for optimal code embedding performance.
76
+
77
  The **distiller** package provides a complete pipeline for:
78
 
79
  1. **Distilling code-specialized embeddings** from large sentence transformer models using Model2Vec
REPORT.md CHANGED
@@ -29,7 +29,7 @@ This report presents a comprehensive analysis of Model2Vec distillation experime
29
  | code_model2vec_jina_embeddings_v2_base_code | [jina-embeddings-v2-base-code](https://huggingface.co/jina-embeddings-v2-base-code) | 0.7381 | 0.6996 | 0.8130 | πŸ₯‰ 3rd |
30
  | code_model2vec_paraphrase_MiniLM_L6_v2 | [sentence-transformers/paraphrase-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/paraphrase-MiniLM-L6-v2) | 0.7013 | 0.6638 | 0.7665 | #4 |
31
  | code_model2vec_Reason_ModernColBERT | [lightonai/Reason-ModernColBERT](https://huggingface.co/lightonai/Reason-ModernColBERT) | 0.6598 | 0.6228 | 0.7260 | #5 |
32
- | code_model2vec_all_mpnet_base_v2_fine_tuned | [sentence-transformers/all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2) | 0.5347 | 0.4875 | 0.6200 | #6 |
33
  | code_model2vec_bge_m3 | [BAAI/bge-m3](https://huggingface.co/BAAI/bge-m3) | 0.4863 | 0.4439 | 0.5514 | #7 |
34
  | code_model2vec_jina_embeddings_v3 | [jinaai/jina-embeddings-v3](https://huggingface.co/jinaai/jina-embeddings-v3) | 0.4755 | 0.4416 | 0.5456 | #8 |
35
  | code_model2vec_nomic_embed_text_v2_moe | [nomic-ai/nomic-embed-text-v2-moe](https://huggingface.co/nomic-ai/nomic-embed-text-v2-moe) | 0.4532 | 0.4275 | 0.5094 | #9 |
@@ -51,7 +51,7 @@ Our distilled models exhibit consistent architectural characteristics across dif
51
  | jina_embeddings_v2_base_code | 61,053 | 15.6M | 256 | 29.8MB |
52
  | paraphrase_MiniLM_L6_v2 | 29,525 | 7.6M | 256 | 14.4MB |
53
  | Reason_ModernColBERT | 50,254 | 12.9M | 256 | 24.5MB |
54
- | all_mpnet_base_v2_fine_tuned | 29,528 | 7.6M | 256 | 28.8MB |
55
  | bge_m3 | 249,999 | 64.0M | 256 | 122.1MB |
56
  | jina_embeddings_v3 | 249,999 | 64.0M | 256 | 122.1MB |
57
  | nomic_embed_text_v2_moe | 249,999 | 64.0M | 256 | 122.1MB |
@@ -69,9 +69,9 @@ Our distilled models exhibit consistent architectural characteristics across dif
69
  #### Key Insights from Model Specifications:
70
 
71
 
72
- - **Vocabulary Consistency**: All models use vocabulary sizes ranging from 29,525 to 249,999 tokens (avg: 101,087)
73
- - **Parameter Efficiency**: Models range from 7.6M to 64.0M parameters (avg: 25.9M)
74
- - **Storage Efficiency**: Disk usage ranges from 14.4MB to 122.1MB (avg: 50.4MB)
75
  - **Embedding Dimensions**: Consistent 256 dimensions across all models (optimized for efficiency)
76
 
77
 
@@ -81,13 +81,85 @@ Our distilled models exhibit consistent architectural characteristics across dif
81
  - **Best Teacher Model**: code_model2vec_all_mpnet_base_v2 (NDCG@10: 0.7387)
82
  - **Least Effective Teacher**: code_model2vec_codebert_base (NDCG@10: 0.2779)
83
  - **Performance Range**: 62.4% difference between best and worst
84
- - **Average Performance**: 0.5190 NDCG@10
85
 
86
 
87
  ## 🎯 Language Performance Radar Charts
88
 
89
  ### Best Model vs Peer Models Comparison
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  | Rank | Model | Type | NDCG@10 | MRR | Recall@5 |
92
  |------|-------|------|---------|-----|----------|
93
  | 1 | Alibaba-NLP/gte-Qwen2-1.5B-instruct | General | 0.9729 | 0.9676 | 0.9825 |
@@ -109,9 +181,9 @@ Our distilled models exhibit consistent architectural characteristics across dif
109
  | 17 | code_model2vec_jina_embeddings_v2_base_code | **πŸ”₯ Simplified Distillation** | 0.7381 | 0.6996 | 0.8130 |
110
  | 18 | code_model2vec_paraphrase_MiniLM_L6_v2 | **πŸ”₯ Simplified Distillation** | 0.7013 | 0.6638 | 0.7665 |
111
  | 19 | code_model2vec_Reason_ModernColBERT | **πŸ”₯ Simplified Distillation** | 0.6598 | 0.6228 | 0.7260 |
112
- | 20 | potion-multilingual-128M | Model2Vec | 0.6124 | 0.5683 | 0.7017 |
113
- | 21 | huggingface/CodeBERTa-small-v1 | Code-Specific | 0.5903 | 0.5350 | 0.6779 |
114
- | 22 | code_model2vec_all_mpnet_base_v2_fine_tuned | **πŸŽ“ Fine-tuned Distillation** | 0.5347 | 0.4875 | 0.6200 |
115
  | 23 | Salesforce/codet5-base | Code-Specific | 0.4872 | 0.4500 | 0.5742 |
116
  | 24 | code_model2vec_bge_m3 | **πŸ”₯ Simplified Distillation** | 0.4863 | 0.4439 | 0.5514 |
117
  | 25 | code_model2vec_jina_embeddings_v3 | **πŸ”₯ Simplified Distillation** | 0.4755 | 0.4416 | 0.5456 |
@@ -171,12 +243,12 @@ Our distilled models exhibit consistent architectural characteristics across dif
171
 
172
  | Language | Best Model Performance | Average Performance | Language Difficulty |
173
  |----------|------------------------|--------------------|--------------------|
174
- | Go | 0.9780 | 0.6923 | Easy |
175
- | Java | 0.9921 | 0.6545 | Easy |
176
- | Javascript | 0.9550 | 0.5831 | Easy |
177
- | Php | 1.0000 | 0.6325 | Easy |
178
- | Python | 1.0000 | 0.8599 | Easy |
179
- | Ruby | 0.9493 | 0.6333 | Easy |
180
 
181
 
182
  ## 🎯 Conclusions and Recommendations
@@ -230,5 +302,5 @@ Based on the evaluation results across all simplified distillation models:
230
 
231
  ---
232
 
233
- *Report generated on 2025-05-31 21:07:06 using automated analysis pipeline.*
234
  *For questions about methodology or results, please refer to the CodeSearchNet documentation.*
 
29
  | code_model2vec_jina_embeddings_v2_base_code | [jina-embeddings-v2-base-code](https://huggingface.co/jina-embeddings-v2-base-code) | 0.7381 | 0.6996 | 0.8130 | πŸ₯‰ 3rd |
30
  | code_model2vec_paraphrase_MiniLM_L6_v2 | [sentence-transformers/paraphrase-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/paraphrase-MiniLM-L6-v2) | 0.7013 | 0.6638 | 0.7665 | #4 |
31
  | code_model2vec_Reason_ModernColBERT | [lightonai/Reason-ModernColBERT](https://huggingface.co/lightonai/Reason-ModernColBERT) | 0.6598 | 0.6228 | 0.7260 | #5 |
32
+ | code_model2vec_all_mpnet_base_v2_fine_tuned | [sentence-transformers/all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2) | 0.6147 | 0.5720 | 0.6950 | #6 |
33
  | code_model2vec_bge_m3 | [BAAI/bge-m3](https://huggingface.co/BAAI/bge-m3) | 0.4863 | 0.4439 | 0.5514 | #7 |
34
  | code_model2vec_jina_embeddings_v3 | [jinaai/jina-embeddings-v3](https://huggingface.co/jinaai/jina-embeddings-v3) | 0.4755 | 0.4416 | 0.5456 | #8 |
35
  | code_model2vec_nomic_embed_text_v2_moe | [nomic-ai/nomic-embed-text-v2-moe](https://huggingface.co/nomic-ai/nomic-embed-text-v2-moe) | 0.4532 | 0.4275 | 0.5094 | #9 |
 
51
  | jina_embeddings_v2_base_code | 61,053 | 15.6M | 256 | 29.8MB |
52
  | paraphrase_MiniLM_L6_v2 | 29,525 | 7.6M | 256 | 14.4MB |
53
  | Reason_ModernColBERT | 50,254 | 12.9M | 256 | 24.5MB |
54
+ | all_mpnet_base_v2_fine_tuned | 36,624 | 9.4M | 256 | 35.8MB |
55
  | bge_m3 | 249,999 | 64.0M | 256 | 122.1MB |
56
  | jina_embeddings_v3 | 249,999 | 64.0M | 256 | 122.1MB |
57
  | nomic_embed_text_v2_moe | 249,999 | 64.0M | 256 | 122.1MB |
 
69
  #### Key Insights from Model Specifications:
70
 
71
 
72
+ - **Vocabulary Consistency**: All models use vocabulary sizes ranging from 29,525 to 249,999 tokens (avg: 101,594)
73
+ - **Parameter Efficiency**: Models range from 7.6M to 64.0M parameters (avg: 26.0M)
74
+ - **Storage Efficiency**: Disk usage ranges from 14.4MB to 122.1MB (avg: 50.9MB)
75
  - **Embedding Dimensions**: Consistent 256 dimensions across all models (optimized for efficiency)
76
 
77
 
 
81
  - **Best Teacher Model**: code_model2vec_all_mpnet_base_v2 (NDCG@10: 0.7387)
82
  - **Least Effective Teacher**: code_model2vec_codebert_base (NDCG@10: 0.2779)
83
  - **Performance Range**: 62.4% difference between best and worst
84
+ - **Average Performance**: 0.5248 NDCG@10
85
 
86
 
87
  ## 🎯 Language Performance Radar Charts
88
 
89
  ### Best Model vs Peer Models Comparison
90
 
91
+ ![Comparative Radar Chart](analysis_charts/comparative_radar.png)
92
+
93
+ *Comparative view showing how the best simplified distillation model performs against top peer models across programming languages.*
94
+
95
+ ### Individual Model Performance by Language
96
+
97
+ #### code_model2vec_all_mpnet_base_v2 (Teacher: [sentence-transformers/all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2)) - NDCG@10: 0.7387
98
+
99
+ ![code_model2vec_all_mpnet_base_v2 Radar Chart](analysis_charts/radar_code_model2vec_all_mpnet_base_v2.png)
100
+
101
+ #### code_model2vec_all_MiniLM_L6_v2 (Teacher: [sentence-transformers/all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)) - NDCG@10: 0.7385
102
+
103
+ ![code_model2vec_all_MiniLM_L6_v2 Radar Chart](analysis_charts/radar_code_model2vec_all_MiniLM_L6_v2.png)
104
+
105
+ #### code_model2vec_jina_embeddings_v2_base_code (Teacher: [jina-embeddings-v2-base-code](https://huggingface.co/jina-embeddings-v2-base-code)) - NDCG@10: 0.7381
106
+
107
+ ![code_model2vec_jina_embeddings_v2_base_code Radar Chart](analysis_charts/radar_code_model2vec_jina_embeddings_v2_base_code.png)
108
+
109
+ #### code_model2vec_paraphrase_MiniLM_L6_v2 (Teacher: [sentence-transformers/paraphrase-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/paraphrase-MiniLM-L6-v2)) - NDCG@10: 0.7013
110
+
111
+ ![code_model2vec_paraphrase_MiniLM_L6_v2 Radar Chart](analysis_charts/radar_code_model2vec_paraphrase_MiniLM_L6_v2.png)
112
+
113
+ #### code_model2vec_Reason_ModernColBERT (Teacher: [lightonai/Reason-ModernColBERT](https://huggingface.co/lightonai/Reason-ModernColBERT)) - NDCG@10: 0.6598
114
+
115
+ ![code_model2vec_Reason_ModernColBERT Radar Chart](analysis_charts/radar_code_model2vec_Reason_ModernColBERT.png)
116
+
117
+ #### code_model2vec_all_mpnet_base_v2_fine_tuned (Teacher: [sentence-transformers/all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2)) - NDCG@10: 0.6147
118
+
119
+ ![code_model2vec_all_mpnet_base_v2_fine_tuned Radar Chart](analysis_charts/radar_code_model2vec_all_mpnet_base_v2_fine_tuned.png)
120
+
121
+ #### code_model2vec_bge_m3 (Teacher: [BAAI/bge-m3](https://huggingface.co/BAAI/bge-m3)) - NDCG@10: 0.4863
122
+
123
+ ![code_model2vec_bge_m3 Radar Chart](analysis_charts/radar_code_model2vec_bge_m3.png)
124
+
125
+ #### code_model2vec_jina_embeddings_v3 (Teacher: [jinaai/jina-embeddings-v3](https://huggingface.co/jinaai/jina-embeddings-v3)) - NDCG@10: 0.4755
126
+
127
+ ![code_model2vec_jina_embeddings_v3 Radar Chart](analysis_charts/radar_code_model2vec_jina_embeddings_v3.png)
128
+
129
+ #### code_model2vec_nomic_embed_text_v2_moe (Teacher: [nomic-ai/nomic-embed-text-v2-moe](https://huggingface.co/nomic-ai/nomic-embed-text-v2-moe)) - NDCG@10: 0.4532
130
+
131
+ ![code_model2vec_nomic_embed_text_v2_moe Radar Chart](analysis_charts/radar_code_model2vec_nomic_embed_text_v2_moe.png)
132
+
133
+ #### code_model2vec_gte_Qwen2_1.5B_instruct (Teacher: [Alibaba-NLP/gte-Qwen2-1.5B-instruct](https://huggingface.co/Alibaba-NLP/gte-Qwen2-1.5B-instruct)) - NDCG@10: 0.4238
134
+
135
+ ![code_model2vec_gte_Qwen2_1.5B_instruct Radar Chart](analysis_charts/radar_code_model2vec_gte_Qwen2_15B_instruct.png)
136
+
137
+ #### code_model2vec_Qodo_Embed_1_1.5B (Teacher: [Qodo/Qodo-Embed-1-1.5B](https://huggingface.co/Qodo/Qodo-Embed-1-1.5B)) - NDCG@10: 0.4101
138
+
139
+ ![code_model2vec_Qodo_Embed_1_1.5B Radar Chart](analysis_charts/radar_code_model2vec_Qodo_Embed_1_15B.png)
140
+
141
+ #### code_model2vec_graphcodebert_base (Teacher: [microsoft/codebert-base](https://huggingface.co/microsoft/codebert-base)) - NDCG@10: 0.3420
142
+
143
+ ![code_model2vec_graphcodebert_base Radar Chart](analysis_charts/radar_code_model2vec_graphcodebert_base.png)
144
+
145
+ #### code_model2vec_Linq_Embed_Mistral (Teacher: [Linq-AI-Research/Linq-Embed-Mistral](https://huggingface.co/Linq-AI-Research/Linq-Embed-Mistral)) - NDCG@10: 0.2868
146
+
147
+ ![code_model2vec_Linq_Embed_Mistral Radar Chart](analysis_charts/radar_code_model2vec_Linq_Embed_Mistral.png)
148
+
149
+ #### code_model2vec_codebert_base (Teacher: [microsoft/codebert-base](https://huggingface.co/microsoft/codebert-base)) - NDCG@10: 0.2779
150
+
151
+ ![code_model2vec_codebert_base Radar Chart](analysis_charts/radar_code_model2vec_codebert_base.png)
152
+
153
+
154
+
155
+ ## πŸ† Peer Model Comparison
156
+
157
+ ![Peer Comparison](analysis_charts/peer_comparison.png)
158
+
159
+ *Comparison with established code-specialized embedding models using actual evaluation results.*
160
+
161
+ ### Complete Model Ranking
162
+
163
  | Rank | Model | Type | NDCG@10 | MRR | Recall@5 |
164
  |------|-------|------|---------|-----|----------|
165
  | 1 | Alibaba-NLP/gte-Qwen2-1.5B-instruct | General | 0.9729 | 0.9676 | 0.9825 |
 
181
  | 17 | code_model2vec_jina_embeddings_v2_base_code | **πŸ”₯ Simplified Distillation** | 0.7381 | 0.6996 | 0.8130 |
182
  | 18 | code_model2vec_paraphrase_MiniLM_L6_v2 | **πŸ”₯ Simplified Distillation** | 0.7013 | 0.6638 | 0.7665 |
183
  | 19 | code_model2vec_Reason_ModernColBERT | **πŸ”₯ Simplified Distillation** | 0.6598 | 0.6228 | 0.7260 |
184
+ | 20 | code_model2vec_all_mpnet_base_v2_fine_tuned | **πŸŽ“ Fine-tuned Distillation** | 0.6147 | 0.5720 | 0.6950 |
185
+ | 21 | potion-multilingual-128M | Model2Vec | 0.6124 | 0.5683 | 0.7017 |
186
+ | 22 | huggingface/CodeBERTa-small-v1 | Code-Specific | 0.5903 | 0.5350 | 0.6779 |
187
  | 23 | Salesforce/codet5-base | Code-Specific | 0.4872 | 0.4500 | 0.5742 |
188
  | 24 | code_model2vec_bge_m3 | **πŸ”₯ Simplified Distillation** | 0.4863 | 0.4439 | 0.5514 |
189
  | 25 | code_model2vec_jina_embeddings_v3 | **πŸ”₯ Simplified Distillation** | 0.4755 | 0.4416 | 0.5456 |
 
243
 
244
  | Language | Best Model Performance | Average Performance | Language Difficulty |
245
  |----------|------------------------|--------------------|--------------------|
246
+ | Go | 0.9780 | 0.6960 | Easy |
247
+ | Java | 0.9921 | 0.6553 | Easy |
248
+ | Javascript | 0.9550 | 0.5850 | Easy |
249
+ | Php | 1.0000 | 0.6321 | Easy |
250
+ | Python | 1.0000 | 0.8623 | Easy |
251
+ | Ruby | 0.9493 | 0.6397 | Easy |
252
 
253
 
254
  ## 🎯 Conclusions and Recommendations
 
302
 
303
  ---
304
 
305
+ *Report generated on 2025-06-01 08:04:06 using automated analysis pipeline.*
306
  *For questions about methodology or results, please refer to the CodeSearchNet documentation.*
analysis_charts/batch_size_scaling.png CHANGED

Git LFS Details

  • SHA256: 1965e61be476c42036749ec5e0f96f177ea07ff282f89b19d1de284344556d3f
  • Pointer size: 132 Bytes
  • Size of remote file: 1.06 MB

Git LFS Details

  • SHA256: 285208fa4aaa7388319a0f6c6f773211c542c7bc215d9175f46c4a2f193281da
  • Pointer size: 132 Bytes
  • Size of remote file: 1.06 MB
analysis_charts/benchmark_performance.png CHANGED

Git LFS Details

  • SHA256: 0f632f595712a28a569c86c1ec44e0c1f0f29f6fcca0bd183cdb1ce8c045459a
  • Pointer size: 132 Bytes
  • Size of remote file: 2.05 MB

Git LFS Details

  • SHA256: 480fd3cf8cc86be28f56fab6e7930335f2cabb1f068a4928010f3333d5d2e0ac
  • Pointer size: 132 Bytes
  • Size of remote file: 2.05 MB
analysis_charts/efficiency_analysis.png CHANGED

Git LFS Details

  • SHA256: aa6dc848da294ff0d1b3dc48b5a4edfb11aaf197cfe874066e437788f71bebc5
  • Pointer size: 131 Bytes
  • Size of remote file: 239 kB

Git LFS Details

  • SHA256: 3167079aedb34c9cae7108da4857e80b241d1cd38d784bbdd45ca88dc63b2151
  • Pointer size: 131 Bytes
  • Size of remote file: 240 kB
analysis_charts/language_heatmap.png CHANGED

Git LFS Details

  • SHA256: c13d6a4ae9e5d57bc1c11084903b6381af1f66d70ae3fcc39b34933118c12652
  • Pointer size: 132 Bytes
  • Size of remote file: 1.21 MB

Git LFS Details

  • SHA256: 92bd667a45139bb0cb118778109705a5d5319b5e19a0ec1465df9a36cf1f20a6
  • Pointer size: 132 Bytes
  • Size of remote file: 1.21 MB
analysis_charts/memory_scaling.png CHANGED

Git LFS Details

  • SHA256: 75c1b84a9354411a022adab6093a76830028e462b38801f4bc6ddabcb4ac09cc
  • Pointer size: 131 Bytes
  • Size of remote file: 640 kB

Git LFS Details

  • SHA256: 65475b837f7a92c8620979f6733ed6f4d4479deb03b9ddbee00e25398730f585
  • Pointer size: 131 Bytes
  • Size of remote file: 639 kB
analysis_charts/model_comparison.png CHANGED

Git LFS Details

  • SHA256: 4c31ab756a6923ec277c0f1e03dd7ce266f31d29fbbef3e86e1b81a35d9b42a6
  • Pointer size: 132 Bytes
  • Size of remote file: 1.21 MB

Git LFS Details

  • SHA256: 483f24ff73c244b0323ef4e57b361cb89b4333fb564477448be4687cb4134348
  • Pointer size: 132 Bytes
  • Size of remote file: 1.21 MB
analysis_charts/model_specifications.png CHANGED

Git LFS Details

  • SHA256: 26f0e1f91445f820c9bb4138a829b9dfc4eb5002699010047b805777dbd36c46
  • Pointer size: 131 Bytes
  • Size of remote file: 654 kB

Git LFS Details

  • SHA256: 6731c5ab5618e04881ebd1d8532099fc65264e3989a29fe3951abc17b0b15420
  • Pointer size: 131 Bytes
  • Size of remote file: 654 kB
analysis_charts/peer_comparison.png CHANGED

Git LFS Details

  • SHA256: 786489e7fb5237126cf6a5f8f4428ca8b5725b4e1977b10ce2f99bc47a81cb20
  • Pointer size: 131 Bytes
  • Size of remote file: 699 kB

Git LFS Details

  • SHA256: 7bbd82205a146bd1b313011cb3919eae790daaecb0af350e5e4626df785ef26b
  • Pointer size: 131 Bytes
  • Size of remote file: 698 kB
analysis_charts/radar_code_model2vec_all_mpnet_base_v2_fine_tuned.png CHANGED

Git LFS Details

  • SHA256: 86e7b9073df9d7d0ae22e6d253a5874130707db2ab96f8800d39fd24a4a9f927
  • Pointer size: 131 Bytes
  • Size of remote file: 203 kB

Git LFS Details

  • SHA256: a0b7a9ca0656c09aaf4d067ef33c40f8dc729b6357e4830d9fb6cef7dd049844
  • Pointer size: 131 Bytes
  • Size of remote file: 180 kB
src/distiller/__main__.py CHANGED
@@ -23,9 +23,6 @@ def distill(
23
  clear_checkpoints: Annotated[
24
  bool, typer.Option(help="Clear tokenlearn checkpoints to force fresh featurization and training")
25
  ] = False,
26
- skip_ptr: Annotated[
27
- bool, typer.Option("--skip-ptr", help="Skip post-training re-regularization (PCA + SIF weighting) step")
28
- ] = False,
29
  use_optimized_dataset: Annotated[
30
  bool,
31
  typer.Option(
@@ -48,7 +45,6 @@ def distill(
48
  pca_dims,
49
  clear_cache,
50
  clear_checkpoints,
51
- skip_ptr,
52
  use_optimized_dataset,
53
  dataset_path,
54
  )
 
23
  clear_checkpoints: Annotated[
24
  bool, typer.Option(help="Clear tokenlearn checkpoints to force fresh featurization and training")
25
  ] = False,
 
 
 
26
  use_optimized_dataset: Annotated[
27
  bool,
28
  typer.Option(
 
45
  pca_dims,
46
  clear_cache,
47
  clear_checkpoints,
 
48
  use_optimized_dataset,
49
  dataset_path,
50
  )
src/distiller/config.py CHANGED
@@ -210,16 +210,14 @@ class DistillationConfig(BaseModel):
210
  apply_zipf: bool = True
211
 
212
  # Tokenlearn-specific parameters (POTION approach)
213
- tokenlearn_dataset: str = "sentence-transformers/codesearchnet" # Dataset for tokenlearn featurization
214
- tokenlearn_dataset_name: str = "pair" # Use 'pair' configuration (only available config)
215
- tokenlearn_text_key: str = (
216
- "combined_text" # Text field to use from the dataset ('combined_text' for doc-code pairs)
217
- )
218
  tokenlearn_timeout_featurize: int = 21600 # 6 hour timeout for featurization (dataset needs ~5 hours)
219
  tokenlearn_timeout_train: int = 7200 # 2 hour timeout for training
220
 
221
- # Post-training configuration
222
- skip_post_training_regularization: bool = False # Skip PCA + SIF re-regularization step
223
 
224
  # Dataset configuration
225
  use_optimized_dataset: bool = True # Use the pre-created optimized dataset from dataset.py
 
210
  apply_zipf: bool = True
211
 
212
  # Tokenlearn-specific parameters (POTION approach)
213
+ tokenlearn_dataset: str = "allenai/c4" # Dataset for tokenlearn featurization (following POTION paper)
214
+ tokenlearn_dataset_name: str = "en" # Use 'en' configuration for English text
215
+ tokenlearn_text_key: str = "text" # Text field to use from the dataset
 
 
216
  tokenlearn_timeout_featurize: int = 21600 # 6 hour timeout for featurization (dataset needs ~5 hours)
217
  tokenlearn_timeout_train: int = 7200 # 2 hour timeout for training
218
 
219
+ # Dataset sampling configuration
220
+ tokenlearn_max_samples: int = 50000 # Maximum samples to use for tokenlearn training
221
 
222
  # Dataset configuration
223
  use_optimized_dataset: bool = True # Use the pre-created optimized dataset from dataset.py
src/distiller/distill.py CHANGED
@@ -28,7 +28,6 @@ import time
28
  from pathlib import Path
29
  from typing import Annotated, Any
30
 
31
- import numpy as np
32
  import torch
33
  import typer
34
  from beam import function
@@ -410,7 +409,7 @@ def simple_distillation(
410
 
411
 
412
  def load_optimized_dataset(
413
- max_samples: int = 50000,
414
  checkpoint_manager: BeamCheckpointManager | None = None,
415
  dataset_path: str | None = None,
416
  ) -> list[str]:
@@ -424,6 +423,10 @@ def load_optimized_dataset(
424
 
425
  dataset_dir = Path(dataset_path) if dataset_path else DATASET_OUTPUT_DIR
426
 
 
 
 
 
427
  logger.info(f"🎯 Loading optimized dataset from {dataset_dir}")
428
  logger.info(f"πŸ“Š Target samples: {max_samples}")
429
 
@@ -462,12 +465,16 @@ def load_optimized_dataset(
462
 
463
 
464
  def load_codesearchnet_dataset(
465
- max_samples: int = 50000,
466
  checkpoint_manager: BeamCheckpointManager | None = None,
467
  ) -> list[str]:
468
  """Load and format the CodeSearchNet dataset for token frequency computation."""
469
  from datasets import load_dataset
470
 
 
 
 
 
471
  logger.info(f"Loading CodeSearchNet dataset from {codesearchnet_config.dataset_name}")
472
  logger.info(f"Limiting to {max_samples} samples for training efficiency")
473
  logger.info(f"Languages: {', '.join(languages_config.all)}")
@@ -732,192 +739,10 @@ def generate_teacher_embeddings(
732
  return teacher_embeddings
733
 
734
 
735
- def compute_token_frequencies_for_sif(
736
- teacher_model: SentenceTransformer,
737
- features_dir: Path,
738
- ) -> None:
739
- """
740
- Compute token frequencies from the training corpus for SIF weighting.
741
-
742
- This follows the POTION approach for post-training re-regularization.
743
- """
744
- import json
745
- from collections import Counter
746
-
747
- logger.info("πŸ“Š Computing token frequencies for SIF weighting...")
748
-
749
- try:
750
- # Load dataset to compute frequencies (limited sample for efficiency)
751
- if distillation_config.use_optimized_dataset:
752
- # Use the custom optimized dataset
753
- from .dataset import load_optimized_dataset as load_custom_dataset
754
-
755
- custom_dataset_dir = (
756
- Path(distillation_config.custom_dataset_path)
757
- if distillation_config.custom_dataset_path
758
- else Path("code_model2vec/dataset")
759
- )
760
-
761
- if custom_dataset_dir.exists() and (custom_dataset_dir / "train.parquet").exists():
762
- train_df = load_custom_dataset(output_dir=custom_dataset_dir, split="train")
763
- # Sample a subset for frequency computation
764
- sample_size = min(10000, len(train_df))
765
- train_df_sample = train_df.sample(n=sample_size, random_state=42)
766
- dataset_texts = train_df_sample["text"].tolist()
767
- logger.info(f"πŸ“Š Using {len(dataset_texts)} samples from custom optimized dataset")
768
- else:
769
- # Fallback to original dataset loading
770
- dataset_texts = load_codesearchnet_dataset(max_samples=10000)
771
- logger.info(
772
- f"πŸ“Š Custom dataset not found, using original CodeSearchNet with {len(dataset_texts)} texts"
773
- )
774
- else:
775
- dataset_texts = load_codesearchnet_dataset(max_samples=10000)
776
- logger.info(f"πŸ“Š Using original CodeSearchNet with {len(dataset_texts)} texts")
777
-
778
- logger.info(f"πŸ“Š Computing frequencies on {len(dataset_texts)} texts...")
779
-
780
- # Tokenize all texts and count token frequencies
781
- tokenizer = teacher_model.tokenizer
782
- token_counts: Counter[int] = Counter()
783
-
784
- # Process in batches to avoid memory issues
785
- batch_size = 100
786
- for i in range(0, len(dataset_texts), batch_size):
787
- batch_texts = dataset_texts[i : i + batch_size]
788
-
789
- for text in batch_texts:
790
- # Tokenize the text
791
- tokens = tokenizer.encode(text, add_special_tokens=False)
792
- token_counts.update(tokens)
793
-
794
- if i % (batch_size * 10) == 0:
795
- logger.info(f" Processed {i + len(batch_texts)}/{len(dataset_texts)} texts...")
796
-
797
- # Convert to frequencies (token_id -> count)
798
- token_frequencies = dict(token_counts)
799
-
800
- # Save token frequencies to features directory for post-training regularization
801
- freq_file = features_dir / "token_frequencies.json"
802
- with freq_file.open("w") as f:
803
- json.dump(token_frequencies, f, indent=2)
804
-
805
- logger.info(f"βœ… Token frequencies saved to {freq_file}")
806
- logger.info(f"πŸ“Š Total unique tokens: {len(token_frequencies)}")
807
- logger.info(f"πŸ“Š Total token occurrences: {sum(token_frequencies.values())}")
808
-
809
- except Exception as e:
810
- logger.warning(f"⚠️ Failed to compute token frequencies: {e}")
811
- logger.warning("⚠️ Post-training re-regularization will use default Zipf weighting")
812
-
813
-
814
- def apply_post_training_regularization(
815
- model: Any,
816
- features_dir: Path,
817
- pca_dims: int = 256,
818
- ) -> Any:
819
- """
820
- Apply post-training re-regularization following the POTION approach.
821
-
822
- This includes:
823
- 1. Token frequency weighting using corpus frequencies
824
- 2. PCA application
825
- 3. SIF weighting using formula: w = 1e-3 / (1e-3 + proba)
826
- """
827
- import json
828
-
829
- from sklearn.decomposition import PCA
830
-
831
- logger.info("πŸ”§ Starting post-training re-regularization (POTION Step 4)")
832
-
833
- # Step 4a: Load token frequencies from the training corpus
834
- logger.info("πŸ“Š Computing token frequencies from training corpus...")
835
-
836
- # Try to load token frequencies from features directory
837
- freq_file = features_dir / "token_frequencies.json"
838
-
839
- if freq_file.exists():
840
- with freq_file.open("r") as f:
841
- token_frequencies = json.load(f)
842
- logger.info(f"βœ… Loaded token frequencies from {freq_file}")
843
- else:
844
- logger.warning("⚠️ Token frequencies not found - using default Zipf weighting")
845
- # Fallback to basic frequency estimation based on rank
846
- vocab_size = model.embedding.shape[0]
847
- token_frequencies = {str(i): 1.0 / (i + 1) for i in range(vocab_size)}
848
-
849
- # Step 4b: Apply PCA to the embeddings
850
- logger.info(f"πŸ”„ Applying PCA with {pca_dims} dimensions...")
851
-
852
- # Get current embeddings
853
- # Handle both torch tensors and numpy arrays
854
- if hasattr(model.embedding, "cpu"):
855
- embeddings = model.embedding.cpu().numpy().astype(np.float64)
856
- else:
857
- embeddings = model.embedding.astype(np.float64)
858
- original_shape = embeddings.shape
859
- logger.info(f"Original embedding shape: {original_shape}")
860
-
861
- # Apply PCA if dimensions don't match
862
- if original_shape[1] != pca_dims:
863
- pca = PCA(n_components=pca_dims, random_state=42)
864
- embeddings_pca = pca.fit_transform(embeddings)
865
- logger.info(f"PCA applied: {original_shape} β†’ {embeddings_pca.shape}")
866
-
867
- # Explained variance ratio
868
- explained_var = pca.explained_variance_ratio_.sum()
869
- logger.info(f"PCA explained variance ratio: {explained_var:.4f}")
870
- else:
871
- embeddings_pca = embeddings
872
- logger.info("PCA dimensions match - no PCA transformation needed")
873
-
874
- # Step 4c: Apply SIF weighting using corpus frequencies
875
- logger.info("βš–οΈ Applying SIF weighting based on token frequencies...")
876
-
877
- # Convert token frequencies to probabilities
878
- total_tokens = sum(token_frequencies.values())
879
- token_probs = {token: freq / total_tokens for token, freq in token_frequencies.items()}
880
-
881
- # Apply SIF weighting: w = 1e-3 / (1e-3 + proba)
882
- sif_coefficient = 1e-3 # Standard SIF coefficient
883
-
884
- for i in range(embeddings_pca.shape[0]):
885
- token_id = str(i)
886
- prob = token_probs[token_id] if token_id in token_probs else 1.0 / len(token_probs)
887
-
888
- # Apply SIF weighting formula
889
- sif_weight = sif_coefficient / (sif_coefficient + prob)
890
- embeddings_pca[i] *= sif_weight
891
-
892
- logger.info("βœ… SIF weighting applied successfully")
893
-
894
- # Step 4d: Create new model with re-regularized embeddings
895
- logger.info("πŸ“¦ Creating final model with re-regularized embeddings...")
896
-
897
- # Convert back to float32 numpy array
898
- final_embeddings = embeddings_pca.astype(np.float32)
899
-
900
- # Create new model with updated embeddings
901
- from distiller.model2vec.model import StaticModel
902
-
903
- # Save tokenizer and config from original model
904
- tokenizer = model.tokenizer
905
- config = model.config
906
-
907
- # Create new model with re-regularized embeddings
908
- final_model = StaticModel(vectors=final_embeddings, tokenizer=tokenizer, config=config)
909
-
910
- logger.info("βœ… Post-training re-regularization completed successfully")
911
- logger.info(f"Final model embedding shape: {final_model.embedding.shape}")
912
-
913
- return final_model
914
-
915
-
916
  def tokenlearn_training(
917
  student_model: Any,
918
  teacher_model: SentenceTransformer,
919
  checkpoint_manager: BeamCheckpointManager | None = None, # noqa: ARG001
920
- skip_post_training_regularization: bool = False,
921
  ) -> Any:
922
  """
923
  Perform tokenlearn training following the official POTION approach.
@@ -926,7 +751,6 @@ def tokenlearn_training(
926
  1. Model2Vec distillation (already done - student_model)
927
  2. Sentence transformer inference (create features)
928
  3. Tokenlearn training
929
- 4. Post-training re-regularization (PCA + SIF weighting)
930
  """
931
  from pathlib import Path
932
 
@@ -1043,10 +867,6 @@ def tokenlearn_training(
1043
  featurization_complete_marker.touch()
1044
  logger.info(f"πŸ’Ύ Created featurization checkpoint: {featurization_complete_marker}")
1045
 
1046
- # Generate token frequencies for post-training re-regularization
1047
- logger.info("πŸ“Š Computing token frequencies for SIF weighting...")
1048
- compute_token_frequencies_for_sif(teacher_model, features_dir)
1049
-
1050
  except Exception as e:
1051
  logger.exception("πŸ’₯ Tokenlearn featurization failed")
1052
  logger.exception("πŸ’₯ Tokenlearn featurization is required for training - cannot proceed")
@@ -1191,19 +1011,9 @@ def tokenlearn_training(
1191
  logger.info("πŸ”„ Loading model from tokenlearn training...")
1192
  trained_model = StaticModel.from_pretrained(str(trained_model_path))
1193
 
1194
- # Apply post-training re-regularization (POTION Step 4) unless skipped
1195
- if skip_post_training_regularization:
1196
- logger.info("⏭️ Skipping post-training re-regularization (PCA + SIF weighting) as requested")
1197
- final_model = trained_model
1198
- logger.info("βœ… Tokenlearn training pipeline completed successfully (without re-regularization)")
1199
- else:
1200
- logger.info("πŸ”§ Applying post-training re-regularization (PCA + SIF weighting)...")
1201
- final_model = apply_post_training_regularization(
1202
- trained_model, features_dir, pca_dims=distillation_config.optimal_pca_dims
1203
- )
1204
- logger.info("βœ… Tokenlearn training pipeline with post-training re-regularization completed successfully")
1205
-
1206
- return final_model
1207
 
1208
  except ValueError as e:
1209
  if "Number of tokens" in str(e) and "does not match number of vectors" in str(e):
@@ -1366,7 +1176,6 @@ def distill_single_teacher(
1366
  base_model,
1367
  teacher_st_model,
1368
  checkpoint_mgr,
1369
- skip_post_training_regularization=distillation_config.skip_post_training_regularization,
1370
  )
1371
 
1372
  # Save final model
@@ -1706,9 +1515,6 @@ def main(
1706
  clear_checkpoints: Annotated[
1707
  bool, typer.Option(help="Clear tokenlearn checkpoints to force fresh featurization and training")
1708
  ] = False,
1709
- skip_ptr: Annotated[
1710
- bool, typer.Option("--skip-ptr", help="Skip post-training re-regularization (PCA + SIF weighting) step")
1711
- ] = False,
1712
  use_optimized_dataset: Annotated[
1713
  bool,
1714
  typer.Option(
@@ -1723,17 +1529,15 @@ def main(
1723
  """Unified distillation command with optional training."""
1724
  logger.info("πŸš€ Starting unified Model2Vec distillation workflow")
1725
 
1726
- # Set post-training regularization flag in config
1727
- distillation_config.skip_post_training_regularization = skip_ptr
1728
- if skip_ptr and train:
1729
- logger.info("⏭️ Post-training re-regularization will be skipped (PCA + SIF weighting disabled)")
1730
-
1731
  # Set dataset configuration
1732
  distillation_config.use_optimized_dataset = use_optimized_dataset
1733
  distillation_config.custom_dataset_path = dataset_path
 
1734
  if use_optimized_dataset and train:
1735
  dataset_source = dataset_path or "code_model2vec/dataset"
1736
  logger.info(f"🎯 Using optimized dataset from: {dataset_source}")
 
 
1737
 
1738
  logger.info(f"πŸŽ“ Training mode: {'Tokenlearn (POTION) training' if train else 'Basic distillation only'}")
1739
  logger.info(f"☁️ Execution: {'Beam' if use_beam else 'Local'}")
@@ -2200,7 +2004,7 @@ def _prepare_custom_dataset_for_tokenlearn(tokenlearn_dir: Path) -> tuple[str, s
2200
  if not custom_dataset_dir.exists() or not (custom_dataset_dir / "train.parquet").exists():
2201
  logger.info("πŸ“Š Custom dataset not found - creating optimized dataset...")
2202
  create_optimized_dataset(
2203
- max_samples_per_lang=10000, # Reasonable size for tokenlearn
2204
  output_dir=custom_dataset_dir,
2205
  create_multiple_formats=False, # Use simple format for tokenlearn
2206
  )
@@ -2230,14 +2034,13 @@ def _prepare_custom_dataset_for_tokenlearn(tokenlearn_dir: Path) -> tuple[str, s
2230
  return str(train_json_path), None, "text"
2231
 
2232
 
2233
- def _prepare_original_dataset_for_tokenlearn() -> tuple[str, str, str]:
2234
- """Prepare original CodeSearchNet dataset for tokenlearn featurization."""
2235
- logger.info("πŸ“Š Using original CodeSearchNet dataset for tokenlearn...")
2236
-
2237
  return (
2238
- str(distillation_config.tokenlearn_dataset), # "sentence-transformers/codesearchnet"
2239
- str(distillation_config.tokenlearn_dataset_name), # "pair"
2240
- str(distillation_config.tokenlearn_text_key), # "combined_text"
2241
  )
2242
 
2243
 
 
28
  from pathlib import Path
29
  from typing import Annotated, Any
30
 
 
31
  import torch
32
  import typer
33
  from beam import function
 
409
 
410
 
411
  def load_optimized_dataset(
412
+ max_samples: int | None = None,
413
  checkpoint_manager: BeamCheckpointManager | None = None,
414
  dataset_path: str | None = None,
415
  ) -> list[str]:
 
423
 
424
  dataset_dir = Path(dataset_path) if dataset_path else DATASET_OUTPUT_DIR
425
 
426
+ # Use configuration default if not specified
427
+ if max_samples is None:
428
+ max_samples = distillation_config.tokenlearn_max_samples
429
+
430
  logger.info(f"🎯 Loading optimized dataset from {dataset_dir}")
431
  logger.info(f"πŸ“Š Target samples: {max_samples}")
432
 
 
465
 
466
 
467
  def load_codesearchnet_dataset(
468
+ max_samples: int | None = None,
469
  checkpoint_manager: BeamCheckpointManager | None = None,
470
  ) -> list[str]:
471
  """Load and format the CodeSearchNet dataset for token frequency computation."""
472
  from datasets import load_dataset
473
 
474
+ # Use configuration default if not specified
475
+ if max_samples is None:
476
+ max_samples = distillation_config.tokenlearn_max_samples
477
+
478
  logger.info(f"Loading CodeSearchNet dataset from {codesearchnet_config.dataset_name}")
479
  logger.info(f"Limiting to {max_samples} samples for training efficiency")
480
  logger.info(f"Languages: {', '.join(languages_config.all)}")
 
739
  return teacher_embeddings
740
 
741
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
742
  def tokenlearn_training(
743
  student_model: Any,
744
  teacher_model: SentenceTransformer,
745
  checkpoint_manager: BeamCheckpointManager | None = None, # noqa: ARG001
 
746
  ) -> Any:
747
  """
748
  Perform tokenlearn training following the official POTION approach.
 
751
  1. Model2Vec distillation (already done - student_model)
752
  2. Sentence transformer inference (create features)
753
  3. Tokenlearn training
 
754
  """
755
  from pathlib import Path
756
 
 
867
  featurization_complete_marker.touch()
868
  logger.info(f"πŸ’Ύ Created featurization checkpoint: {featurization_complete_marker}")
869
 
 
 
 
 
870
  except Exception as e:
871
  logger.exception("πŸ’₯ Tokenlearn featurization failed")
872
  logger.exception("πŸ’₯ Tokenlearn featurization is required for training - cannot proceed")
 
1011
  logger.info("πŸ”„ Loading model from tokenlearn training...")
1012
  trained_model = StaticModel.from_pretrained(str(trained_model_path))
1013
 
1014
+ # Return the trained model directly
1015
+ logger.info("βœ… Tokenlearn training pipeline completed successfully")
1016
+ return trained_model
 
 
 
 
 
 
 
 
 
 
1017
 
1018
  except ValueError as e:
1019
  if "Number of tokens" in str(e) and "does not match number of vectors" in str(e):
 
1176
  base_model,
1177
  teacher_st_model,
1178
  checkpoint_mgr,
 
1179
  )
1180
 
1181
  # Save final model
 
1515
  clear_checkpoints: Annotated[
1516
  bool, typer.Option(help="Clear tokenlearn checkpoints to force fresh featurization and training")
1517
  ] = False,
 
 
 
1518
  use_optimized_dataset: Annotated[
1519
  bool,
1520
  typer.Option(
 
1529
  """Unified distillation command with optional training."""
1530
  logger.info("πŸš€ Starting unified Model2Vec distillation workflow")
1531
 
 
 
 
 
 
1532
  # Set dataset configuration
1533
  distillation_config.use_optimized_dataset = use_optimized_dataset
1534
  distillation_config.custom_dataset_path = dataset_path
1535
+
1536
  if use_optimized_dataset and train:
1537
  dataset_source = dataset_path or "code_model2vec/dataset"
1538
  logger.info(f"🎯 Using optimized dataset from: {dataset_source}")
1539
+ elif train:
1540
+ logger.info("🎯 Using C4 dataset for training (following POTION approach)")
1541
 
1542
  logger.info(f"πŸŽ“ Training mode: {'Tokenlearn (POTION) training' if train else 'Basic distillation only'}")
1543
  logger.info(f"☁️ Execution: {'Beam' if use_beam else 'Local'}")
 
2004
  if not custom_dataset_dir.exists() or not (custom_dataset_dir / "train.parquet").exists():
2005
  logger.info("πŸ“Š Custom dataset not found - creating optimized dataset...")
2006
  create_optimized_dataset(
2007
+ max_samples_per_lang=distillation_config.tokenlearn_max_samples // 6, # Divide by number of languages
2008
  output_dir=custom_dataset_dir,
2009
  create_multiple_formats=False, # Use simple format for tokenlearn
2010
  )
 
2034
  return str(train_json_path), None, "text"
2035
 
2036
 
2037
+ def _prepare_original_dataset_for_tokenlearn() -> tuple[str, str | None, str]:
2038
+ """Prepare original dataset for tokenlearn featurization (uses C4 by default following POTION approach)."""
2039
+ logger.info("πŸ“Š Using C4 dataset for tokenlearn (following POTION approach)...")
 
2040
  return (
2041
+ str(distillation_config.tokenlearn_dataset), # "allenai/c4"
2042
+ str(distillation_config.tokenlearn_dataset_name), # "en"
2043
+ str(distillation_config.tokenlearn_text_key), # "text"
2044
  )
2045
 
2046