File size: 4,276 Bytes
f735c49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import tensorflow as tf
# hf_hub_download ํ•จ์ˆ˜๋ฅผ ์ง์ ‘ ์‚ฌ์šฉํ•˜๊ธฐ ์œ„ํ•ด import ํ•ฉ๋‹ˆ๋‹ค.
from huggingface_hub import hf_hub_download 
from PIL import Image, ImageOps
import numpy as np
import os

def load_model_from_hf(model_id, model_filename="model.keras"):
    """

    Hugging Face Hub์—์„œ Keras ๋ชจ๋ธ ํŒŒ์ผ์„ ๋‹ค์šด๋กœ๋“œํ•œ ํ›„,

    ๋กœ์ปฌ์— ์ €์žฅ๋œ ํŒŒ์ผ์„ ์ด์šฉํ•ด ๋ชจ๋ธ์„ ๋ถˆ๋Ÿฌ์˜ต๋‹ˆ๋‹ค.

    """
    try:
        print(f"Downloading model '{model_id}' from Hugging Face Hub...")
        # 1. Hugging Face Hub์—์„œ ๋ชจ๋ธ ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ
        #    ์ด ํ•จ์ˆ˜๋Š” ๋‹ค์šด๋กœ๋“œ๋œ ํŒŒ์ผ์ด ์ €์žฅ๋œ ๋กœ์ปฌ ๊ฒฝ๋กœ๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
        model_path = hf_hub_download(repo_id=model_id, filename=model_filename)
        print(f"Model downloaded to: {model_path}")
        
        # 2. ๋กœ์ปฌ์— ์ €์žฅ๋œ ๋ชจ๋ธ ํŒŒ์ผ ๋กœ๋“œ
        print("Loading model from local file...")
        model = tf.keras.models.load_model(model_path)
        print("Model loaded successfully!")
        return model
        
    except Exception as e:
        print(f"Error loading model: {e}")
        print("Please check if the model ID and filename are correct on Hugging Face Hub.")
        return None

def preprocess_image(image_path):
    """

    ์‚ฌ์šฉ์ž ์ด๋ฏธ์ง€๋ฅผ MNIST ๋ฐ์ดํ„ฐ์…‹ ํ˜•์‹์— ๋งž๊ฒŒ ์ „์ฒ˜๋ฆฌํ•ฉ๋‹ˆ๋‹ค.

    """
    try:
        # 1. ์ด๋ฏธ์ง€ ์—ด๊ธฐ
        img = Image.open(image_path)
        
        # 2. ํ‘๋ฐฑ(Grayscale)์œผ๋กœ ๋ณ€ํ™˜
        img = img.convert('L')
        
        # 3. ์ƒ‰์ƒ ๋ฐ˜์ „ (MNIST๋Š” ํฐ์ƒ‰ ๊ธ€์”จ/๊ฒ€์€ ๋ฐฐ๊ฒฝ, ์‚ฌ์šฉ์ž๋Š” ๋ณดํ†ต ๊ฒ€์€ ๊ธ€์”จ/ํฐ์ƒ‰ ๋ฐฐ๊ฒฝ)
        if np.mean(np.array(img)) > 128:
            img = ImageOps.invert(img)

        # 4. 28x28 ํฌ๊ธฐ๋กœ ๋ฆฌ์‚ฌ์ด์ฆˆ
        img = img.resize((28, 28), Image.Resampling.LANCZOS)
        
        # 5. Numpy ๋ฐฐ์—ด๋กœ ๋ณ€ํ™˜ํ•˜๊ณ  0~1 ์‚ฌ์ด ๊ฐ’์œผ๋กœ ์ •๊ทœํ™”
        img_array = np.array(img).astype('float32') / 255.0
        
        # 6. ๋ชจ๋ธ์˜ ์ž…๋ ฅ ํ˜•ํƒœ์— ๋งž๊ฒŒ ์ฐจ์› ํ™•์žฅ (1, 28, 28, 1)
        processed_img = np.expand_dims(img_array, axis=0)
        processed_img = np.expand_dims(processed_img, axis=-1)
        
        return processed_img
    except FileNotFoundError:
        print(f"Error: The file '{image_path}' was not found.")
        return None
    except Exception as e:
        print(f"Error processing image: {e}")
        return None

def main():
    # Hugging Face์— ์—…๋กœ๋“œ๋œ ๋ชจ๋ธ ID
    model_id = "OneclickAI/CNN_test_Model"
    
    # ๋ชจ๋ธ ๋กœ๋“œ (์ˆ˜์ •๋œ ํ•จ์ˆ˜ ํ˜ธ์ถœ)
    # ์ด์ „ train.py์—์„œ model.save("my_keras_model.keras")๋กœ ์ €์žฅํ–ˆ์œผ๋ฏ€๋กœ, 
    # Hub์— ์˜ฌ๋ผ๊ฐ„ ํŒŒ์ผ ์ด๋ฆ„์€ 'my_keras_model.keras'์ผ ๊ฒƒ์ž…๋‹ˆ๋‹ค. 
    # ๋งŒ์•ฝ ๋‹ค๋ฅธ ์ด๋ฆ„์œผ๋กœ ์˜ฌ๋ ธ๋‹ค๋ฉด ํ•ด๋‹น ํŒŒ์ผ๋ช…์œผ๋กœ ์ˆ˜์ •ํ•ด์ฃผ์„ธ์š”.
    # (Hugging Face Hub์—์„œ๋Š” ๋ณดํ†ต 'model.keras' ๋ผ๋Š” ์ด๋ฆ„์„ ๊ถŒ์žฅํ•ฉ๋‹ˆ๋‹ค)
    model = load_model_from_hf(model_id, model_filename="my_keras_model.keras") 
    
    if model is None:
        return # ๋ชจ๋ธ ๋กœ๋“œ ์‹คํŒจ ์‹œ ์ข…๋ฃŒ

    # ์‚ฌ์šฉ์ž์—๊ฒŒ ์ด๋ฏธ์ง€ ๊ฒฝ๋กœ๋ฅผ ๊ณ„์†ํ•ด์„œ ์ž…๋ ฅ๋ฐ›์Œ
    while True:
        user_input = input("\nPlease enter the path to your image (or type 'exit' to quit): ")
        
        if user_input.lower() == 'exit':
            break
            
        if not os.path.exists(user_input):
            print(f"File not found at '{user_input}'. Please check the path and try again.")
            continue

        # ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ
        processed_image = preprocess_image(user_input)
        
        if processed_image is not None:
            # ๋ชจ๋ธ ์˜ˆ์ธก ์ˆ˜ํ–‰
            predictions = model.predict(processed_image)
            
            # ๊ฐ€์žฅ ๋†’์€ ํ™•๋ฅ ์„ ๊ฐ€์ง„ ํด๋ž˜์Šค(์ˆซ์ž)๋ฅผ ์ฐพ์Œ
            predicted_digit = np.argmax(predictions[0])
            confidence = np.max(predictions[0]) * 100
            
            print("\n--- Prediction Result ---")
            print(f"Predicted Digit: {predicted_digit}")
            print(f"Confidence: {confidence:.2f}%")
            print("-------------------------")

if __name__ == "__main__":
    main()