Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
# Ensure SentencePiece is installed | |
try: | |
import sentencepiece | |
except ImportError: | |
st.error("SentencePiece is not installed. Please install it using: pip install sentencepiece") | |
st.stop() | |
# Load the model and tokenizer with caching | |
def load_model(): | |
model_name = "flax-community/t5-recipe-generation" | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
# Explicitly set to CPU and use float32 to reduce memory usage | |
model = model.to('cpu').float() | |
return tokenizer, model | |
except Exception as e: | |
st.error(f"Error loading model: {e}") | |
st.stop() | |
# Generate recipe function with error handling | |
def generate_recipe(ingredients, tokenizer, model, max_length=512): | |
# Prepare input | |
input_text = f"Generate recipe with: {ingredients}" | |
try: | |
# Use torch no_grad to reduce memory consumption | |
with torch.no_grad(): | |
input_ids = tokenizer.encode(input_text, return_tensors="pt", max_length=max_length, truncation=True) | |
# Adjust generation parameters for faster CPU inference | |
output_ids = model.generate( | |
input_ids, | |
max_length=max_length, | |
num_return_sequences=1, | |
no_repeat_ngram_size=2, | |
num_beams=4, # Reduced beam search for faster CPU processing | |
early_stopping=True | |
) | |
# Decode and clean the output | |
recipe = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
return recipe | |
except Exception as e: | |
st.error(f"Error generating recipe: {e}") | |
return None | |
# Streamlit app | |
def main(): | |
st.title("π³ AI Recipe Generator") | |
# Sidebar for input | |
st.sidebar.header("Ingredient Input") | |
ingredients_input = st.sidebar.text_area( | |
"Enter ingredients (comma-separated):", | |
placeholder="e.g. chicken, tomatoes, onions, garlic" | |
) | |
# Load model | |
tokenizer, model = load_model() | |
# Generate button | |
if st.sidebar.button("Generate Recipe"): | |
if ingredients_input: | |
with st.spinner("Generating recipe..."): | |
recipe = generate_recipe(ingredients_input, tokenizer, model) | |
if recipe: | |
# Display recipe sections | |
st.subheader("π₯ Generated Recipe") | |
st.write(recipe) | |
else: | |
st.error("Failed to generate recipe. Please try again.") | |
else: | |
st.warning("Please enter some ingredients!") | |
# Additional UI elements | |
st.sidebar.markdown("---") | |
st.sidebar.info("Enter ingredients and click 'Generate Recipe'") | |
if __name__ == "__main__": | |
main() |