Spaces:
Build error
Build error
| import torch | |
| import pickle | |
| import cv2 | |
| import os | |
| import numpy as np | |
| from PIL import Image | |
| from transformers import ViTForImageClassification, AutoImageProcessor, AdamW, ViTImageProcessor, VisionEncoderDecoderModel, AutoTokenizer | |
| from torch.utils.data import DataLoader, TensorDataset | |
| import gradio as gr | |
| model_path = '/home/user/app' | |
| train_pickle_path = 'train_data.pickle' | |
| valid_pickle_path = 'valid_data.pickle' | |
| image_directory = 'images' | |
| test_image_path = '/home/user/app/test.jpg' | |
| num_epochs = 5 # Fine-tune the model | |
| label_list = ["小白", "巧巧", "冏媽", "乖狗", "花捲", "超人", "黑胖", "橘子"] | |
| label_dictionary = {"小白": 0, "巧巧": 1, "冏媽": 2, "乖狗": 3, "花捲": 4, "超人": 5, "黑胖": 6, "橘子": 7} | |
| num_classes = len(label_dictionary) # Adjust according to your classification task | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| # device = torch.device("mps") | |
| def data_generate(dataset): | |
| images = [] | |
| labels = [] | |
| image_processor = AutoImageProcessor.from_pretrained('google/vit-large-patch16-224-in21k') | |
| for folder_name in os.listdir(image_directory): | |
| folder_path = os.path.join(image_directory, folder_name) | |
| if os.path.isdir(folder_path): | |
| for image_file in os.listdir(folder_path): | |
| if image_file.startswith(dataset): | |
| image_path = os.path.join(folder_path, image_file) | |
| # print(image_path) | |
| img = cv2.imread(image_path) | |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| img = Image.fromarray(img) | |
| img = img.resize((224, 224)) | |
| inputs = image_processor(images=img, return_tensors="pt") | |
| images.append(inputs['pixel_values'].squeeze(0).numpy()) | |
| labels.append(int(folder_name.split('_')[0])) | |
| images = np.array(images) | |
| labels = np.array(labels) | |
| # Now you can pickle this data | |
| train_data = {'img': images, 'label': labels} | |
| with open(f'{dataset}_data.pickle', 'wb') as f: | |
| pickle.dump(train_data, f) | |
| def train_model(): | |
| if not os.path.exists(valid_pickle_path): | |
| data_generate('valid') | |
| if not os.path.exists(train_pickle_path): | |
| data_generate('train') | |
| # Load the train and vaild | |
| with open("train_data.pickle", "rb") as f: | |
| train_data = pickle.load(f) | |
| with open("valid_data.pickle", "rb") as f: | |
| valid_data = pickle.load(f) | |
| # Convert the dataset into torch tensors | |
| train_inputs = torch.tensor(train_data["img"]) | |
| train_labels = torch.tensor(train_data["label"]) | |
| valid_inputs = torch.tensor(valid_data["img"]) | |
| valid_labels = torch.tensor(valid_data["label"]) | |
| # Create the TensorDataset | |
| train_dataset = TensorDataset(train_inputs, train_labels) | |
| valid_dataset = TensorDataset(valid_inputs, valid_labels) | |
| # Create the DataLoader | |
| train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True) | |
| valid_loader = DataLoader(valid_dataset, batch_size=16, shuffle=True) | |
| # Define the model and move it to the GPU | |
| model = ViTForImageClassification.from_pretrained('google/vit-large-patch16-224-in21k', num_labels=num_classes) | |
| model.to(device) | |
| # Define the optimizer | |
| optimizer = AdamW(model.parameters(), lr=1e-4) | |
| for epoch in range(num_epochs): | |
| model.train() | |
| total_loss = 0 | |
| for i, batch in enumerate(train_loader): | |
| # Move batch to the GPU | |
| batch = [r.to(device) for r in batch] | |
| # Unpack the inputs from our dataloader | |
| inputs, labels = batch | |
| # Clear out the gradients (by default they accumulate) | |
| optimizer.zero_grad() | |
| # Forward pass | |
| outputs = model(inputs, labels=labels) | |
| # Compute loss | |
| loss = outputs.loss | |
| # Backward pass | |
| loss.backward() | |
| # Update parameters and take a step using the computed gradient | |
| optimizer.step() | |
| # Update the loss | |
| total_loss += loss.item() | |
| # print(f'{i}/{len(train_loader)} ') | |
| # Get the average loss for the entire epoch | |
| avg_loss = total_loss / len(train_loader) | |
| # Print the loss | |
| print('Epoch:', epoch + 1, 'Training Loss:', avg_loss) | |
| # Evaluate the model on the validation set | |
| model.eval() | |
| total_correct = 0 | |
| for batch in valid_loader: | |
| # Move batch to the GPU | |
| batch = [t.to(device) for t in batch] | |
| # Unpack the inputs from our dataloader | |
| inputs, labels = batch | |
| # Forward pass | |
| with torch.no_grad(): | |
| outputs = model(inputs) | |
| # Get the predictions | |
| predictions = torch.argmax(outputs.logits, dim=1) | |
| # Update the total correct | |
| total_correct += torch.sum(predictions == labels) | |
| # Calculate the accuracy | |
| accuracy = total_correct / len(valid_dataset) | |
| print('Validation accuracy:', accuracy.item()) | |
| model.save_pretrained("model") | |
| def predict(upload_image): | |
| # Load the model | |
| model = ViTForImageClassification.from_pretrained(model_path, num_labels=num_classes) | |
| image_processor = AutoImageProcessor.from_pretrained('google/vit-large-patch16-224-in21k') | |
| # Load the test data | |
| # Load the image | |
| img2 = cv2.imread(test_image_path) | |
| print("cv2: ", img2) | |
| print("cv2 shape: ", img2.shape) | |
| # img = upload_image | |
| # img = cv2.cvtColor((upload_image * 255).astype(np.uint8), cv2.COLOR_RGB2BGR) | |
| pil_image = upload_image.convert('RGB') | |
| open_cv_image = np.array(pil_image) | |
| # Convert RGB to BGR | |
| img = open_cv_image[:, :, ::-1].copy() | |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| print("gradio: ", img) | |
| print("gradio shape: ", img.shape) | |
| # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| # Resize the image to 224x224 pixels | |
| img = Image.fromarray(img) | |
| img = img.resize((224, 224)) | |
| # img to tensor | |
| # Preprocess the image and generate features | |
| inputs = image_processor(images=img, return_tensors="pt") | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| probabilities = torch.nn.functional.softmax(logits, dim=-1) | |
| predicted_class_idx = logits.argmax(-1).item() | |
| return label_list[predicted_class_idx] if probabilities.max().item() > 0.90 else '不是校狗' | |
| def captioning(upload_image): | |
| model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning") | |
| feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning") | |
| tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| max_length = 16 | |
| num_beams = 4 | |
| gen_kwargs = {"max_length": max_length, "num_beams": num_beams} | |
| images = [] | |
| # for image_path in [test_image_path]: | |
| # i_image = Image.open(image_path) | |
| # if i_image.mode != "RGB": | |
| # i_image = i_image.convert(mode="RGB") | |
| # images.append(i_image) | |
| pil_image = upload_image.convert('RGB') | |
| open_cv_image = np.array(pil_image) | |
| # Convert RGB to BGR | |
| img = open_cv_image[:, :, ::-1].copy() | |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| images.append(img) | |
| pixel_values = feature_extractor(images=images, return_tensors="pt").pixel_values | |
| pixel_values = pixel_values.to(device) | |
| output_ids = model.generate(pixel_values, **gen_kwargs) | |
| preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True) | |
| preds = [pred.strip() for pred in preds] | |
| return preds[-1] | |
| def output(predict_class, caption): | |
| conj = ['are', 'is', 'dog'] | |
| if predict_class == '不是校狗' or caption.find('dog') == -1: | |
| print(f'{caption} ({predict_class})') | |
| return f'{caption} ({predict_class})' | |
| else: | |
| for c in conj: | |
| if caption.find(c) != -1: | |
| print(f'{predict_class} is{caption[caption.find(c) + len(c):]}') | |
| return f'{predict_class} is{caption[caption.find(c) + len(c):]}' | |
| print(f'{caption} ({predict_class})') | |
| return f'{caption} ({predict_class})' | |
| if __name__ == '__main__': | |
| if not os.path.exists(model_path): | |
| train_model() | |
| # output(predict(), captioning()) | |
| def get_result(upload_image): | |
| result = output(predict(upload_image), captioning(upload_image)) | |
| return result | |
| iface = gr.Interface(fn=get_result, inputs=gr.Image(type="pil"), outputs="text") | |
| iface.launch() |