{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "collapsed_sections": [
        "3xnrF3UB6ev0"
      ],
      "gpuType": "T4"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# Model Inference"
      ],
      "metadata": {
        "id": "33C47swS80_1"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "#@title Install Dependencies\n",
        "!pip install transformers -q"
      ],
      "metadata": {
        "cellView": "form",
        "id": "noaoheUjvGbd"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#@title [Run this if using Nvidia Ampere or newer. This will significantly speed up the process]\n",
        "import torch\n",
        "torch.backends.cuda.matmul.allow_tf32 = True\n",
        "torch.backends.cudnn.allow_tf32 = True"
      ],
      "metadata": {
        "cellView": "form",
        "id": "MkGgqW87eUsQ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "NZLqjuWEtCDy"
      },
      "outputs": [],
      "source": [
        "#@title Imports\n",
        "import os\n",
        "from transformers import pipeline\n",
        "import shutil\n",
        "from PIL import Image\n",
        "import torch\n",
        "pipe = pipeline(\"image-classification\", model=\"shadowlilac/aesthetic-shadow\", device=0)"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "#@title Inference\n",
        "\n",
        "# Input image file\n",
        "single_image_file = \"image_1.png\" #@param {type:\"string\"}\n",
        "\n",
        "result = pipe(images=[single_image_file])\n",
        "\n",
        "prediction_single = result[0]\n",
        "print(\"Prediction: \" + str(round([p for p in prediction_single if p['label'] == 'hq'][0]['score'], 2)) + \"% High Quality\")\n",
        "Image.open(single_image_file)"
      ],
      "metadata": {
        "cellView": "form",
        "id": "r1R-L2r-0uo2"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Batch Mode"
      ],
      "metadata": {
        "id": "3xnrF3UB6ev0"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "#@title Batch parameters\n",
        "# Define the paths for the input folder and output folders\n",
        "input_folder = \"input_folder\" #@param {type:\"string\"}\n",
        "output_folder_hq = \"output_hq_folder\" #@param {type:\"string\"}\n",
        "output_folder_lq = \"output_lq_folder\" #@param {type:\"string\"}\n",
        "# Threshhold\n",
        "batch_hq_threshold = 0.5 #@param {type:\"number\"}\n",
        "# Define the batch size\n",
        "batch_size = 8 #@param {type:\"number\"}"
      ],
      "metadata": {
        "cellView": "form",
        "id": "VlPgrJf4wpHo"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#@title Execute Batch Job\n",
        "\n",
        "# List all image files in the input folder\n",
        "image_files = [os.path.join(input_folder, f) for f in os.listdir(input_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]\n",
        "\n",
        "# Process images in batches\n",
        "for i in range(0, len(image_files), batch_size):\n",
        "    batch = image_files[i:i + batch_size]\n",
        "\n",
        "    # Perform classification for the batch\n",
        "    results = pipe(images=batch)\n",
        "\n",
        "    for idx, result in enumerate(results):\n",
        "        # Extract the prediction scores and labels\n",
        "        predictions = result\n",
        "        hq_score = [p for p in predictions if p['label'] == 'hq'][0]['score']\n",
        "\n",
        "        # Determine the destination folder based on the prediction and threshold\n",
        "        destination_folder = output_folder_hq if hq_score >= batch_hq_threshold else output_folder_lq\n",
        "\n",
        "        # Copy the image to the appropriate folder\n",
        "        shutil.copy(batch[idx], os.path.join(destination_folder, os.path.basename(batch[idx])))\n",
        "\n",
        "print(\"Classification and sorting complete.\")"
      ],
      "metadata": {
        "cellView": "form",
        "id": "RG01mcYf4DvK"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}