Other
English
minecraft
action prediction
Kqte commited on
Commit
df37595
·
verified ·
1 Parent(s): dd36003

Delete model/parser_generate.py

Browse files
Files changed (1) hide show
  1. model/parser_generate.py +0 -127
model/parser_generate.py DELETED
@@ -1,127 +0,0 @@
1
- import torch
2
- import json
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
4
- from datasets import load_dataset
5
- from tqdm import tqdm
6
-
7
- device_map = "auto"
8
- model = AutoModelForCausalLM.from_pretrained(
9
- "/path/to/llamipa/adapter",
10
- return_dict=True,
11
- torch_dtype=torch.float16,
12
- device_map=device_map)
13
-
14
-
15
- tokenizer = AutoTokenizer.from_pretrained("/path/to/meta-llama3-8b/",add_eos_token=True)
16
-
17
- tokenizer.pad_token_id = tokenizer.eos_token_id + 1
18
- tokenizer.padding_side = "right"
19
-
20
- pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer, pad_token_id=tokenizer.pad_token_id, max_new_tokens=100)
21
-
22
- test_dataset = load_dataset("json", data_files={'test':'/path/to/parser_test_moves_15.jsonl'})["test"]
23
-
24
- def is_first_moves(sample):
25
- answer = 0
26
- slist = sample.split('\n')
27
- if slist[0].startswith('Context: 0 <Buil> Mission has started.'):
28
- struct = [i for i in slist if i.startswith('Structure:')]
29
- rels = struct[0].split(':')[1].strip()
30
- if len(rels) == 0:
31
- answer = 1
32
- return answer
33
-
34
-
35
- def check_endpoints(struct, head):
36
- """
37
- takes a struct string and a head int and returns only
38
- the struct rels with sources that are >= head
39
- """
40
- new_rels_list = []
41
- new_rels = None
42
- if struct:
43
- rels = struct.split(' ')
44
- for rel in rels:
45
- if len(rel) > 0:
46
- source = int(rel.split('(')[1].split(',')[0].strip())
47
- if source >= head:
48
- new_rels_list.append(rel)
49
- if len(new_rels_list) > 0:
50
- new_rels = ' '.join(new_rels_list)
51
- return new_rels
52
-
53
- def add_previous(sample, previous, predictions):
54
- new_output = []
55
- keep_str = None
56
- #get head
57
- slist = sample.split('\n')
58
- head = int(slist[0].split('Context:')[1].split('<')[0].strip())
59
- # check current structure
60
- for s in slist:
61
- if s.startswith('Structure:'):
62
- new_structure = check_endpoints(previous, head)
63
- if new_structure:
64
- s = 'Structure: ' + new_structure + ' ' + predictions
65
- keep_str = new_structure + ' ' + predictions
66
- else:
67
- s = 'Structure: ' + predictions
68
- keep_str = predictions
69
- new_output.append(s)
70
- new_output_string = '\n'.join(new_output)
71
- return keep_str, new_output_string
72
-
73
- def format_gen(preds):
74
- labels = ['COM','CONTR','CORR','QAP','ACK','ELAB','CLARIFQ','COND','CONTIN',
75
- 'RES','EXPL','QELAB','ALT','NARR','CONFQ','SEQ']
76
- split_list = [st.strip() for st in preds.split(' ')]
77
- clean_list = []
78
- for a in split_list:
79
- s_tuple = None
80
- rel = None
81
- try:
82
- s = a.split('(')[1].split(')')[0].split(',')
83
- r = a.split('(')[0].strip()
84
- except IndexError:
85
- print('split error one')
86
- else:
87
- try:
88
- s_tuple = (int(s[0]), int(s[1]))
89
- except IndexError:
90
- print('split error two')
91
- except ValueError:
92
- print('value error three')
93
- if r in labels:
94
- #make sure the label is well-formed
95
- rel = r
96
- if rel != None and s_tuple != None:
97
- clean_list.append(rel + '(' + str(s_tuple[0]) + ',' + str(s_tuple[1]) + ')')
98
- clean_preds = ' '.join(clean_list)
99
- return clean_preds
100
-
101
-
102
- def formatting_prompts_func(example):
103
- output_text = '<|begin_of_text|>Identify the discourse structure (DS) for the new turn in the following excerpt :\n' + example + '\n ### DS:'
104
- return output_text
105
-
106
-
107
- f = open("/path/to/val-output-file.txt","w")
108
-
109
- new_generations = None
110
- previous_generations = None
111
- for datum in tqdm(test_dataset['sample']):
112
-
113
- #figure out if it's a first example
114
- if is_first_moves(datum):
115
- text = formatting_prompts_func(datum)
116
- previous_generations = None
117
- else:
118
- #need to make sure head edu and relations match up
119
- update_prev, amended_text = add_previous(datum, previous_generations, new_generations)
120
- previous_generations = update_prev
121
- text = formatting_prompts_func(amended_text)
122
- generated = pipe(text)[0]['generated_text']
123
- print(generated, file=f)
124
- new_generations = format_gen(generated.split('### DS:')[1])
125
-
126
- f.close()
127
-