samwell commited on
Commit
8765359
·
verified ·
1 Parent(s): ee08504

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -3
app.py CHANGED
@@ -1,10 +1,14 @@
1
  import gradio as gr
2
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
3
  import torch
 
 
 
4
 
5
  model_name = 'synthome-fyi-paper-classification'
6
- model = AutoModelForSequenceClassification.from_pretrained(model_name)
7
- tokenizer = AutoTokenizer.from_pretrained(model_name)
 
8
 
9
  def classify_text(text):
10
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
@@ -12,7 +16,8 @@ def classify_text(text):
12
  outputs = model(**inputs)
13
  logits = outputs.logits
14
  prediction = logits.argmax().item()
15
- return prediction
 
16
 
17
  iface = gr.Interface(fn=classify_text, inputs="text", outputs="text")
18
  iface.launch()
 
1
  import gradio as gr
2
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
3
  import torch
4
+ import os
5
+
6
+ token = os.getenv('token')
7
 
8
  model_name = 'synthome-fyi-paper-classification'
9
+
10
+ model = AutoModelForSequenceClassification.from_pretrained(model_name, use_auth_token=token)
11
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=token)
12
 
13
  def classify_text(text):
14
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
 
16
  outputs = model(**inputs)
17
  logits = outputs.logits
18
  prediction = logits.argmax().item()
19
+ labels = ['AI only', 'Bio only', 'AIxBio']
20
+ return labels[prediction]
21
 
22
  iface = gr.Interface(fn=classify_text, inputs="text", outputs="text")
23
  iface.launch()