File size: 2,723 Bytes
b8dde21
a7f0432
b8dde21
 
 
 
 
 
 
a7f0432
69967c7
a258be0
 
 
 
 
 
 
 
a7f0432
6c31d06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b8dde21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c31d06
 
b8dde21
6c31d06
b8dde21
 
6c31d06
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
import gradio as gr
from transformers import (
    AutoModelForSeq2SeqLM, 
    AutoTokenizer, 
    AutoModelForTokenClassification,
)
import googlesearch


def summarize(text):
    checkpoint = "sshleifer/distilbart-cnn-12-6"
    tokenizer = AutoTokenizer.from_pretrained(checkpoint)
    model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)

    inputs = tokenizer(text, truncation=True, return_tensors="pt").input_ids
    outputs = model.generate(inputs, max_new_tokens=100, do_sample=False)

    return tokenizer.decode(outputs[0], skip_special_tokens=True)


def generate_question(text):
    checkpoint = "mrm8488/t5-base-finetuned-question-generation-ap"
    tokenizer = AutoTokenizer.from_pretrained(checkpoint, use_fast=False)
    model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)

    prompt = f"answer: {text} context: {text}"
    inputs = tokenizer(prompt, truncation=True, return_tensors="pt").input_ids
    outputs = model.generate(inputs)

    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

    _, question = generated_text.split("question: ")

    return question


def get_similar_articles(text):
    tokenizer = AutoTokenizer.from_pretrained("yanekyuk/bert-keyword-extractor")
    model = AutoModelForTokenClassification.from_pretrained("yanekyuk/bert-keyword-extractor")

    inputs = tokenizer(text, truncation=True, return_tensors="pt")
    outputs = model(**inputs)

    keyword_tokens = []
    current_keyword_tokens = []
    for token, logits in zip(inputs.input_ids[0], outputs.logits[0]):
        token_type = torch.argmax(logits).item()
        if token_type > 0:
            current_keyword_tokens.append(token.item())
        elif len(current_keyword_tokens) > 0:
            keyword_tokens.append(current_keyword_tokens)
            current_keyword_tokens = []
    
    keywords = tokenizer.batch_decode(keyword_tokens)
    keywords = list(set(keywords))

    similar_websites = []
    for keyword in keywords[:3]:
        websites = googlesearch.search(
            keyword,
            tld="com",
            lang="en",
            num=3,
            stop=3,
            pause=0.5,
        )
        similar_websites += list(websites)
    
    return "\n".join(similar_websites)


summarize_interface = gr.Interface(fn=summarize, inputs="text", outputs="text")
question_interface = gr.Interface(fn=generate_question, inputs="text", outputs="text")
similar_articles_interface = gr.Interface(fn=get_similar_articles, inputs="text", outputs="text")
tabs = gr.TabbedInterface(
    [summarize_interface, question_interface, similar_articles_interface], 
    ["Summarize an article", "Generate a question", "Get similar articles"],
)
tabs.launch()