Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -16,7 +16,7 @@ def generate(html, entity, website_desc, datasource, year, month, title, prompt)
|
|
16 |
entity_text = entity_text + " |" + ent + "|"
|
17 |
entity_text = "entity ||| <ENTITY_CHAIN>" + entity_text + " </ENTITY_CHAIN> "
|
18 |
else:
|
19 |
-
entity_text = ""
|
20 |
website_desc_text = "Website Description: " + website_desc + " | " if website_desc != "" else ""
|
21 |
datasource_text = "Datasource: " + datasource + " | " if datasource != "" else ""
|
22 |
year_text = "Year: " + year + " | " if year != "" else ""
|
@@ -26,11 +26,12 @@ def generate(html, entity, website_desc, datasource, year, month, title, prompt)
|
|
26 |
final_prompt = html_text + year_text + month_text + website_desc_text + title_text + datasource_text + entity_text + prompt
|
27 |
|
28 |
model = AutoModelForCausalLM.from_pretrained("bs-modeling-metadata/checkpoints_all_04_23", subfolder="checkpoint-30000step")
|
29 |
-
tokenizer = AutoTokenizer.from_pretrained("bs-modeling-metadata/checkpoints_all_04_23", subfolder="tokenizer")
|
|
|
30 |
|
31 |
inputs = tokenizer(final_prompt, return_tensors="pt")
|
32 |
|
33 |
-
outputs = model.generate(**inputs, max_new_tokens=128)
|
34 |
return tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
35 |
|
36 |
|
|
|
16 |
entity_text = entity_text + " |" + ent + "|"
|
17 |
entity_text = "entity ||| <ENTITY_CHAIN>" + entity_text + " </ENTITY_CHAIN> "
|
18 |
else:
|
19 |
+
entity_text = "||| "
|
20 |
website_desc_text = "Website Description: " + website_desc + " | " if website_desc != "" else ""
|
21 |
datasource_text = "Datasource: " + datasource + " | " if datasource != "" else ""
|
22 |
year_text = "Year: " + year + " | " if year != "" else ""
|
|
|
26 |
final_prompt = html_text + year_text + month_text + website_desc_text + title_text + datasource_text + entity_text + prompt
|
27 |
|
28 |
model = AutoModelForCausalLM.from_pretrained("bs-modeling-metadata/checkpoints_all_04_23", subfolder="checkpoint-30000step")
|
29 |
+
tokenizer = AutoTokenizer.from_pretrained("bs-modeling-metadata/checkpoints_all_04_23", subfolder="tokenizer", add_prefix_space=True)
|
30 |
+
bad_words_ids = tokenizer(["<ENTITY_CHAIN>", " </ENTITY_CHAIN> "]).input_ids
|
31 |
|
32 |
inputs = tokenizer(final_prompt, return_tensors="pt")
|
33 |
|
34 |
+
outputs = model.generate(**inputs, max_new_tokens=128, bad_words_ids=bad_words_ids)
|
35 |
return tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
36 |
|
37 |
|