Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
""" | |
Created on Fri May 26 14:07:22 2023 | |
@author: vibin | |
""" | |
import streamlit as st | |
from pandasql import sqldf | |
import pandas as pd | |
import re | |
from typing import List | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
import re | |
### Main | |
nav = st.sidebar.radio("Navigation",["TAPAS","Text2SQL"]) | |
if nav == "TAPAS": | |
col1 , col2, col3 = st.columns(3) | |
col2.title("TAPAS") | |
col3 , col4 = st.columns([3,12]) | |
col4.text("Tabular Data Text Extraction using text") | |
table = pd.read_csv("data.csv") | |
table = table.astype(str) | |
st.text("DataSet - ") | |
st.dataframe(table,width=3000,height= 400) | |
st.title("") | |
lst_q = ["Which country has low medicare","Who are the patients from india","Who are the patients from india","Patients who have Edema","CUI code for diabetes patients","Patients having oxygen less than 94 but 91"] | |
v2 = st.selectbox("Choose your text",lst_q,index = 0) | |
st.title("") | |
sql_txt = st.text_area("TAPAS Input",v2) | |
if st.button("Predict"): | |
tqa = pipeline(task="table-question-answering", | |
model="google/tapas-base-finetuned-wtq") | |
txt_sql = tqa(table=table, query=sql_txt)["answer"] | |
st.text("Output - ") | |
st.success(f"{txt_sql}") | |
# st.write(all_students) | |
elif nav == "Text2SQL": | |
### Function | |
def prepare_input(question: str, table: List[str]): | |
table_prefix = "table:" | |
question_prefix = "question:" | |
join_table = ",".join(table) | |
inputs = f"{question_prefix} {question} {table_prefix} {join_table}" | |
input_ids = tokenizer(inputs, max_length=512, return_tensors="pt").input_ids | |
return input_ids | |
def inference(question: str, table: List[str]) -> str: | |
input_data = prepare_input(question=question, table=table) | |
input_data = input_data.to(model.device) | |
outputs = model.generate(inputs=input_data, num_beams=10, top_k=10, max_length=700) | |
result = tokenizer.decode(token_ids=outputs[0], skip_special_tokens=True) | |
return result | |
col1 , col2, col3 = st.columns(3) | |
col2.title("Text2SQL") | |
col3 , col4 = st.columns([1,20]) | |
col4.text("Text will be converted to SQL Query and can extract the data from DataSet") | |
# Import Data | |
df_qna = pd.read_csv("data.csv", encoding= 'unicode_escape') | |
st.title("") | |
st.text("DataSet - ") | |
st.dataframe(df_qna,width=3000,height= 500) | |
st.title("") | |
lst_q = ["what interface is measure indicator code = 72_HR_ABX and version is 1 and source is TD", "get class code with measure = 72_HR_ABX", "get sum of version for Class_Code is Antibiotic Stewardship", "what interface is measure indicator code = 72_HR_ABX"] | |
v2 = st.selectbox("Choose your text",lst_q,index = 0) | |
st.title("") | |
sql_txt = st.text_area("Text for SQL Conversion",v2) | |
if st.button("Predict"): | |
tokenizer = AutoTokenizer.from_pretrained("juierror/flan-t5-text2sql-with-schema") | |
model = AutoModelForSeq2SeqLM.from_pretrained("juierror/flan-t5-text2sql-with-schema") | |
# text = "what interface is measure indicator code = 72_HR_ABX and version is 1 and source is TD" | |
table_name = "df_qna" | |
table_col = ["Patient_Name","Country","Disease","CUI","Snomed","Oxygen_Rate","Med_Type","Admission_Date"] | |
txt_sql = inference(question=sql_txt, table=table_col) | |
### SQL Modification | |
txt_sql = txt_sql.replace("table",table_name) | |
sql_quotes = [] | |
for match in re.finditer("=",txt_sql): | |
new_txt = txt_sql[match.span()[1]+1:] | |
try: | |
match2 = re.search("AND",new_txt) | |
sql_quotes.append((new_txt[:match2.span()[0]]).strip()) | |
except: | |
sql_quotes.append(new_txt.strip()) | |
for i in sql_quotes: | |
qts = "'" + i + "'" | |
txt_sql = txt_sql.replace(i, qts) | |
st.success(f"{txt_sql}") | |
all_students = sqldf(txt_sql) | |
st.text("Output - ") | |
st.write(all_students) | |