Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
from data_processing import DataProcessor
|
5 |
+
from model_training import ModelTrainer
|
6 |
+
from visualizations import Visualizer
|
7 |
+
from utils import load_data, get_feature_names, save_model, load_saved_model, list_saved_models
|
8 |
+
import warnings
|
9 |
+
warnings.filterwarnings('ignore')
|
10 |
+
|
11 |
+
st.set_page_config(
|
12 |
+
page_title="ML Pipeline for Purple Teaming",
|
13 |
+
page_icon="🛡️",
|
14 |
+
layout="wide"
|
15 |
+
)
|
16 |
+
|
17 |
+
def main():
|
18 |
+
st.title("🛡️ ML Pipeline for Cybersecurity Purple Teaming")
|
19 |
+
|
20 |
+
# Sidebar
|
21 |
+
st.sidebar.header("Pipeline Configuration")
|
22 |
+
|
23 |
+
# File upload
|
24 |
+
uploaded_file = st.sidebar.file_uploader(
|
25 |
+
"Upload Dataset (CSV/JSON)",
|
26 |
+
type=['csv', 'json']
|
27 |
+
)
|
28 |
+
|
29 |
+
if uploaded_file is not None:
|
30 |
+
try:
|
31 |
+
df = load_data(uploaded_file)
|
32 |
+
|
33 |
+
# Initialize components
|
34 |
+
processor = DataProcessor()
|
35 |
+
trainer = ModelTrainer()
|
36 |
+
visualizer = Visualizer()
|
37 |
+
|
38 |
+
# Data Processing Section
|
39 |
+
st.header("1. Data Processing")
|
40 |
+
col1, col2 = st.columns(2)
|
41 |
+
|
42 |
+
with col1:
|
43 |
+
st.subheader("Dataset Overview")
|
44 |
+
st.write(f"Shape: {df.shape}")
|
45 |
+
st.write("Sample Data:")
|
46 |
+
st.dataframe(df.head())
|
47 |
+
|
48 |
+
with col2:
|
49 |
+
st.subheader("Data Statistics")
|
50 |
+
st.write(df.describe())
|
51 |
+
|
52 |
+
# Feature Engineering Configuration
|
53 |
+
st.header("2. Feature Engineering")
|
54 |
+
col3, col4 = st.columns(2)
|
55 |
+
|
56 |
+
with col3:
|
57 |
+
# Basic preprocessing
|
58 |
+
handling_strategy = st.selectbox(
|
59 |
+
"Missing Values Strategy",
|
60 |
+
["mean", "median", "most_frequent", "constant"]
|
61 |
+
)
|
62 |
+
scaling_method = st.selectbox(
|
63 |
+
"Scaling Method",
|
64 |
+
["standard", "minmax", "robust"]
|
65 |
+
)
|
66 |
+
|
67 |
+
# Advanced Feature Engineering
|
68 |
+
st.subheader("Advanced Features")
|
69 |
+
use_polynomial = st.checkbox("Use Polynomial Features")
|
70 |
+
if use_polynomial:
|
71 |
+
poly_degree = st.slider("Polynomial Degree", 2, 5, 2)
|
72 |
+
|
73 |
+
use_feature_selection = st.checkbox("Use Feature Selection")
|
74 |
+
if use_feature_selection:
|
75 |
+
k_best_features = st.slider("Number of Best Features", 5, 50, 10)
|
76 |
+
|
77 |
+
with col4:
|
78 |
+
use_pca = st.checkbox("Use PCA")
|
79 |
+
if use_pca:
|
80 |
+
n_components = st.slider("PCA Components (%)", 1, 100, 95) / 100.0
|
81 |
+
|
82 |
+
add_cyber_features = st.checkbox("Add Cybersecurity Features")
|
83 |
+
|
84 |
+
feature_cols = st.multiselect(
|
85 |
+
"Select Features",
|
86 |
+
get_feature_names(df),
|
87 |
+
default=get_feature_names(df)
|
88 |
+
)
|
89 |
+
target_col = st.selectbox(
|
90 |
+
"Select Target Column",
|
91 |
+
df.columns.tolist()
|
92 |
+
)
|
93 |
+
|
94 |
+
# Create feature engineering config
|
95 |
+
feature_engineering_config = {
|
96 |
+
'use_polynomial': use_polynomial,
|
97 |
+
'poly_degree': poly_degree if use_polynomial else None,
|
98 |
+
'use_feature_selection': use_feature_selection,
|
99 |
+
'k_best_features': k_best_features if use_feature_selection else None,
|
100 |
+
'use_pca': use_pca,
|
101 |
+
'n_components': n_components if use_pca else None,
|
102 |
+
'add_cyber_features': add_cyber_features
|
103 |
+
}
|
104 |
+
|
105 |
+
# Model Configuration Section
|
106 |
+
st.header("3. Model Configuration")
|
107 |
+
col5, col6 = st.columns(2)
|
108 |
+
|
109 |
+
with col5:
|
110 |
+
n_estimators = st.slider(
|
111 |
+
"Number of Trees",
|
112 |
+
min_value=10,
|
113 |
+
max_value=500,
|
114 |
+
value=100
|
115 |
+
)
|
116 |
+
max_depth = st.slider(
|
117 |
+
"Max Depth",
|
118 |
+
min_value=1,
|
119 |
+
max_value=50,
|
120 |
+
value=10
|
121 |
+
)
|
122 |
+
|
123 |
+
with col6:
|
124 |
+
min_samples_split = st.slider(
|
125 |
+
"Min Samples Split",
|
126 |
+
min_value=2,
|
127 |
+
max_value=20,
|
128 |
+
value=2
|
129 |
+
)
|
130 |
+
min_samples_leaf = st.slider(
|
131 |
+
"Min Samples Leaf",
|
132 |
+
min_value=1,
|
133 |
+
max_value=10,
|
134 |
+
value=1
|
135 |
+
)
|
136 |
+
|
137 |
+
if st.button("Train Model"):
|
138 |
+
with st.spinner("Processing data and training model..."):
|
139 |
+
# Process data with feature engineering
|
140 |
+
X_train, X_test, y_train, y_test = processor.process_data(
|
141 |
+
df,
|
142 |
+
feature_cols,
|
143 |
+
target_col,
|
144 |
+
handling_strategy,
|
145 |
+
scaling_method,
|
146 |
+
feature_engineering_config
|
147 |
+
)
|
148 |
+
|
149 |
+
# Train model
|
150 |
+
model, metrics = trainer.train_model(
|
151 |
+
X_train, X_test, y_train, y_test,
|
152 |
+
n_estimators=n_estimators,
|
153 |
+
max_depth=max_depth,
|
154 |
+
min_samples_split=min_samples_split,
|
155 |
+
min_samples_leaf=min_samples_leaf
|
156 |
+
)
|
157 |
+
|
158 |
+
# Results Section
|
159 |
+
st.header("4. Results and Visualizations")
|
160 |
+
col7, col8 = st.columns(2)
|
161 |
+
|
162 |
+
with col7:
|
163 |
+
st.subheader("Model Performance Metrics")
|
164 |
+
for metric, value in metrics.items():
|
165 |
+
st.metric(metric, f"{value:.4f}")
|
166 |
+
|
167 |
+
# Add model export section
|
168 |
+
st.subheader("Export Model")
|
169 |
+
model_name = st.text_input("Model Name (optional)")
|
170 |
+
if st.button("Save Model"):
|
171 |
+
try:
|
172 |
+
# Save model and metadata
|
173 |
+
preprocessing_params = {
|
174 |
+
'feature_engineering_config': feature_engineering_config,
|
175 |
+
'handling_strategy': handling_strategy,
|
176 |
+
'scaling_method': scaling_method
|
177 |
+
}
|
178 |
+
|
179 |
+
model_path, metadata_path = save_model(
|
180 |
+
model,
|
181 |
+
feature_cols,
|
182 |
+
preprocessing_params,
|
183 |
+
metrics,
|
184 |
+
model_name
|
185 |
+
)
|
186 |
+
|
187 |
+
st.success(f"Model saved successfully! Files:\n- {model_path}\n- {metadata_path}")
|
188 |
+
except Exception as e:
|
189 |
+
st.error(f"Error saving model: {str(e)}")
|
190 |
+
|
191 |
+
with col8:
|
192 |
+
if not use_pca: # Skip feature importance for PCA
|
193 |
+
st.subheader("Feature Importance")
|
194 |
+
fig_importance = visualizer.plot_feature_importance(
|
195 |
+
model,
|
196 |
+
feature_cols if not use_polynomial else [f"Feature_{i}" for i in range(X_train.shape[1])]
|
197 |
+
)
|
198 |
+
st.pyplot(fig_importance)
|
199 |
+
|
200 |
+
# Confusion Matrix
|
201 |
+
st.subheader("Confusion Matrix")
|
202 |
+
fig_cm = visualizer.plot_confusion_matrix(
|
203 |
+
y_test,
|
204 |
+
model.predict(X_test)
|
205 |
+
)
|
206 |
+
st.pyplot(fig_cm)
|
207 |
+
|
208 |
+
# ROC Curve
|
209 |
+
st.subheader("ROC Curve")
|
210 |
+
fig_roc = visualizer.plot_roc_curve(
|
211 |
+
model,
|
212 |
+
X_test,
|
213 |
+
y_test
|
214 |
+
)
|
215 |
+
st.pyplot(fig_roc)
|
216 |
+
|
217 |
+
except Exception as e:
|
218 |
+
st.error(f"Error: {str(e)}")
|
219 |
+
|
220 |
+
else:
|
221 |
+
st.info("Please upload a dataset to begin.")
|
222 |
+
|
223 |
+
# Add Model Management Section
|
224 |
+
st.header("5. Saved Models")
|
225 |
+
try:
|
226 |
+
saved_models = list_saved_models()
|
227 |
+
if saved_models:
|
228 |
+
for model_info in saved_models:
|
229 |
+
with st.expander(f"Model: {model_info['name']}"):
|
230 |
+
st.write(f"Type: {model_info['type']}")
|
231 |
+
st.write(f"Created: {model_info['created_at']}")
|
232 |
+
st.write("Performance Metrics:")
|
233 |
+
for metric, value in model_info['metrics'].items():
|
234 |
+
st.metric(metric, f"{value:.4f}")
|
235 |
+
else:
|
236 |
+
st.info("No saved models found.")
|
237 |
+
except Exception as e:
|
238 |
+
st.error(f"Error loading saved models: {str(e)}")
|
239 |
+
|
240 |
+
if __name__ == "__main__":
|
241 |
+
main()
|