psaegert
commited on
Add model files
Browse files- dataset_train.yaml +2 -0
- dataset_val.yaml +2 -0
- expression_space.yaml +364 -0
- nsr.yaml +20 -0
- skeleton_pool_train.yaml +31 -0
- skeleton_pool_val.yaml +30 -0
- state_dict.pt +3 -0
- train.yaml +22 -0
dataset_train.yaml
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
skeleton_pool: /home/psaegert/Projects/flash-ansr/configs/v7.20/././skeleton_pool_train.yaml
|
2 |
+
padding: zero
|
dataset_val.yaml
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
skeleton_pool: /home/psaegert/Projects/flash-ansr/configs/v7.20/././skeleton_pool_val.yaml
|
2 |
+
padding: zero
|
expression_space.yaml
ADDED
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
operators:
|
2 |
+
+:
|
3 |
+
realization: +
|
4 |
+
alias:
|
5 |
+
- add
|
6 |
+
- plus
|
7 |
+
inverse: '-'
|
8 |
+
arity: 2
|
9 |
+
weight: 10
|
10 |
+
precedence: 1
|
11 |
+
commutative: true
|
12 |
+
symmetry: 0
|
13 |
+
positive: false
|
14 |
+
monotonicity: 0
|
15 |
+
'-':
|
16 |
+
realization: '-'
|
17 |
+
alias:
|
18 |
+
- sub
|
19 |
+
- minus
|
20 |
+
inverse: +
|
21 |
+
arity: 2
|
22 |
+
weight: 5
|
23 |
+
precedence: 1
|
24 |
+
commutative: false
|
25 |
+
symmetry: 0
|
26 |
+
positive: false
|
27 |
+
monotonicity: 0
|
28 |
+
neg:
|
29 |
+
realization: nsrops.neg
|
30 |
+
alias:
|
31 |
+
- negative
|
32 |
+
inverse: neg
|
33 |
+
arity: 1
|
34 |
+
weight: 5
|
35 |
+
precedence: 2.5
|
36 |
+
commutative: false
|
37 |
+
symmetry: -1
|
38 |
+
positive: false
|
39 |
+
monotonicity: -1
|
40 |
+
'*':
|
41 |
+
realization: '*'
|
42 |
+
alias:
|
43 |
+
- mul
|
44 |
+
- times
|
45 |
+
inverse: /
|
46 |
+
arity: 2
|
47 |
+
weight: 10
|
48 |
+
precedence: 2
|
49 |
+
commutative: true
|
50 |
+
symmetry: 0
|
51 |
+
positive: false
|
52 |
+
monotonicity: 0
|
53 |
+
/:
|
54 |
+
realization: nsrops.div
|
55 |
+
alias:
|
56 |
+
- div
|
57 |
+
- divide
|
58 |
+
inverse: '*'
|
59 |
+
arity: 2
|
60 |
+
weight: 5
|
61 |
+
precedence: 2
|
62 |
+
commutative: false
|
63 |
+
symmetry: 0
|
64 |
+
positive: false
|
65 |
+
monotonicity: 0
|
66 |
+
abs:
|
67 |
+
realization: abs
|
68 |
+
alias:
|
69 |
+
- absolute
|
70 |
+
inverse: null
|
71 |
+
arity: 1
|
72 |
+
weight: 4
|
73 |
+
precedence: 3
|
74 |
+
commutative: false
|
75 |
+
symmetry: 1
|
76 |
+
positive: true
|
77 |
+
monotonicity: 0
|
78 |
+
inv:
|
79 |
+
realization: nsrops.inv
|
80 |
+
alias:
|
81 |
+
- inverse
|
82 |
+
inverse: inv
|
83 |
+
arity: 1
|
84 |
+
weight: 4
|
85 |
+
precedence: 4
|
86 |
+
commutative: false
|
87 |
+
symmetry: -1
|
88 |
+
positive: false
|
89 |
+
monotonicity: -1
|
90 |
+
pow2:
|
91 |
+
realization: nsrops.pow2
|
92 |
+
alias:
|
93 |
+
- square
|
94 |
+
inverse: null
|
95 |
+
arity: 1
|
96 |
+
weight: 4
|
97 |
+
precedence: 3
|
98 |
+
commutative: false
|
99 |
+
symmetry: 1
|
100 |
+
positive: true
|
101 |
+
monotonicity: 0
|
102 |
+
pow3:
|
103 |
+
realization: nsrops.pow3
|
104 |
+
alias:
|
105 |
+
- cube
|
106 |
+
inverse: pow1_3
|
107 |
+
arity: 1
|
108 |
+
weight: 2
|
109 |
+
precedence: 3
|
110 |
+
commutative: false
|
111 |
+
symmetry: -1
|
112 |
+
positive: false
|
113 |
+
monotonicity: 1
|
114 |
+
pow4:
|
115 |
+
realization: nsrops.pow4
|
116 |
+
alias: []
|
117 |
+
inverse: null
|
118 |
+
arity: 1
|
119 |
+
weight: 1
|
120 |
+
precedence: 3
|
121 |
+
commutative: false
|
122 |
+
symmetry: 1
|
123 |
+
positive: true
|
124 |
+
monotonicity: 0
|
125 |
+
pow5:
|
126 |
+
realization: nsrops.pow5
|
127 |
+
alias: []
|
128 |
+
inverse: pow1_5
|
129 |
+
arity: 1
|
130 |
+
weight: 1
|
131 |
+
precedence: 3
|
132 |
+
commutative: false
|
133 |
+
symmetry: -1
|
134 |
+
positive: false
|
135 |
+
monotonicity: 1
|
136 |
+
pow1_2:
|
137 |
+
realization: nsrops.pow1_2
|
138 |
+
alias:
|
139 |
+
- sqrt
|
140 |
+
inverse: null
|
141 |
+
arity: 1
|
142 |
+
weight: 4
|
143 |
+
precedence: 3
|
144 |
+
commutative: false
|
145 |
+
symmetry: 0
|
146 |
+
positive: true
|
147 |
+
monotonicity: 1
|
148 |
+
pow1_3:
|
149 |
+
realization: nsrops.pow1_3
|
150 |
+
alias: []
|
151 |
+
inverse: null
|
152 |
+
arity: 1
|
153 |
+
weight: 2
|
154 |
+
precedence: 3
|
155 |
+
commutative: false
|
156 |
+
symmetry: -1
|
157 |
+
positive: false
|
158 |
+
monotonicity: 1
|
159 |
+
pow1_4:
|
160 |
+
realization: nsrops.pow1_4
|
161 |
+
alias: []
|
162 |
+
inverse: null
|
163 |
+
arity: 1
|
164 |
+
weight: 1
|
165 |
+
precedence: 3
|
166 |
+
commutative: false
|
167 |
+
symmetry: 0
|
168 |
+
positive: true
|
169 |
+
monotonicity: 1
|
170 |
+
pow1_5:
|
171 |
+
realization: nsrops.pow1_5
|
172 |
+
alias: []
|
173 |
+
inverse: null
|
174 |
+
arity: 1
|
175 |
+
weight: 1
|
176 |
+
precedence: 3
|
177 |
+
commutative: false
|
178 |
+
symmetry: -1
|
179 |
+
positive: false
|
180 |
+
monotonicity: 1
|
181 |
+
sin:
|
182 |
+
realization: numpy.sin
|
183 |
+
alias: []
|
184 |
+
inverse: asin
|
185 |
+
arity: 1
|
186 |
+
weight: 4
|
187 |
+
precedence: 2
|
188 |
+
commutative: false
|
189 |
+
symmetry: -1
|
190 |
+
positive: false
|
191 |
+
monotonicity: 0
|
192 |
+
cos:
|
193 |
+
realization: numpy.cos
|
194 |
+
alias: []
|
195 |
+
inverse: acos
|
196 |
+
arity: 1
|
197 |
+
weight: 4
|
198 |
+
precedence: 2
|
199 |
+
commutative: false
|
200 |
+
symmetry: 1
|
201 |
+
positive: false
|
202 |
+
monotonicity: 0
|
203 |
+
tan:
|
204 |
+
realization: numpy.tan
|
205 |
+
alias: []
|
206 |
+
inverse: atan
|
207 |
+
arity: 1
|
208 |
+
weight: 4
|
209 |
+
precedence: 2
|
210 |
+
commutative: false
|
211 |
+
symmetry: -1
|
212 |
+
positive: false
|
213 |
+
monotonicity: 0
|
214 |
+
asin:
|
215 |
+
realization: numpy.arcsin
|
216 |
+
alias:
|
217 |
+
- arcsin
|
218 |
+
inverse: sin
|
219 |
+
arity: 1
|
220 |
+
weight: 2
|
221 |
+
precedence: 2
|
222 |
+
commutative: false
|
223 |
+
symmetry: -1
|
224 |
+
positive: false
|
225 |
+
monotonicity: 1
|
226 |
+
acos:
|
227 |
+
realization: numpy.arccos
|
228 |
+
alias:
|
229 |
+
- arccos
|
230 |
+
inverse: cos
|
231 |
+
arity: 1
|
232 |
+
weight: 2
|
233 |
+
precedence: 2
|
234 |
+
commutative: false
|
235 |
+
symmetry: 0
|
236 |
+
positive: true
|
237 |
+
monotonicity: 1
|
238 |
+
atan:
|
239 |
+
realization: numpy.arctan
|
240 |
+
alias:
|
241 |
+
- arctan
|
242 |
+
inverse: tan
|
243 |
+
arity: 1
|
244 |
+
weight: 2
|
245 |
+
precedence: 2
|
246 |
+
commutative: false
|
247 |
+
symmetry: -1
|
248 |
+
positive: false
|
249 |
+
monotonicity: 1
|
250 |
+
exp:
|
251 |
+
realization: numpy.exp
|
252 |
+
alias: []
|
253 |
+
inverse: log
|
254 |
+
arity: 1
|
255 |
+
weight: 4
|
256 |
+
precedence: 3
|
257 |
+
commutative: false
|
258 |
+
symmetry: 0
|
259 |
+
positive: true
|
260 |
+
monotonicity: 1
|
261 |
+
log:
|
262 |
+
realization: numpy.log
|
263 |
+
alias:
|
264 |
+
- ln
|
265 |
+
inverse: exp
|
266 |
+
arity: 1
|
267 |
+
weight: 4
|
268 |
+
precedence: 2
|
269 |
+
commutative: false
|
270 |
+
symmetry: 0
|
271 |
+
positive: false
|
272 |
+
monotonicity: 1
|
273 |
+
mult2:
|
274 |
+
realization: nsrops.mult2
|
275 |
+
alias: []
|
276 |
+
inverse: div2
|
277 |
+
arity: 1
|
278 |
+
weight: 1
|
279 |
+
precedence: 3
|
280 |
+
commutative: false
|
281 |
+
symmetry: 0
|
282 |
+
positive: false
|
283 |
+
monotonicity: 1
|
284 |
+
mult3:
|
285 |
+
realization: nsrops.mult3
|
286 |
+
alias: []
|
287 |
+
inverse: div3
|
288 |
+
arity: 1
|
289 |
+
weight: 1
|
290 |
+
precedence: 3
|
291 |
+
commutative: false
|
292 |
+
symmetry: 0
|
293 |
+
positive: false
|
294 |
+
monotonicity: 1
|
295 |
+
mult4:
|
296 |
+
realization: nsrops.mult4
|
297 |
+
alias: []
|
298 |
+
inverse: div4
|
299 |
+
arity: 1
|
300 |
+
weight: 1
|
301 |
+
precedence: 3
|
302 |
+
commutative: false
|
303 |
+
symmetry: 0
|
304 |
+
positive: false
|
305 |
+
monotonicity: 1
|
306 |
+
mult5:
|
307 |
+
realization: nsrops.mult5
|
308 |
+
alias: []
|
309 |
+
inverse: div5
|
310 |
+
arity: 1
|
311 |
+
weight: 1
|
312 |
+
precedence: 3
|
313 |
+
commutative: false
|
314 |
+
symmetry: 0
|
315 |
+
positive: false
|
316 |
+
monotonicity: 1
|
317 |
+
div2:
|
318 |
+
realization: nsrops.div2
|
319 |
+
alias: []
|
320 |
+
inverse: mult2
|
321 |
+
arity: 1
|
322 |
+
weight: 1
|
323 |
+
precedence: 3
|
324 |
+
commutative: false
|
325 |
+
symmetry: 0
|
326 |
+
positive: false
|
327 |
+
monotonicity: 1
|
328 |
+
div3:
|
329 |
+
realization: nsrops.div3
|
330 |
+
alias: []
|
331 |
+
inverse: mult3
|
332 |
+
arity: 1
|
333 |
+
weight: 1
|
334 |
+
precedence: 3
|
335 |
+
commutative: false
|
336 |
+
symmetry: 0
|
337 |
+
positive: false
|
338 |
+
monotonicity: 1
|
339 |
+
div4:
|
340 |
+
realization: nsrops.div4
|
341 |
+
alias: []
|
342 |
+
inverse: mult4
|
343 |
+
arity: 1
|
344 |
+
weight: 1
|
345 |
+
precedence: 3
|
346 |
+
commutative: false
|
347 |
+
symmetry: 0
|
348 |
+
positive: false
|
349 |
+
monotonicity: 1
|
350 |
+
div5:
|
351 |
+
realization: nsrops.div5
|
352 |
+
alias: []
|
353 |
+
inverse: mult5
|
354 |
+
arity: 1
|
355 |
+
weight: 1
|
356 |
+
precedence: 3
|
357 |
+
commutative: false
|
358 |
+
symmetry: 0
|
359 |
+
positive: false
|
360 |
+
monotonicity: 1
|
361 |
+
variables: 3
|
362 |
+
simplification: auto_flash
|
363 |
+
simplification_kwargs:
|
364 |
+
rules_file: '{{ROOT}}/data/ansr-data/simplification_rules/v7.20.json'
|
nsr.yaml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
encoder_max_n_variables: 4
|
2 |
+
pre_encoder_input_type: ieee-754
|
3 |
+
pre_encoder_support_nan: false
|
4 |
+
encoder: SetTransformer
|
5 |
+
encoder_kwargs:
|
6 |
+
hidden_size: 512
|
7 |
+
n_enc_isab: 5
|
8 |
+
n_dec_sab: 2
|
9 |
+
n_induce: 64
|
10 |
+
n_heads: 8
|
11 |
+
layer_norm: false
|
12 |
+
n_seeds: 64
|
13 |
+
size: 512
|
14 |
+
decoder_n_heads: 8
|
15 |
+
decoder_ff_size: 512
|
16 |
+
decoder_dropout: 0.1
|
17 |
+
decoder_n_layers: 5
|
18 |
+
learnable_positional_embeddings: false
|
19 |
+
max_input_length: null
|
20 |
+
expression_space: /home/psaegert/Projects/flash-ansr/configs/v7.20/././expression_space.yaml
|
skeleton_pool_train.yaml
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
expression_space: /home/psaegert/Projects/flash-ansr/configs/v7.20/./././expression_space.yaml
|
2 |
+
holdout_pools:
|
3 |
+
- '{{ROOT}}/data/ansr-data/v7.20/skeleton_pool_val/'
|
4 |
+
- '{{ROOT}}/data/ansr-data/test_set/soose_nc/skeleton_pool/'
|
5 |
+
- '{{ROOT}}/data/ansr-data/test_set/feynman/skeleton_pool/'
|
6 |
+
- '{{ROOT}}/data/ansr-data/test_set/nguyen/skeleton_pool/'
|
7 |
+
- '{{ROOT}}/data/ansr-data/test_set/pool_15/skeleton_pool/'
|
8 |
+
sample_strategy:
|
9 |
+
n_operator_distribution: length_proportional
|
10 |
+
min_operators: 0
|
11 |
+
max_operators: 10
|
12 |
+
power: 1
|
13 |
+
max_length: 21
|
14 |
+
max_tries: 1
|
15 |
+
independent_dimensions: true
|
16 |
+
allow_nan: false
|
17 |
+
simplify: true
|
18 |
+
literal_prior: uniform
|
19 |
+
literal_prior_kwargs:
|
20 |
+
low: -5
|
21 |
+
high: 5
|
22 |
+
support_prior: uniform_intervals
|
23 |
+
support_prior_kwargs:
|
24 |
+
low: -10
|
25 |
+
high: 10
|
26 |
+
n_support_prior: uniform
|
27 |
+
n_support_prior_kwargs:
|
28 |
+
low: 16
|
29 |
+
high: 512
|
30 |
+
min_value: 16
|
31 |
+
max_value: 512
|
skeleton_pool_val.yaml
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
expression_space: /home/psaegert/Projects/flash-ansr/configs/v7.20/./././expression_space.yaml
|
2 |
+
holdout_pools:
|
3 |
+
- '{{ROOT}}/data/ansr-data/test_set/soose_nc/skeleton_pool/'
|
4 |
+
- '{{ROOT}}/data/ansr-data/test_set/feynman/skeleton_pool/'
|
5 |
+
- '{{ROOT}}/data/ansr-data/test_set/nguyen/skeleton_pool/'
|
6 |
+
- '{{ROOT}}/data/ansr-data/test_set/pool_15/skeleton_pool/'
|
7 |
+
sample_strategy:
|
8 |
+
n_operator_distribution: length_proportional
|
9 |
+
min_operators: 0
|
10 |
+
max_operators: 10
|
11 |
+
power: 1
|
12 |
+
max_length: 21
|
13 |
+
max_tries: 1
|
14 |
+
independent_dimensions: true
|
15 |
+
allow_nan: false
|
16 |
+
simplify: true
|
17 |
+
literal_prior: uniform
|
18 |
+
literal_prior_kwargs:
|
19 |
+
low: -5
|
20 |
+
high: 5
|
21 |
+
support_prior: uniform_intervals
|
22 |
+
support_prior_kwargs:
|
23 |
+
low: -10
|
24 |
+
high: 10
|
25 |
+
n_support_prior: uniform
|
26 |
+
n_support_prior_kwargs:
|
27 |
+
low: 16
|
28 |
+
high: 512
|
29 |
+
min_value: 16
|
30 |
+
max_value: 512
|
state_dict.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:23ce69742a182e664e3617d3d047322699ada0f1c5544b9c0e78e2e2349564b8
|
3 |
+
size 108657169
|
train.yaml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model: /home/psaegert/Projects/flash-ansr/configs/v7.20/./nsr.yaml
|
2 |
+
optimizer:
|
3 |
+
name: AdamW
|
4 |
+
kwargs:
|
5 |
+
lr: 1
|
6 |
+
weight_decay: 0.01
|
7 |
+
amsgrad: true
|
8 |
+
lr_scheduler:
|
9 |
+
name: WarmupLinearAnnealing
|
10 |
+
kwargs:
|
11 |
+
min_lr: 0
|
12 |
+
max_lr: 1e-4
|
13 |
+
warmup_steps: 10000
|
14 |
+
total_steps: 1500000
|
15 |
+
batch_size: 128
|
16 |
+
train_dataset: /home/psaegert/Projects/flash-ansr/configs/v7.20/./dataset_train.yaml
|
17 |
+
val_dataset: /home/psaegert/Projects/flash-ansr/configs/v7.20/./dataset_val.yaml
|
18 |
+
val_batch_size: 128
|
19 |
+
val_size: 100000
|
20 |
+
numeric_prediction_loss_weight: 0
|
21 |
+
steps: 1500000
|
22 |
+
device: cuda
|