daekeun-ml commited on
Commit
e1d96f5
ยท
1 Parent(s): dce3b70

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +161 -0
README.md ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ language:
4
+ - ko
5
+ pipeline_tag: feature-extraction
6
+ ---
7
+
8
+ # KoSimCSE Training on Amazon SageMaker
9
+
10
+
11
+ ## Usage
12
+
13
+ ```python
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from torch import Tensor
18
+ from transformers import AutoConfig, PretrainedConfig, PreTrainedModel
19
+ from transformers import AutoModel, AutoTokenizer, logging
20
+
21
+ class SimCSEConfig(PretrainedConfig):
22
+ def __init__(self, version=1.0, **kwargs):
23
+ self.version = version
24
+ super().__init__(**kwargs)
25
+
26
+ class SimCSEModel(PreTrainedModel):
27
+ config_class = SimCSEConfig
28
+
29
+ def __init__(self, config):
30
+ super().__init__(config)
31
+ self.backbone = AutoModel.from_pretrained(config.base_model)
32
+ self.hidden_size: int = self.backbone.config.hidden_size
33
+ self.dense = nn.Linear(self.hidden_size, self.hidden_size)
34
+ self.activation = nn.Tanh()
35
+
36
+ def forward(
37
+ self,
38
+ input_ids: Tensor,
39
+ attention_mask: Tensor = None,
40
+ # RoBERTa variants don't have token_type_ids, so this argument is optional
41
+ token_type_ids: Tensor = None,
42
+ ) -> Tensor:
43
+ # shape of input_ids: (batch_size, seq_len)
44
+ # shape of attention_mask: (batch_size, seq_len)
45
+ outputs: BaseModelOutputWithPoolingAndCrossAttentions = self.backbone(
46
+ input_ids=input_ids,
47
+ attention_mask=attention_mask,
48
+ token_type_ids=token_type_ids,
49
+ )
50
+
51
+ emb = outputs.last_hidden_state[:, 0]
52
+
53
+ if self.training:
54
+ emb = self.dense(emb)
55
+ emb = self.activation(emb)
56
+
57
+ return emb
58
+
59
+ def show_embedding_score(tokenizer, model, sentences):
60
+ inputs = tokenizer(sentences, padding=True, truncation=True, return_tensors="pt")
61
+ embeddings = model(**inputs)
62
+ score01 = cal_score(embeddings[0,:], embeddings[1,:])
63
+ score02 = cal_score(embeddings[0,:], embeddings[2,:])
64
+ print(score01, score02)
65
+
66
+ def cal_score(a, b):
67
+ if len(a.shape) == 1: a = a.unsqueeze(0)
68
+ if len(b.shape) == 1: b = b.unsqueeze(0)
69
+ a_norm = a / a.norm(dim=1)[:, None]
70
+ b_norm = b / b.norm(dim=1)[:, None]
71
+ return torch.mm(a_norm, b_norm.transpose(0, 1)) * 100
72
+
73
+ # Load pre-trained model
74
+ model = SimCSEModel.from_pretrained("daekeun-ml/KoSimCSE-unsupervised-roberta-large")
75
+ tokenizer = AutoTokenizer.from_pretrained("daekeun-ml/KoSimCSE-unsupervised-roberta-large")
76
+
77
+ # Inference example
78
+ sentences = ['์ด๋ฒˆ ์ฃผ ์ผ์š”์ผ์— ๋ถ„๋‹น ์ด๋งˆํŠธ ์ ์€ ๋ฌธ์„ ์—ฌ๋‚˜์š”?',
79
+ '์ผ์š”์ผ์— ๋ถ„๋‹น ์ด๋งˆํŠธ๋Š” ๋ฌธ ์—ด์–ด์š”?',
80
+ '๋ถ„๋‹น ์ด๋งˆํŠธ ์ ์€ ํ† ์š”์ผ์— ๋ช‡ ์‹œ๊นŒ์ง€ ํ•˜๋‚˜์š”']
81
+
82
+ show_embedding_score(tokenizer, model.cpu(), sentences)
83
+ ```
84
+
85
+
86
+ ## Introduction
87
+
88
+ [SimCSE](https://aclanthology.org/2021.emnlp-main.552/) is a highly efficient and innovative embedding technique based on the concept of contrastive learning. Unsupervised learning can be performed without the need to prepare ground-truth labels, and high-performance supervised learning can be performed if a good NLI (Natural Language Inference) dataset is prepared. The concept is very simple and the psudeo-code is intuitive, so the implementation is not difficult, but I have seen many people still struggle to train this model.
89
+
90
+ The official implementation code from the authors of the paper is publicly available, but it is not suitable for a step-by-step implementation. Therefore, we have reorganized the code based on [Simple-SIMCSE's GitHub](https://github.com/hppRC/simple-simcse) so that even ML beginners can train the model from the scratch with a step-by-step implementation. It's minimalist code for beginners, but data scientists and ML engineers can also make good use of it.
91
+
92
+ ### Added over Simple-SimCSE
93
+ - Added the Supervised Learning part, which shows you step-by-step how to construct the training dataset.
94
+ - Added Distributed Learning Logic. If you have a multi-GPU setup, you can train faster.
95
+ - Added SageMaker Training. `ml.g4dn.xlarge` trains well, but we recommend `ml.g4dn.12xlarge` or` ml.g5.12xlarge` for faster training.
96
+
97
+ ## Requirements
98
+ We recommend preparing an Amazon SageMaker instance with the specifications below to perform this hands-on.
99
+
100
+ ### SageMaker Notebook instance
101
+ - `ml.g4dn.xlarge`
102
+
103
+ ### SageMaker Training instance
104
+ - `ml.g4dn.xlarge` (Minimum)
105
+ - `ml.g5.12xlarge` (Recommended)
106
+
107
+ ## Datasets
108
+
109
+ For supervised learning, you need an NLI dataset that specifies the relationship between the two sentences. For unsupervised learning, we recommend using wikipedia raw data separated into sentences. This hands-on uses the dataset registered with huggingface, but you can also configure your own dataset.
110
+
111
+ The datasets used in this hands-on are as follows
112
+
113
+ #### Supervised
114
+ - [Klue-NLI](https://huggingface.co/datasets/klue/viewer/nli/)
115
+ - [Kor-NLI](https://huggingface.co/datasets/kor_nli)
116
+
117
+ #### Unsupervised
118
+ - [kowiki-sentences](https://huggingface.co/datasets/heegyu/kowiki-sentences): Data from 20221001 Korean wiki split into sentences using kss (backend=mecab) morphological analyzer.
119
+
120
+ ## How to train
121
+ - See https://github.com/daekeun-ml/KoSimCSE-SageMaker
122
+
123
+ ## Performance
124
+ We trained with parameters similar to those in the paper and did not perform any parameter tuning. Higher max sequence length does not guarantee higher performance; building a good NLI dataset is more important
125
+
126
+ ```json
127
+ {
128
+ "batch_size": 64,
129
+ "num_epochs": 1 (for unsupervised training), 3 (for supervised training)
130
+ "lr": 3e-05,
131
+ "num_warmup_steps": 0,
132
+ "temperature": 0.05,
133
+ "lr_scheduler_type": "linear",
134
+ "max_seq_len": 32,
135
+ "use_fp16": "True",
136
+ }
137
+ ```
138
+
139
+
140
+ ### KLUE-STS
141
+ | Model | Avg | Cosine Pearson | Cosine Spearman | Euclidean Pearson | Euclidean Spearman | Manhattan Pearson | Manhattan Spearman | Dot Pearson | Dot Spearman |
142
+ |------------------------|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|
143
+ | KoSimCSE-RoBERTa-base (Unsupervised) | 81.17 | 81.27 | 80.96 | 81.70 | 80.97 | 81.63 | 80.89 | 81.12 | 80.81 |
144
+ | KoSimCSE-RoBERTa-base (Supervised) | 84.19 | 83.04 | 84.46 | 84.97 | 84.50 | 84.95 | 84.45 | 82.88 | 84.28 |
145
+ | KoSimCSE-RoBERTa-large (Unsupervised) | 81.96 | 82.09 | 81.71 | 82.45 | 81.73 | 82.42 | 81.69 | 81.98 | 81.58 |
146
+ | KoSimCSE-RoBERTa-large (Supervised) | 85.37 | 84.38 | 85.99 | 85.97 | 85.81 | 86.00 | 85.79 | 83.87 | 85.15 |
147
+
148
+ ### Kor-STS
149
+ | Model | Avg | Cosine Pearson | Cosine Spearman | Euclidean Pearson | Euclidean Spearman | Manhattan Pearson | Manhattan Spearman | Dot Pearson | Dot Spearman |
150
+ |------------------------|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|
151
+ | KoSimCSE-RoBERTa-base (Unsupervised) | 81.20 | 81.53 | 81.17 | 80.89 | 81.20 | 80.93 | 81.22 | 81.48 | 81.14 |
152
+ | KoSimCSE-RoBERTa-base (Supervised) | 85.33 | 85.16 | 85.46 | 85.37 | 85.45 | 85.31 | 85.37 | 85.13 | 85.41 |
153
+ | KoSimCSE-RoBERTa-large (Unsupervised) | 81.71 | 82.10 | 81.78 | 81.12 | 81.78 | 81.15 | 81.80 | 82.15 | 81.80 |
154
+ | KoSimCSE-RoBERTa-large (Supervised) | 85.54 | 85.41 | 85.78 | 85.18 | 85.51 | 85.26 | 85.61 | 85.70 | 85.90 |
155
+
156
+
157
+ ## References
158
+ - Simple-SimCSE: https://github.com/hppRC/simple-simcse
159
+ - KoSimCSE: https://github.com/BM-K/KoSimCSE-SKT
160
+ - SimCSE (official): https://github.com/princeton-nlp/SimCSE
161
+ - SimCSE paper: https://aclanthology.org/2021.emnlp-main.552