CNN_test_Model / test.py
OneclickAI's picture
Final
f735c49 verified
raw
history blame
4.28 kB
import tensorflow as tf
# hf_hub_download ํ•จ์ˆ˜๋ฅผ ์ง์ ‘ ์‚ฌ์šฉํ•˜๊ธฐ ์œ„ํ•ด import ํ•ฉ๋‹ˆ๋‹ค.
from huggingface_hub import hf_hub_download
from PIL import Image, ImageOps
import numpy as np
import os
def load_model_from_hf(model_id, model_filename="model.keras"):
"""
Hugging Face Hub์—์„œ Keras ๋ชจ๋ธ ํŒŒ์ผ์„ ๋‹ค์šด๋กœ๋“œํ•œ ํ›„,
๋กœ์ปฌ์— ์ €์žฅ๋œ ํŒŒ์ผ์„ ์ด์šฉํ•ด ๋ชจ๋ธ์„ ๋ถˆ๋Ÿฌ์˜ต๋‹ˆ๋‹ค.
"""
try:
print(f"Downloading model '{model_id}' from Hugging Face Hub...")
# 1. Hugging Face Hub์—์„œ ๋ชจ๋ธ ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ
# ์ด ํ•จ์ˆ˜๋Š” ๋‹ค์šด๋กœ๋“œ๋œ ํŒŒ์ผ์ด ์ €์žฅ๋œ ๋กœ์ปฌ ๊ฒฝ๋กœ๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
model_path = hf_hub_download(repo_id=model_id, filename=model_filename)
print(f"Model downloaded to: {model_path}")
# 2. ๋กœ์ปฌ์— ์ €์žฅ๋œ ๋ชจ๋ธ ํŒŒ์ผ ๋กœ๋“œ
print("Loading model from local file...")
model = tf.keras.models.load_model(model_path)
print("Model loaded successfully!")
return model
except Exception as e:
print(f"Error loading model: {e}")
print("Please check if the model ID and filename are correct on Hugging Face Hub.")
return None
def preprocess_image(image_path):
"""
์‚ฌ์šฉ์ž ์ด๋ฏธ์ง€๋ฅผ MNIST ๋ฐ์ดํ„ฐ์…‹ ํ˜•์‹์— ๋งž๊ฒŒ ์ „์ฒ˜๋ฆฌํ•ฉ๋‹ˆ๋‹ค.
"""
try:
# 1. ์ด๋ฏธ์ง€ ์—ด๊ธฐ
img = Image.open(image_path)
# 2. ํ‘๋ฐฑ(Grayscale)์œผ๋กœ ๋ณ€ํ™˜
img = img.convert('L')
# 3. ์ƒ‰์ƒ ๋ฐ˜์ „ (MNIST๋Š” ํฐ์ƒ‰ ๊ธ€์”จ/๊ฒ€์€ ๋ฐฐ๊ฒฝ, ์‚ฌ์šฉ์ž๋Š” ๋ณดํ†ต ๊ฒ€์€ ๊ธ€์”จ/ํฐ์ƒ‰ ๋ฐฐ๊ฒฝ)
if np.mean(np.array(img)) > 128:
img = ImageOps.invert(img)
# 4. 28x28 ํฌ๊ธฐ๋กœ ๋ฆฌ์‚ฌ์ด์ฆˆ
img = img.resize((28, 28), Image.Resampling.LANCZOS)
# 5. Numpy ๋ฐฐ์—ด๋กœ ๋ณ€ํ™˜ํ•˜๊ณ  0~1 ์‚ฌ์ด ๊ฐ’์œผ๋กœ ์ •๊ทœํ™”
img_array = np.array(img).astype('float32') / 255.0
# 6. ๋ชจ๋ธ์˜ ์ž…๋ ฅ ํ˜•ํƒœ์— ๋งž๊ฒŒ ์ฐจ์› ํ™•์žฅ (1, 28, 28, 1)
processed_img = np.expand_dims(img_array, axis=0)
processed_img = np.expand_dims(processed_img, axis=-1)
return processed_img
except FileNotFoundError:
print(f"Error: The file '{image_path}' was not found.")
return None
except Exception as e:
print(f"Error processing image: {e}")
return None
def main():
# Hugging Face์— ์—…๋กœ๋“œ๋œ ๋ชจ๋ธ ID
model_id = "OneclickAI/CNN_test_Model"
# ๋ชจ๋ธ ๋กœ๋“œ (์ˆ˜์ •๋œ ํ•จ์ˆ˜ ํ˜ธ์ถœ)
# ์ด์ „ train.py์—์„œ model.save("my_keras_model.keras")๋กœ ์ €์žฅํ–ˆ์œผ๋ฏ€๋กœ,
# Hub์— ์˜ฌ๋ผ๊ฐ„ ํŒŒ์ผ ์ด๋ฆ„์€ 'my_keras_model.keras'์ผ ๊ฒƒ์ž…๋‹ˆ๋‹ค.
# ๋งŒ์•ฝ ๋‹ค๋ฅธ ์ด๋ฆ„์œผ๋กœ ์˜ฌ๋ ธ๋‹ค๋ฉด ํ•ด๋‹น ํŒŒ์ผ๋ช…์œผ๋กœ ์ˆ˜์ •ํ•ด์ฃผ์„ธ์š”.
# (Hugging Face Hub์—์„œ๋Š” ๋ณดํ†ต 'model.keras' ๋ผ๋Š” ์ด๋ฆ„์„ ๊ถŒ์žฅํ•ฉ๋‹ˆ๋‹ค)
model = load_model_from_hf(model_id, model_filename="my_keras_model.keras")
if model is None:
return # ๋ชจ๋ธ ๋กœ๋“œ ์‹คํŒจ ์‹œ ์ข…๋ฃŒ
# ์‚ฌ์šฉ์ž์—๊ฒŒ ์ด๋ฏธ์ง€ ๊ฒฝ๋กœ๋ฅผ ๊ณ„์†ํ•ด์„œ ์ž…๋ ฅ๋ฐ›์Œ
while True:
user_input = input("\nPlease enter the path to your image (or type 'exit' to quit): ")
if user_input.lower() == 'exit':
break
if not os.path.exists(user_input):
print(f"File not found at '{user_input}'. Please check the path and try again.")
continue
# ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ
processed_image = preprocess_image(user_input)
if processed_image is not None:
# ๋ชจ๋ธ ์˜ˆ์ธก ์ˆ˜ํ–‰
predictions = model.predict(processed_image)
# ๊ฐ€์žฅ ๋†’์€ ํ™•๋ฅ ์„ ๊ฐ€์ง„ ํด๋ž˜์Šค(์ˆซ์ž)๋ฅผ ์ฐพ์Œ
predicted_digit = np.argmax(predictions[0])
confidence = np.max(predictions[0]) * 100
print("\n--- Prediction Result ---")
print(f"Predicted Digit: {predicted_digit}")
print(f"Confidence: {confidence:.2f}%")
print("-------------------------")
if __name__ == "__main__":
main()