File size: 5,113 Bytes
be99ca1
 
 
 
 
 
 
 
 
 
333e148
be99ca1
 
333e148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be99ca1
 
 
 
333e148
be99ca1
333e148
 
be99ca1
333e148
be99ca1
 
 
 
333e148
 
 
 
be99ca1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333e148
be99ca1
 
 
 
 
 
 
 
 
 
 
 
333e148
be99ca1
333e148
be99ca1
333e148
be99ca1
 
 
 
 
 
 
 
 
 
 
 
333e148
be99ca1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import gradio as gr
from darts import TimeSeries
from darts.models import TFTModel, NBEATSModel
from darts.dataprocessing.transformers import Scaler
from sklearn.preprocessing import LabelEncoder
import numpy as np
import io
import os

# ----------------------------
# SAFE DATASET LOADER
# ----------------------------
def load_dataset(path="dataset.csv", url=None):
    try:
        # Try UTF-8
        return pd.read_csv(path, encoding="utf-8")
    except UnicodeDecodeError:
        try:
            # Fallback to Latin1
            return pd.read_csv(path, encoding="latin1")
        except Exception as e:
            if url:
                # If file missing, try URL
                try:
                    return pd.read_csv(url, encoding="utf-8")
                except UnicodeDecodeError:
                    return pd.read_csv(url, encoding="latin1")
            else:
                raise e

# Path or fallback URL
url = "https://raw.githubusercontent.com/yourusername/yourrepo/main/dataset.csv"
if os.path.exists("dataset.csv"):
    df = load_dataset("dataset.csv")
else:
    df = load_dataset(url=url)

# ----------------------------
# Preprocessing
# ----------------------------
df['datetime'] = pd.to_datetime(df['datetime'])
df = df.sort_values("datetime")

# Encode weather icons
encoder = LabelEncoder()
if "icon" in df.columns:
    df['icon_encoded'] = encoder.fit_transform(df['icon'])

# Create timeseries
series = TimeSeries.from_dataframe(df, "datetime", "pv_output_kWh")
scaler = Scaler()
series_scaled = scaler.fit_transform(series)

# Pre-trained or fallback model
try:
    model = TFTModel.load_from_checkpoint("tft_pretrained", work_dir="./")
except Exception:
    model = NBEATSModel(input_chunk_length=30, output_chunk_length=7, n_epochs=10)
    model.fit(series_scaled)

# ----------------------------
# EDA FUNCTIONS
# ----------------------------
def eda_summary():
    buf = io.StringIO()
    df.describe().to_string(buf)
    return buf.getvalue()

def eda_histogram(column):
    plt.figure(figsize=(6,4))
    sns.histplot(df[column], kde=True, bins=20)
    plt.title(f"Distribution of {column}")
    plt.tight_layout()
    return plt.gcf()

def eda_correlation():
    plt.figure(figsize=(8,6))
    sns.heatmap(df.corr(), annot=True, cmap="coolwarm", fmt=".2f")
    plt.title("Correlation Heatmap")
    plt.tight_layout()
    return plt.gcf()

def eda_timeseries():
    plt.figure(figsize=(10,4))
    plt.plot(df["datetime"], df["pv_output_kWh"], label="PV Output (kWh)")
    plt.title("Time-Series Trend of PV Output")
    plt.xlabel("Date")
    plt.ylabel("PV Output (kWh)")
    plt.legend()
    plt.tight_layout()
    return plt.gcf()

# ----------------------------
# FORECAST FUNCTION
# ----------------------------
def forecast_pv(horizon, weather_condition):
    horizon_map = {"24 Hours": 24, "3 Days": 72, "7 Days": 168, "14 Days": 336}
    steps = horizon_map[horizon]

    forecast = model.predict(steps)
    forecast = scaler.inverse_transform(forecast)

    # Weather impact adjustment
    adjustment = {
        "Clear": 1.0,
        "Partly Cloudy": 0.85,
        "Cloudy": 0.65,
        "Fog": 0.55,
        "Smoke/Dust": 0.6,
        "Winter": 0.7,
        "Rain": 0.5
    }
    adj_factor = adjustment.get(weather_condition, 1.0)
    forecast_adj = forecast * adj_factor

    # Plot
    plt.figure(figsize=(10,4))
    series[-7*24:].plot(label="History")  # last week history
    forecast.plot(label="Forecast (Base)")
    forecast_adj.plot(label=f"Forecast (Adjusted: {weather_condition})")
    plt.legend()
    plt.title(f"PV Forecast for {horizon}")
    plt.tight_layout()

    # Peak info
    peak_time = forecast_adj.time_index[np.argmax(forecast_adj.values())]
    peak_val = np.max(forecast_adj.values())
    peak_info = f"🔺 Peak PV Output: {round(peak_val,2)} kWh at {peak_time}"

    return plt.gcf(), peak_info

# ----------------------------
# GRADIO DASHBOARD
# ----------------------------
eda_tab = gr.TabbedInterface(
    [
        gr.Interface(fn=eda_summary, inputs=[], outputs="text", title="Summary Stats"),
        gr.Interface(fn=eda_histogram, inputs=gr.Dropdown(df.columns, label="Select Column"), outputs="plot", title="Histogram"),
        gr.Interface(fn=eda_correlation, inputs=[], outputs="plot", title="Correlation Heatmap"),
        gr.Interface(fn=eda_timeseries, inputs=[], outputs="plot", title="Time Series Trend")
    ],
    tab_names=["Summary", "Histogram", "Correlation", "Time Series"]
)

forecast_tab = gr.Interface(
    fn=forecast_pv,
    inputs=[
        gr.Radio(["24 Hours", "3 Days", "7 Days", "14 Days"], label="Select Forecast Horizon"),
        gr.Dropdown(["Clear","Partly Cloudy","Cloudy","Fog","Smoke/Dust","Winter","Rain"], label="Weather Condition")
    ],
    outputs=[
        gr.Plot(label="Forecast Plot"),
        gr.Textbox(label="Peak Info")
    ],
    title="PV Forecasting"
)

app = gr.TabbedInterface([eda_tab, forecast_tab], tab_names=["EDA Dashboard", "Forecasting"])

if __name__ == "__main__":
    app.launch()