Update README.md
Browse files
README.md
CHANGED
@@ -39,18 +39,30 @@ print(embeddings)
|
|
39 |
## Training Code
|
40 |
```python
|
41 |
from datasets import load_dataset, concatenate_datasets
|
42 |
-
from sentence_transformers import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
from torch.utils.data import DataLoader
|
44 |
from torch import nn
|
45 |
import random
|
46 |
|
47 |
-
word_embedding_model = models.Transformer(
|
|
|
|
|
48 |
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
|
49 |
dense_model = models.Dense(
|
50 |
-
in_features=pooling_model.get_sentence_embedding_dimension(),
|
|
|
|
|
51 |
)
|
52 |
model = SentenceTransformer(modules=[word_embedding_model, pooling_model, dense_model])
|
53 |
|
|
|
54 |
def pair():
|
55 |
def norm(x):
|
56 |
x["label"] = x["label"] / m
|
@@ -58,7 +70,9 @@ def pair():
|
|
58 |
|
59 |
dd = []
|
60 |
for sub in ["swepar", "swesim_relatedness", "swesim_similarity"]:
|
61 |
-
ds = concatenate_datasets(
|
|
|
|
|
62 |
if "sentence_1" in ds.features:
|
63 |
ds = ds.rename_column("sentence_1", "d1")
|
64 |
ds = ds.rename_column("sentence_2", "d2")
|
@@ -74,10 +88,15 @@ def pair():
|
|
74 |
train_examples.append(InputExample(texts=[d["d1"], d["d2"]], label=d["label"]))
|
75 |
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=64)
|
76 |
train_loss = losses.CosineSimilarityLoss(model)
|
77 |
-
model.fit(
|
|
|
|
|
|
|
78 |
|
79 |
def nli():
|
80 |
-
ds = concatenate_datasets(
|
|
|
|
|
81 |
|
82 |
def add_to_samples(sent1, sent2, label):
|
83 |
if sent1 not in train_data:
|
@@ -93,17 +112,33 @@ def nli():
|
|
93 |
for sent1, others in train_data.items():
|
94 |
if len(others[0]) > 0 and len(others[1]) > 0:
|
95 |
train_samples.append(
|
96 |
-
InputExample(
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
)
|
98 |
train_samples.append(
|
99 |
-
InputExample(
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
)
|
101 |
train_dataloader = datasets.NoDuplicatesDataLoader(train_samples, batch_size=64)
|
102 |
train_loss = losses.MultipleNegativesRankingLoss(model)
|
103 |
-
model.fit(
|
|
|
|
|
|
|
104 |
|
105 |
pair()
|
106 |
nli()
|
107 |
model.save()
|
108 |
|
|
|
109 |
```
|
|
|
39 |
## Training Code
|
40 |
```python
|
41 |
from datasets import load_dataset, concatenate_datasets
|
42 |
+
from sentence_transformers import (
|
43 |
+
SentenceTransformer,
|
44 |
+
InputExample,
|
45 |
+
losses,
|
46 |
+
models,
|
47 |
+
util,
|
48 |
+
datasets,
|
49 |
+
)
|
50 |
from torch.utils.data import DataLoader
|
51 |
from torch import nn
|
52 |
import random
|
53 |
|
54 |
+
word_embedding_model = models.Transformer(
|
55 |
+
"KBLab/bert-base-swedish-cased-new", max_seq_length=256
|
56 |
+
)
|
57 |
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
|
58 |
dense_model = models.Dense(
|
59 |
+
in_features=pooling_model.get_sentence_embedding_dimension(),
|
60 |
+
out_features=256,
|
61 |
+
activation_function=nn.Tanh(),
|
62 |
)
|
63 |
model = SentenceTransformer(modules=[word_embedding_model, pooling_model, dense_model])
|
64 |
|
65 |
+
|
66 |
def pair():
|
67 |
def norm(x):
|
68 |
x["label"] = x["label"] / m
|
|
|
70 |
|
71 |
dd = []
|
72 |
for sub in ["swepar", "swesim_relatedness", "swesim_similarity"]:
|
73 |
+
ds = concatenate_datasets(
|
74 |
+
[d for d in load_dataset("sbx/superlim-2", sub).values()]
|
75 |
+
)
|
76 |
if "sentence_1" in ds.features:
|
77 |
ds = ds.rename_column("sentence_1", "d1")
|
78 |
ds = ds.rename_column("sentence_2", "d2")
|
|
|
88 |
train_examples.append(InputExample(texts=[d["d1"], d["d2"]], label=d["label"]))
|
89 |
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=64)
|
90 |
train_loss = losses.CosineSimilarityLoss(model)
|
91 |
+
model.fit(
|
92 |
+
train_objectives=[(train_dataloader, train_loss)], epochs=10, warmup_steps=100
|
93 |
+
)
|
94 |
+
|
95 |
|
96 |
def nli():
|
97 |
+
ds = concatenate_datasets(
|
98 |
+
[d for d in load_dataset("sbx/superlim-2", "swenli").values()]
|
99 |
+
)
|
100 |
|
101 |
def add_to_samples(sent1, sent2, label):
|
102 |
if sent1 not in train_data:
|
|
|
112 |
for sent1, others in train_data.items():
|
113 |
if len(others[0]) > 0 and len(others[1]) > 0:
|
114 |
train_samples.append(
|
115 |
+
InputExample(
|
116 |
+
texts=[
|
117 |
+
sent1,
|
118 |
+
random.choice(list(others[0])),
|
119 |
+
random.choice(list(others[1])),
|
120 |
+
]
|
121 |
+
)
|
122 |
)
|
123 |
train_samples.append(
|
124 |
+
InputExample(
|
125 |
+
texts=[
|
126 |
+
random.choice(list(others[0])),
|
127 |
+
sent1,
|
128 |
+
random.choice(list(others[1])),
|
129 |
+
]
|
130 |
+
)
|
131 |
)
|
132 |
train_dataloader = datasets.NoDuplicatesDataLoader(train_samples, batch_size=64)
|
133 |
train_loss = losses.MultipleNegativesRankingLoss(model)
|
134 |
+
model.fit(
|
135 |
+
train_objectives=[(train_dataloader, train_loss)], epochs=1, warmup_steps=100
|
136 |
+
)
|
137 |
+
|
138 |
|
139 |
pair()
|
140 |
nli()
|
141 |
model.save()
|
142 |
|
143 |
+
|
144 |
```
|