File size: 2,988 Bytes
3eceaf9
2ab4365
2a8e7a8
3eceaf9
2a8e7a8
 
 
 
 
 
 
 
2ab4365
 
 
2a8e7a8
 
 
 
 
 
 
 
 
 
 
3eceaf9
2a8e7a8
2ab4365
 
 
 
2a8e7a8
 
 
 
 
 
 
 
 
 
 
 
 
 
2ab4365
2a8e7a8
 
 
 
 
 
3eceaf9
2ab4365
 
 
 
 
 
 
 
 
3eceaf9
2ab4365
 
 
 
 
 
 
 
 
 
2a8e7a8
 
 
 
 
 
2ab4365
 
3eceaf9
2ab4365
 
 
3eceaf9
2ab4365
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# Ensure SentencePiece is installed
try:
    import sentencepiece
except ImportError:
    st.error("SentencePiece is not installed. Please install it using: pip install sentencepiece")
    st.stop()

# Load the model and tokenizer with caching
@st.cache_resource
def load_model():
    model_name = "flax-community/t5-recipe-generation"
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
        
        # Explicitly set to CPU and use float32 to reduce memory usage
        model = model.to('cpu').float()
        
        return tokenizer, model
    except Exception as e:
        st.error(f"Error loading model: {e}")
        st.stop()

# Generate recipe function with error handling
def generate_recipe(ingredients, tokenizer, model, max_length=512):
    # Prepare input
    input_text = f"Generate recipe with: {ingredients}"
    
    try:
        # Use torch no_grad to reduce memory consumption
        with torch.no_grad():
            input_ids = tokenizer.encode(input_text, return_tensors="pt", max_length=max_length, truncation=True)
            
            # Adjust generation parameters for faster CPU inference
            output_ids = model.generate(
                input_ids, 
                max_length=max_length, 
                num_return_sequences=1, 
                no_repeat_ngram_size=2,
                num_beams=4,  # Reduced beam search for faster CPU processing
                early_stopping=True
            )
        
        # Decode and clean the output
        recipe = tokenizer.decode(output_ids[0], skip_special_tokens=True)
        return recipe
    except Exception as e:
        st.error(f"Error generating recipe: {e}")
        return None

# Streamlit app
def main():
    st.title("🍳 AI Recipe Generator")
    
    # Sidebar for input
    st.sidebar.header("Ingredient Input")
    ingredients_input = st.sidebar.text_area(
        "Enter ingredients (comma-separated):", 
        placeholder="e.g. chicken, tomatoes, onions, garlic"
    )
    
    # Load model
    tokenizer, model = load_model()
    
    # Generate button
    if st.sidebar.button("Generate Recipe"):
        if ingredients_input:
            with st.spinner("Generating recipe..."):
                recipe = generate_recipe(ingredients_input, tokenizer, model)
                
                if recipe:
                    # Display recipe sections
                    st.subheader("πŸ₯˜ Generated Recipe")
                    st.write(recipe)
                else:
                    st.error("Failed to generate recipe. Please try again.")
        else:
            st.warning("Please enter some ingredients!")

    # Additional UI elements
    st.sidebar.markdown("---")
    st.sidebar.info("Enter ingredients and click 'Generate Recipe'")

if __name__ == "__main__":
    main()