charlie572 commited on
Commit
b8dde21
·
1 Parent(s): 4f18f83

Add tool to get similar articles

Browse files
Files changed (1) hide show
  1. app.py +46 -2
app.py CHANGED
@@ -1,5 +1,12 @@
 
1
  import gradio as gr
2
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
 
 
 
 
 
 
3
 
4
  def summarize(text):
5
  checkpoint = "sshleifer/distilbart-cnn-12-6"
@@ -28,9 +35,46 @@ def generate_question(text):
28
  return question
29
 
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  summarize_interface = gr.Interface(fn=summarize, inputs="text", outputs="text")
32
  question_interface = gr.Interface(fn=generate_question, inputs="text", outputs="text")
 
33
  tabs = gr.TabbedInterface(
34
- [summarize_interface, question_interface], ["Summarize an article", "Generate a question"]
 
35
  )
36
  tabs.launch()
 
1
+ import torch
2
  import gradio as gr
3
+ from transformers import (
4
+ AutoModelForSeq2SeqLM,
5
+ AutoTokenizer,
6
+ AutoModelForTokenClassification,
7
+ )
8
+ import googlesearch
9
+
10
 
11
  def summarize(text):
12
  checkpoint = "sshleifer/distilbart-cnn-12-6"
 
35
  return question
36
 
37
 
38
+ def get_similar_articles(text):
39
+ tokenizer = AutoTokenizer.from_pretrained("yanekyuk/bert-keyword-extractor")
40
+ model = AutoModelForTokenClassification.from_pretrained("yanekyuk/bert-keyword-extractor")
41
+
42
+ inputs = tokenizer(text, truncation=True, return_tensors="pt")
43
+ outputs = model(**inputs)
44
+
45
+ keyword_tokens = []
46
+ current_keyword_tokens = []
47
+ for token, logits in zip(inputs.input_ids[0], outputs.logits[0]):
48
+ token_type = torch.argmax(logits).item()
49
+ if token_type > 0:
50
+ current_keyword_tokens.append(token.item())
51
+ elif len(current_keyword_tokens) > 0:
52
+ keyword_tokens.append(current_keyword_tokens)
53
+ current_keyword_tokens = []
54
+
55
+ keywords = tokenizer.batch_decode(keyword_tokens)
56
+ keywords = list(set(keywords))
57
+
58
+ similar_websites = []
59
+ for keyword in keywords[:3]:
60
+ websites = googlesearch.search(
61
+ keyword,
62
+ tld="com",
63
+ lang="en",
64
+ num=3,
65
+ stop=3,
66
+ pause=0.5,
67
+ )
68
+ similar_websites += list(websites)
69
+
70
+ return "\n".join(similar_websites)
71
+
72
+
73
  summarize_interface = gr.Interface(fn=summarize, inputs="text", outputs="text")
74
  question_interface = gr.Interface(fn=generate_question, inputs="text", outputs="text")
75
+ similar_articles_interface = gr.Interface(fn=get_similar_articles, inputs="text", outputs="text")
76
  tabs = gr.TabbedInterface(
77
+ [summarize_interface, question_interface, similar_articles_interface],
78
+ ["Summarize an article", "Generate a question", "Get similar articles"],
79
  )
80
  tabs.launch()