yangyang158 commited on
Commit
8334f8b
·
1 Parent(s): 1dd235a

feat: Integrate UI control panel and fix CORS issue

Browse files
Files changed (4) hide show
  1. __pycache__/app.cpython-312.pyc +0 -0
  2. app.py +23 -2
  3. index.html +215 -0
  4. requirements.txt +1 -0
__pycache__/app.cpython-312.pyc CHANGED
Binary files a/__pycache__/app.cpython-312.pyc and b/__pycache__/app.cpython-312.pyc differ
 
app.py CHANGED
@@ -2,7 +2,8 @@ import os
2
  import sys
3
  import logging
4
  from functools import wraps
5
- from flask import Flask, request, jsonify
 
6
  import torch
7
  import pandas as pd
8
  from huggingface_hub import hf_hub_download
@@ -19,6 +20,7 @@ except ImportError as e:
19
 
20
  # --- Globals ---
21
  app = Flask(__name__)
 
22
  predictor = None
23
  model_name_global = "kronos-base" # Use key now
24
  API_KEY = os.environ.get("KRONOS_API_KEY")
@@ -87,6 +89,11 @@ def require_api_key(f):
87
 
88
  # --- API Endpoints ---
89
 
 
 
 
 
 
90
  @app.route('/api/load-model', methods=['POST'])
91
  @require_api_key
92
  def load_model_endpoint():
@@ -238,7 +245,15 @@ def predict():
238
  )
239
 
240
  # Format results for JSON response
241
- prediction_results = pred_df.to_dict(orient='records')
 
 
 
 
 
 
 
 
242
 
243
  return jsonify({
244
  'success': True,
@@ -249,3 +264,9 @@ def predict():
249
  except Exception as e:
250
  logging.error(f"Prediction failed: {e}")
251
  return jsonify({'error': f'An error occurred during prediction: {str(e)}'}), 500
 
 
 
 
 
 
 
2
  import sys
3
  import logging
4
  from functools import wraps
5
+ from flask import Flask, request, jsonify, send_from_directory
6
+ from flask_cors import CORS
7
  import torch
8
  import pandas as pd
9
  from huggingface_hub import hf_hub_download
 
20
 
21
  # --- Globals ---
22
  app = Flask(__name__)
23
+ CORS(app) # Enable CORS for all routes
24
  predictor = None
25
  model_name_global = "kronos-base" # Use key now
26
  API_KEY = os.environ.get("KRONOS_API_KEY")
 
89
 
90
  # --- API Endpoints ---
91
 
92
+ @app.route('/')
93
+ def index():
94
+ """Serves the index.html file for the visualizer."""
95
+ return send_from_directory('.', 'index.html')
96
+
97
  @app.route('/api/load-model', methods=['POST'])
98
  @require_api_key
99
  def load_model_endpoint():
 
245
  )
246
 
247
  # Format results for JSON response
248
+ # --- Format results to match input format ---
249
+ pred_df_reset = pred_df.reset_index()
250
+ # Convert timestamp to Unix milliseconds integer
251
+ pred_df_reset['timestamp'] = (pred_df_reset['timestamp'].astype('int64') / 10**6).astype('int64')
252
+ # Reorder columns to match the desired output format: [timestamp, open, high, low, close, volume]
253
+ output_columns = ['timestamp', 'open', 'high', 'low', 'close', 'volume']
254
+ pred_df_formatted = pred_df_reset[output_columns]
255
+ # Convert to list of lists
256
+ prediction_results = pred_df_formatted.values.tolist()
257
 
258
  return jsonify({
259
  'success': True,
 
264
  except Exception as e:
265
  logging.error(f"Prediction failed: {e}")
266
  return jsonify({'error': f'An error occurred during prediction: {str(e)}'}), 500
267
+
268
+
269
+ if __name__ == '__main__':
270
+ # This block is for local debugging purposes.
271
+ # The production server will use a WSGI server like Gunicorn.
272
+ app.run(host='0.0.0.0', port=7860, debug=True)
index.html ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Kronos API Prediction Visualizer</title>
7
+ <script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/echarts.min.js"></script>
8
+ <style>
9
+ body { font-family: sans-serif; margin: 2em; }
10
+ .container { max-width: 1200px; margin: auto; }
11
+ .form-group { margin-bottom: 1em; }
12
+ label { display: block; margin-bottom: 0.5em; }
13
+ input, textarea, select { width: 100%; padding: 0.5em; box-sizing: border-box; }
14
+ textarea { min-height: 150px; }
15
+ button { padding: 0.7em 1.5em; cursor: pointer; }
16
+ #chart { width: 100%; height: 600px; margin-top: 2em; border: 1px solid #ccc; }
17
+ .error { color: red; }
18
+ .success { color: green; }
19
+ .status { margin-top: 1em; font-weight: bold; }
20
+ .info { background-color: #f0f0f0; border-left: 4px solid #007bff; padding: 1em; margin-bottom: 1em; }
21
+ hr { margin: 2em 0; }
22
+ </style>
23
+ </head>
24
+ <body>
25
+ <div class="container">
26
+ <h1>Kronos API Control Panel</h1>
27
+
28
+ <div class="info">
29
+ <p>
30
+ For local testing, first run <code>app.py</code> in VS Code (press F5). The API endpoints will be available at <code>http://localhost:7860</code>.
31
+ </p>
32
+ </div>
33
+
34
+ <!-- Section for Model Loading -->
35
+ <section id="model-loader">
36
+ <h2>1. Load Model</h2>
37
+ <div class="form-group">
38
+ <label for="model-select">Available Models:</label>
39
+ <select id="model-select"></select>
40
+ </div>
41
+ <button id="load-model-btn">Load Selected Model</button>
42
+ <div id="model-status" class="status"></div>
43
+ </section>
44
+
45
+ <hr>
46
+
47
+ <!-- Section for Prediction -->
48
+ <section id="predictor">
49
+ <h2>2. Get Prediction</h2>
50
+ <div class="form-group">
51
+ <label for="api-key">API Key (Bearer Token, if required):</label>
52
+ <input type="password" id="api-key" placeholder="Enter your API Key">
53
+ </div>
54
+ <div class="form-group">
55
+ <label for="k-lines">K-line Data (JSON Array of Arrays):</label>
56
+ <textarea id="k-lines" placeholder="Paste your k-line data here..."></textarea>
57
+ </div>
58
+ <div class="form-group">
59
+ <label for="pred-len">Prediction Length:</label>
60
+ <input type="number" id="pred-len" value="120">
61
+ </div>
62
+ <button id="predict-btn">Get Prediction & Visualize</button>
63
+ </section>
64
+
65
+ <div id="chart"></div>
66
+ <div id="error-message" class="status error"></div>
67
+ </div>
68
+
69
+ <script>
70
+ // --- Global DOM Elements ---
71
+ const modelSelect = document.getElementById('model-select');
72
+ const loadModelBtn = document.getElementById('load-model-btn');
73
+ const modelStatusDiv = document.getElementById('model-status');
74
+ const predictBtn = document.getElementById('predict-btn');
75
+ const apiKeyInput = document.getElementById('api-key');
76
+ const kLinesTextarea = document.getElementById('k-lines');
77
+ const predLenInput = document.getElementById('pred-len');
78
+ const chartDom = document.getElementById('chart');
79
+ const errorDiv = document.getElementById('error-message');
80
+
81
+ // --- API Base URLs ---
82
+ // Use relative paths for deployed environment, detect local for testing.
83
+ const isLocal = window.location.hostname === 'localhost' || window.location.hostname === '127.0.0.1';
84
+ const apiBaseUrl = isLocal ? 'http://localhost:7860' : '';
85
+
86
+ // --- Helper Functions ---
87
+ async function apiFetch(endpoint, options) {
88
+ const apiKey = apiKeyInput.value;
89
+ const headers = {
90
+ 'Content-Type': 'application/json',
91
+ ...options.headers,
92
+ };
93
+ if (apiKey) {
94
+ headers['Authorization'] = `Bearer ${apiKey}`;
95
+ }
96
+
97
+ const response = await fetch(apiBaseUrl + endpoint, { ...options, headers });
98
+
99
+ if (!response.ok) {
100
+ const errorData = await response.json();
101
+ throw new Error(`API Error (${response.status}): ${errorData.error || 'Unknown error'}`);
102
+ }
103
+ return response.json();
104
+ }
105
+
106
+ // --- Model Loading Logic ---
107
+ async function populateModels() {
108
+ try {
109
+ const models = await apiFetch('/api/available-models', { method: 'GET' });
110
+ modelSelect.innerHTML = ''; // Clear existing options
111
+ for (const key in models) {
112
+ const option = document.createElement('option');
113
+ option.value = key;
114
+ option.textContent = `${models[key].name} (${models[key].params}) - ${models[key].description}`;
115
+ modelSelect.appendChild(option);
116
+ }
117
+ } catch (error) {
118
+ modelStatusDiv.className = 'status error';
119
+ modelStatusDiv.textContent = `Failed to fetch models: ${error.message}`;
120
+ }
121
+ }
122
+
123
+ loadModelBtn.addEventListener('click', async () => {
124
+ const modelKey = modelSelect.value;
125
+ modelStatusDiv.className = 'status';
126
+ modelStatusDiv.textContent = `Loading model '${modelKey}'...`;
127
+ try {
128
+ const result = await apiFetch('/api/load-model', {
129
+ method: 'POST',
130
+ body: JSON.stringify({ model_key: modelKey })
131
+ });
132
+ modelStatusDiv.className = 'status success';
133
+ modelStatusDiv.textContent = result.status;
134
+ } catch (error) {
135
+ modelStatusDiv.className = 'status error';
136
+ modelStatusDiv.textContent = error.message;
137
+ }
138
+ });
139
+
140
+ // --- Prediction Logic ---
141
+ predictBtn.addEventListener('click', async () => {
142
+ const kLinesText = kLinesTextarea.value;
143
+ const predLen = parseInt(predLenInput.value, 10);
144
+
145
+ errorDiv.textContent = '';
146
+ const myChart = echarts.init(chartDom);
147
+ myChart.showLoading();
148
+
149
+ if (!kLinesText) {
150
+ errorDiv.textContent = 'K-line data cannot be empty.';
151
+ myChart.hideLoading();
152
+ return;
153
+ }
154
+
155
+ let kLines;
156
+ try {
157
+ kLines = JSON.parse(kLinesText);
158
+ } catch (e) {
159
+ errorDiv.textContent = 'Invalid JSON in K-line data. Please check the format.';
160
+ myChart.hideLoading();
161
+ return;
162
+ }
163
+
164
+ const payload = {
165
+ k_lines: kLines,
166
+ prediction_params: { pred_len: predLen }
167
+ };
168
+
169
+ try {
170
+ const result = await apiFetch('/api/predict', {
171
+ method: 'POST',
172
+ body: JSON.stringify(payload)
173
+ });
174
+
175
+ const historicalData = kLines.map(item => [
176
+ item[0], parseFloat(item[1]), parseFloat(item[4]), parseFloat(item[3]), parseFloat(item[2])
177
+ ]);
178
+
179
+ const predictionData = result.prediction_results.map(item => [
180
+ item[0], parseFloat(item[1]), parseFloat(item[4]), parseFloat(item[3]), parseFloat(item[2])
181
+ ]);
182
+
183
+ const allTimestamps = [...historicalData.map(d => d[0]), ...predictionData.map(d => d[0])];
184
+
185
+ const option = {
186
+ tooltip: { trigger: 'axis', axisPointer: { type: 'cross' } },
187
+ legend: { data: ['Historical', 'Prediction'] },
188
+ grid: { left: '10%', right: '10%', bottom: '15%' },
189
+ xAxis: { type: 'time', min: allTimestamps[0], max: allTimestamps[allTimestamps.length - 1] },
190
+ yAxis: { scale: true, splitArea: { show: true } },
191
+ dataZoom: [
192
+ { type: 'inside', start: 50, end: 100 },
193
+ { show: true, type: 'slider', top: '90%', start: 50, end: 100 }
194
+ ],
195
+ series: [
196
+ { name: 'Historical', type: 'candlestick', data: historicalData, itemStyle: { color: '#00da3c', color0: '#ec0000', borderColor: '#008F28', borderColor0: '#8A0000' } },
197
+ { name: 'Prediction', type: 'candlestick', data: predictionData, itemStyle: { color: '#4287f5', color0: '#f54242', borderColor: '#285199', borderColor0: '#992828' } }
198
+ ]
199
+ };
200
+
201
+ myChart.hideLoading();
202
+ myChart.setOption(option);
203
+
204
+ } catch (error) {
205
+ myChart.hideLoading();
206
+ errorDiv.textContent = error.message;
207
+ console.error('Fetch error:', error);
208
+ }
209
+ });
210
+
211
+ // --- Initial Load ---
212
+ document.addEventListener('DOMContentLoaded', populateModels);
213
+ </script>
214
+ </body>
215
+ </html>
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  flask
 
2
  pandas
3
  huggingface_hub
4
  transformers
 
1
  flask
2
+ flask-cors
3
  pandas
4
  huggingface_hub
5
  transformers