miracle-ema commited on
Commit
713c79c
·
verified ·
1 Parent(s): f6dde48

Upload 3 files

Browse files
Files changed (3) hide show
  1. Dockerfile +32 -0
  2. app.py +286 -0
  3. requirements.txt +13 -0
Dockerfile ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## use the official python 3.9 image
2
+
3
+ FROM python:3.9
4
+
5
+
6
+ ## set the working directory to /code
7
+
8
+ WORKDIR /code
9
+
10
+ ## copy the current directory contents into the container at /code
11
+ COPY ./requirements.txt /code/requirements.txt
12
+
13
+ ## Install the requirements.txt
14
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
15
+
16
+ # set up a new user named "user"
17
+ RUN useradd user
18
+ # Switch to the "user" user
19
+ USER user
20
+
21
+ # set home to the user's home directory
22
+ ENV HOME=/home/user \
23
+ PATH=/home/user/.local/bin:$PATH
24
+
25
+ # set the working directory to the user's home directory
26
+ WORKDIR $HOME/app
27
+
28
+ # copy the current directory contents into the container at $HOME/app setting the user as the owner to avoid permission issues
29
+ COPY --chown=user . $HOME/app
30
+
31
+ ## Start the FASTAPI App on the port 7860
32
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile, HTTPException, Form
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from fastapi.responses import HTMLResponse
4
+ from fastapi.staticfiles import StaticFiles
5
+ import pandas as pd
6
+ import matplotlib.pyplot as plt
7
+ import seaborn as sns
8
+ import os
9
+ import logging
10
+ from huggingface_hub import InferenceClient
11
+ from dotenv import load_dotenv
12
+ import hashlib
13
+ import ast
14
+ import re
15
+
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
+
19
+ load_dotenv()
20
+
21
+ app = FastAPI()
22
+
23
+ app.add_middleware(
24
+ CORSMiddleware,
25
+ allow_origins=["*"],
26
+ allow_credentials=True,
27
+ allow_methods=["*"],
28
+ allow_headers=["*"],
29
+ )
30
+
31
+ app.mount("/static", StaticFiles(directory="static"), name="static")
32
+
33
+ API_TOKEN = os.getenv("HF_TOKEN")
34
+ if not API_TOKEN:
35
+ raise ValueError("HUGGINGFACE_API_TOKEN environment variable not set.")
36
+
37
+ MODEL_NAME = "bigcode/starcoder"
38
+ client = InferenceClient(model=MODEL_NAME, token=API_TOKEN)
39
+
40
+ UPLOAD_DIR = "uploads"
41
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
42
+
43
+ IMAGES_DIR = os.path.join("../static", "images")
44
+ os.makedirs(IMAGES_DIR, exist_ok=True)
45
+
46
+ def detect_plot_type(prompt):
47
+ """Detect the requested plot type from the prompt."""
48
+ prompt_lower = prompt.lower()
49
+ if "bar" in prompt_lower:
50
+ return "bar"
51
+ elif "histogram" in prompt_lower or "distribution" in prompt_lower:
52
+ return "histogram"
53
+ elif "line" in prompt_lower:
54
+ return "line"
55
+ else:
56
+ return "scatter"
57
+
58
+ @app.post("/upload/")
59
+ async def upload_file(file: UploadFile = File(...)):
60
+ if not file.filename.endswith(".xlsx"):
61
+ raise HTTPException(status_code=400, detail="File must be an Excel file (.xlsx)")
62
+
63
+ file_path = os.path.join(UPLOAD_DIR, file.filename)
64
+ with open(file_path, "wb") as buffer:
65
+ buffer.write(await file.read())
66
+
67
+ logger.info(f"File uploaded: {file.filename}")
68
+ return {"filename": file.filename}
69
+
70
+ @app.post("/generate-visualization/")
71
+ async def generate_visualization(prompt: str = Form(...), filename: str = Form(...)):
72
+ file_path = os.path.join(UPLOAD_DIR, filename)
73
+
74
+ if not os.path.exists(file_path):
75
+ raise HTTPException(status_code=404, detail="File not found on server.")
76
+
77
+ try:
78
+ df = pd.read_excel(file_path)
79
+ if df.empty:
80
+ raise ValueError("Excel file is empty.")
81
+ except Exception as e:
82
+ raise HTTPException(status_code=400, detail=f"Error reading Excel file: {str(e)}")
83
+
84
+ plot_type = detect_plot_type(prompt)
85
+ allow_groupby = "average" in prompt.lower() or "mean" in prompt.lower()
86
+
87
+ input_text = f"""
88
+ You are a Python code generator specializing in data visualization. The DataFrame 'df' is already loaded from an Excel file '{filename}' with columns {', '.join(df.columns)}.
89
+
90
+ The user requests: '{prompt}'.
91
+
92
+ Instructions:
93
+ - Generate Python code to create a {plot_type} plot based on the user's natural language prompt using the pre-loaded DataFrame 'df', pandas (pd), matplotlib.pyplot (plt), and seaborn (sns).
94
+ - Include the following imports at the top of the code, preceded by a comment:
95
+ # import libraries
96
+ import pandas as pd
97
+ import matplotlib.pyplot as plt
98
+ import seaborn as sns
99
+ - Include a line to read the DataFrame, preceded by a comment (even though it will be removed during execution):
100
+ # load data
101
+ df = pd.read_excel('{filename}')
102
+ - Add xlabel and ylabel using human-readable forms inferred from the prompt (e.g., 'Petal Length' if the prompt mentions "petal length").
103
+ - Add a title using plt.title(). Format based on plot type and prompt context:
104
+ - Scatter: "<X> vs <Y>" or "<X> vs <Y> by <Hue>" if "colored by" is present
105
+ - Bar: "<Y> by <X>" or "Average <Y> by <X>" if averages are requested
106
+ - Histogram: "Distribution of <X>"
107
+ - Line: "<Y> by <X>" or "<X> vs <Y>"
108
+ - For averages, use df.groupby().mean() if "average" or "mean" is in the prompt.
109
+ - Plot type specifics:
110
+ - Scatter: Use sns.scatterplot with hue=<column> if "colored by" is present, else plt.scatter
111
+ - Bar: Use sns.barplot; apply groupby if averages are requested
112
+ - Histogram: Use sns.histplot
113
+ - Line: Use sns.lineplot
114
+ - Automatically infer column names from the prompt and match them to the exact DataFrame columns ({', '.join(df.columns)}) based on context. Use the exact column names as they appear in the DataFrame.
115
+ - Include plt.show() at the end (will be removed during execution).
116
+ - Output only the Python code as valid Python.
117
+
118
+ Examples:
119
+ - For "Create a scatter plot of column1 vs column2":
120
+ # import libraries
121
+ import pandas as pd
122
+ import matplotlib.pyplot as plt
123
+ import seaborn as sns
124
+
125
+ # load data
126
+ df = pd.read_excel('{filename}')
127
+
128
+ sns.scatterplot(x='column1', y='column2', data=df)
129
+ plt.xlabel('Column 1')
130
+ plt.ylabel('Column 2')
131
+ plt.title('Column 1 vs Column 2')
132
+
133
+ plt.show()
134
+
135
+ - For "Create a scatter plot of column1 vs column2 colored by column3":
136
+ # import libraries
137
+ import pandas as pd
138
+ import matplotlib.pyplot as plt
139
+ import seaborn as sns
140
+
141
+ # load data
142
+ df = pd.read_excel('{filename}')
143
+
144
+ sns.scatterplot(x='column1', y='column2', hue='column3', data=df)
145
+ plt.xlabel('Column 1')
146
+ plt.ylabel('Column 2')
147
+ plt.title('Column 1 vs Column 2 by Column3')
148
+
149
+ plt.show()
150
+
151
+ - For "Create a bar chart of column1 by column2":
152
+ # import libraries
153
+ import pandas as pd
154
+ import matplotlib.pyplot as plt
155
+ import seaborn as sns
156
+
157
+ # load data
158
+ df = pd.read_excel('{filename}')
159
+
160
+ sns.barplot(x='column2', y='column1', data=df)
161
+ plt.xlabel('Column 2')
162
+ plt.ylabel('Column 1')
163
+ plt.title('Column 1 by Column 2')
164
+
165
+ plt.show()
166
+
167
+ - For "Create a bar chart of average column1 by column2":
168
+ # import libraries
169
+ import pandas as pd
170
+ import matplotlib.pyplot as plt
171
+ import seaborn as sns
172
+
173
+ # load data
174
+ df = pd.read_excel('{filename}')
175
+
176
+ sns.barplot(x='column2', y='column1', data=df.groupby('column2').mean().reset_index())
177
+ plt.xlabel('Column 2')
178
+ plt.ylabel('Average Column 1')
179
+ plt.title('Average Column 1 by Column 2')
180
+
181
+ plt.show()
182
+
183
+ - For "Create a histogram of column1":
184
+ # import libraries
185
+ import pandas as pd
186
+ import matplotlib.pyplot as plt
187
+ import seaborn as sns
188
+
189
+ # load data
190
+ df = pd.read_excel('{filename}')
191
+
192
+ sns.histplot(df['column1'])
193
+ plt.xlabel('Column 1')
194
+ plt.ylabel('Frequency')
195
+ plt.title('Distribution of Column 1')
196
+
197
+ plt.show()
198
+
199
+ - For "Create a line chart of column1 by column2":
200
+ # import libraries
201
+ import pandas as pd
202
+ import matplotlib.pyplot as plt
203
+ import seaborn as sns
204
+
205
+ # load data
206
+ df = pd.read_excel('{filename}')
207
+
208
+ sns.lineplot(x='column2', y='column1', data=df)
209
+ plt.xlabel('Column 2')
210
+ plt.ylabel('Column 1')
211
+ plt.title('Column 1 by Column 2')
212
+
213
+ plt.show()
214
+
215
+ Generate the code for the user's request now. Output only the Python code, nothing else:
216
+ """
217
+
218
+ try:
219
+ raw_generated_code = client.text_generation(input_text, max_new_tokens=400)
220
+ logger.info(f"Raw generated code: '{raw_generated_code}'")
221
+ except Exception as e:
222
+ logger.error(f"Error querying model: {str(e)}")
223
+ raise HTTPException(status_code=500, detail=f"Error querying model: {str(e)}")
224
+
225
+ if not raw_generated_code.strip():
226
+ logger.error("No code generated by the AI model.")
227
+ raise HTTPException(status_code=500, detail="No code generated by the AI model.")
228
+
229
+
230
+ cleaned_code = raw_generated_code.strip().replace('```', '').replace('"""', '').replace("'''", '')
231
+ lines = cleaned_code.splitlines()
232
+ cleaned_code = "\n".join(
233
+ line.strip() for line in lines
234
+ if line.strip()
235
+ and not line.strip().startswith(('#', 'def', 'class', 'import', 'df ='))
236
+ and not any(kw in line for kw in ["pd.read_csv", "pd.read_excel", "http", "raise", "print", "plt.show"])
237
+ and not re.match(r'^\s*\d+\s+.*$', line)
238
+ and not re.match(r'^\s*$$ .*rows.*columns $$\s*$', line)
239
+
240
+ ).strip()
241
+ logger.info(f"Cleaned code: '{cleaned_code}'")
242
+
243
+ if not cleaned_code:
244
+ logger.error("Cleaned code is empty after filtering.")
245
+ raise HTTPException(status_code=500, detail="Generated code is empty or contains only disallowed content")
246
+
247
+ try:
248
+ ast.parse(cleaned_code)
249
+ except SyntaxError as e:
250
+ logger.error(f"Syntax error in cleaned code: '{cleaned_code}' Exception: {str(e)}")
251
+ raise HTTPException(status_code=500, detail=f"Syntax error in generated code: {str(e)}")
252
+
253
+
254
+ plot_hash = hashlib.md5(f"{filename}_{prompt}".encode()).hexdigest()[:8]
255
+ plot_filename = f"plot_{plot_hash}.png"
256
+ plot_path = os.path.join(IMAGES_DIR, plot_filename)
257
+
258
+ try:
259
+ exec_globals = {"pd": pd, "plt": plt, "sns": sns, "df": df}
260
+ plt.close('all')
261
+ plt.clf()
262
+ plt.cla()
263
+ fig = plt.figure(figsize=(8, 6))
264
+ exec(cleaned_code, exec_globals)
265
+ if not fig.get_axes():
266
+ plt.close('all')
267
+ raise ValueError("Generated code produced an empty plot")
268
+ plt.savefig(plot_path, bbox_inches="tight")
269
+ logger.info(f"Plot saved to {plot_path}")
270
+ plt.close('all')
271
+ except Exception as e:
272
+ plt.close('all')
273
+ logger.error(f"Error executing cleaned code: '{cleaned_code}' Exception: {str(e)}")
274
+ raise HTTPException(status_code=500, detail=f"Error executing code: {str(e)}")
275
+
276
+ if not os.path.exists(plot_path):
277
+ raise HTTPException(status_code=500, detail="Plot file was not created.")
278
+
279
+ plot_url = f"/static/images/{plot_filename}?t={int(pd.Timestamp.now().timestamp())}"
280
+ return {"plot_url": plot_url, "generated_code": raw_generated_code}
281
+
282
+ @app.get("/")
283
+ async def serve_frontend():
284
+ with open("static/index.html", "r") as f:
285
+ return HTMLResponse(content=f.read())
286
+
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.115.0
2
+ uvicorn==0.30.6
3
+ pandas==2.2.2
4
+ matplotlib==3.9.4
5
+ seaborn==0.13.2
6
+ python-multipart==0.0.9
7
+ transformers==4.45.2
8
+ torch==2.4.1
9
+ openpyxl==3.1.5
10
+ python-dotenv==1.0.1
11
+ huggingface_hub==0.23.4
12
+
13
+