Spaces:
Sleeping
Sleeping
import gradio as gr | |
import laion_clap | |
from qdrant_client import QdrantClient | |
from settings import QDRANT_KEY, QDRANT_URL, ENVIRONMENT | |
# Loading the Qdrant DB in local ################################################################### | |
if ENVIRONMENT == "PROD": | |
qdrant_client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_KEY) | |
else: | |
qdrant_client = QdrantClient("localhost", port=6333) | |
print("[INFO] Client created...") | |
# loading the model | |
print("[INFO] Loading the model...") | |
model_name = "laion/larger_clap_music" | |
model = laion_clap.CLAP_Module(enable_fusion=False) | |
model.load_ckpt() # download the default pretrained checkpoint. | |
# Gradio Interface ################################################################################# | |
max_results = 10 | |
def sound_search(query): | |
text_embed = model.get_text_embedding([query, ''])[0] # trick because can't accept singleton | |
hits = qdrant_client.search( | |
collection_name="music_db", | |
query_vector=text_embed, | |
limit=max_results, | |
) | |
return [ | |
gr.Audio( | |
hit.payload['s3_url'], | |
label=f"score: {hit.score}") | |
for hit in hits | |
] * 3 | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
"""# Sound search database """ | |
) | |
inp = gr.Textbox(placeholder="What sound are you looking for ?") | |
out = [gr.Audio(label=f"{x}") for x in range(3)] # Necessary to have different objs | |
inp.change(sound_search, inp, out) | |
demo.launch() | |