|
|
|
import gradio as gr |
|
import datetime |
|
import pandas as pd |
|
from groq import Groq |
|
from sentence_transformers import SentenceTransformer |
|
import chromadb |
|
from chromadb.config import Settings |
|
import hashlib |
|
from typing import TypedDict, Optional, List |
|
from langgraph.graph import StateGraph, END |
|
import json |
|
import tempfile |
|
import subprocess |
|
import os |
|
|
|
|
|
|
|
api_key_coder= os.environ.get('api_key_coder') |
|
|
|
|
|
|
|
class CodeAssistantState(TypedDict): |
|
user_input: str |
|
similar_examples: Optional[List[str]] |
|
generated_code: Optional[str] |
|
error: Optional[str] |
|
task_type: Optional[str] |
|
evaluation_result: Optional[str] |
|
|
|
|
|
|
|
|
|
|
|
df = pd.read_parquet("hf://datasets/openai/openai_humaneval/openai_humaneval/test-00000-of-00001.parquet") |
|
extracted_data = df[['task_id', 'prompt', 'canonical_solution']] |
|
|
|
|
|
embedding_model = SentenceTransformer("all-MiniLM-L6-v2") |
|
groq_client = Groq(api_key=api_key_coder) |
|
|
|
client = chromadb.Client(Settings( |
|
anonymized_telemetry=False, |
|
persist_directory="rag_db" |
|
)) |
|
collection = client.get_or_create_collection( |
|
name="code_examples", |
|
metadata={"hnsw:space": "cosine"} |
|
) |
|
|
|
|
|
|
|
|
|
|
|
def initialize_db(state: CodeAssistantState): |
|
try: |
|
for _, row in extracted_data.iterrows(): |
|
embedding = embedding_model.encode([row['prompt'].strip()])[0] |
|
doc_id = hashlib.md5(row['prompt'].encode()).hexdigest() |
|
collection.add( |
|
documents=[row['canonical_solution'].strip()], |
|
metadatas=[{"prompt": row['prompt'], "type": "code_example"}], |
|
ids=[doc_id], |
|
embeddings=[embedding] |
|
) |
|
return state |
|
except Exception as e: |
|
state["error"] = f"DB initialization failed: {str(e)}" |
|
return state |
|
|
|
def retrieve_examples(state: CodeAssistantState): |
|
try: |
|
embedding = embedding_model.encode([state["user_input"]])[0] |
|
results = collection.query( |
|
query_embeddings=[embedding], |
|
n_results=2 |
|
) |
|
state["similar_examples"] = results['documents'][0] if results['documents'] else None |
|
return state |
|
except Exception as e: |
|
state["error"] = f"Retrieval failed: {str(e)}" |
|
return state |
|
|
|
def classify_task_llm(state: CodeAssistantState) -> CodeAssistantState: |
|
if not isinstance(state, dict): |
|
raise ValueError("State must be a dictionary") |
|
|
|
if "user_input" not in state or not state["user_input"].strip(): |
|
state["error"] = "No user input provided for classification" |
|
state["task_type"] = "generate" |
|
return state |
|
|
|
try: |
|
prompt = f"""You are a helpful code assistant. Classify the user request as one of the following tasks: |
|
- "generate": if the user wants to write or generate code |
|
- "explain": if the user wants to understand what a code snippet does |
|
- "test": if the user wants to test existing code |
|
Return ONLY a JSON object in the format: {{"task": "...", "user_input": "..."}} — no explanation. |
|
User request: {state["user_input"]} |
|
""" |
|
completion = groq_client.chat.completions.create( |
|
model="llama3-70b-8192", |
|
messages=[ |
|
{"role": "system", "content": "Classify code-related user input. Respond with ONLY JSON."}, |
|
{"role": "user", "content": prompt} |
|
], |
|
temperature=0.3, |
|
max_tokens=200, |
|
response_format={"type": "json_object"} |
|
) |
|
|
|
content = completion.choices[0].message.content.strip() |
|
|
|
try: |
|
result = json.loads(content) |
|
if not isinstance(result, dict): |
|
raise ValueError("Response is not a JSON object") |
|
except (json.JSONDecodeError, ValueError) as e: |
|
state["error"] = f"Invalid response format from LLM: {str(e)}. Content: {content}" |
|
state["task_type"] = "generate" |
|
return state |
|
|
|
task_type = result.get("task", "").lower() |
|
if task_type not in ["generate", "explain", "test"]: |
|
state["error"] = f"Invalid task type received: {task_type}" |
|
task_type = "generate" |
|
|
|
state["task_type"] = task_type |
|
state["user_input"] = result.get("user_input", state["user_input"]) |
|
return state |
|
|
|
except Exception as e: |
|
state["error"] = f"LLM-based classification failed: {str(e)}" |
|
state["task_type"] = "generate" |
|
return state |
|
|
|
def test_code(state: CodeAssistantState) -> CodeAssistantState: |
|
if not isinstance(state, dict): |
|
raise ValueError("State must be a dictionary") |
|
|
|
if "user_input" not in state or not state["user_input"].strip(): |
|
state["error"] = "Please provide the code you want to test" |
|
return state |
|
|
|
try: |
|
messages = [ |
|
{"role": "system", "content": """You are a Python testing expert. Generate unit tests for the provided code. |
|
Return the test code in the following format: |
|
```python |
|
# Test code here |
|
```"""}, |
|
{"role": "user", "content": f"Generate comprehensive unit tests for this Python code:\n\n{state['user_input']}"} |
|
] |
|
|
|
completion = groq_client.chat.completions.create( |
|
model="llama-3.3-70b-versatile", |
|
messages=messages, |
|
temperature=0.5, |
|
max_tokens=2048, |
|
) |
|
|
|
test_code = completion.choices[0].message.content |
|
if test_code.startswith('```python'): |
|
test_code = test_code[9:-3] if test_code.endswith('```') else test_code[9:] |
|
elif test_code.startswith('```'): |
|
test_code = test_code[3:-3] if test_code.endswith('```') else test_code[3:] |
|
|
|
state["generated_tests"] = test_code.strip() |
|
state["metadata"] = { |
|
"model": "llama-3.3-70b-versatile", |
|
"timestamp": datetime.datetime.now().isoformat() |
|
} |
|
|
|
|
|
try: |
|
|
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as code_file: |
|
code_file.write(state['user_input']) |
|
code_file_path = code_file.name |
|
|
|
|
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as test_file: |
|
test_file.write(test_code) |
|
test_file_path = test_file.name |
|
|
|
|
|
result = subprocess.run( |
|
['python', test_file_path], |
|
capture_output=True, |
|
text=True, |
|
timeout=10 |
|
) |
|
|
|
state["test_results"] = { |
|
"returncode": result.returncode, |
|
"stdout": result.stdout, |
|
"stderr": result.stderr |
|
} |
|
|
|
|
|
os.unlink(code_file_path) |
|
os.unlink(test_file_path) |
|
|
|
except Exception as e: |
|
state["test_error"] = f"Error executing tests: {str(e)}" |
|
|
|
print(f"\nGenerated Tests:\n{test_code.strip()}\n") |
|
if "test_results" in state: |
|
print(f"Test Execution Results:\n{state['test_results']['stdout']}") |
|
if state["test_results"]["stderr"]: |
|
print(f"Errors:\n{state['test_results']['stderr']}") |
|
|
|
return state |
|
|
|
except Exception as e: |
|
state["error"] = f"Error generating tests: {str(e)}" |
|
return state |
|
|
|
def generate_code(state: CodeAssistantState) -> CodeAssistantState: |
|
if not isinstance(state, dict): |
|
raise ValueError("State must be a dictionary") |
|
|
|
if "user_input" not in state or not state["user_input"].strip(): |
|
state["error"] = "Please enter your code request" |
|
return state |
|
|
|
try: |
|
messages = [ |
|
{"role": "system", "content": "You are a Python coding assistant. Return only clean, production-ready code."}, |
|
{"role": "user", "content": state["user_input"].strip()} |
|
] |
|
|
|
completion = groq_client.chat.completions.create( |
|
model="llama-3.3-70b-versatile", |
|
messages=messages, |
|
temperature=0.7, |
|
max_tokens=2048, |
|
) |
|
|
|
code = completion.choices[0].message.content |
|
if code.startswith('```python'): |
|
code = code[9:-3] if code.endswith('```') else code[9:] |
|
elif code.startswith('```'): |
|
code = code[3:-3] if code.endswith('```') else code[3:] |
|
|
|
state["generated_code"] = code.strip() |
|
state["metadata"] = { |
|
"model": "llama-3.3-70b-versatile", |
|
"timestamp": datetime.datetime.now().isoformat() |
|
} |
|
|
|
|
|
print(f"\nGenerated Code:\n{code.strip()}\n") |
|
|
|
return state |
|
|
|
except Exception as e: |
|
state["error"] = f"Error generating code: {str(e)}" |
|
return state |
|
|
|
def explain_code(state: CodeAssistantState) -> CodeAssistantState: |
|
try: |
|
messages = [ |
|
{"role": "system", "content": "You are a Python expert. Explain what the following code does in plain language."}, |
|
{"role": "user", "content": state["user_input"].strip()} |
|
] |
|
|
|
completion = groq_client.chat.completions.create( |
|
model="llama-3.3-70b-versatile", |
|
messages=messages, |
|
temperature=0.5, |
|
max_tokens=1024 |
|
) |
|
|
|
explanation = completion.choices[0].message.content.strip() |
|
state["generated_code"] = explanation |
|
state["metadata"] = { |
|
"model": "llama-3.3-70b-versatile", |
|
"timestamp": datetime.datetime.now().isoformat() |
|
} |
|
|
|
|
|
print(f"Explanation:\n{explanation}") |
|
|
|
return state |
|
|
|
except Exception as e: |
|
state["error"] = f"Error explaining code: {str(e)}" |
|
return state |
|
|
|
|
|
|
|
|
|
workflow = StateGraph(CodeAssistantState) |
|
|
|
|
|
workflow.add_node("initialize_db", initialize_db) |
|
workflow.add_node("retrieve_examples", retrieve_examples) |
|
workflow.add_node("classify_task", classify_task_llm) |
|
workflow.add_node("generate_code", generate_code) |
|
workflow.add_node("explain_code", explain_code) |
|
workflow.add_node("test_code", test_code) |
|
|
|
|
|
workflow.set_entry_point("initialize_db") |
|
workflow.add_edge("initialize_db", "retrieve_examples") |
|
workflow.add_edge("retrieve_examples", "classify_task") |
|
|
|
|
|
workflow.add_conditional_edges( |
|
"classify_task", |
|
lambda state: state["task_type"], |
|
{ |
|
"generate": "generate_code", |
|
"explain": "explain_code", |
|
"test": "test_code" |
|
} |
|
) |
|
|
|
|
|
workflow.add_edge("generate_code", END) |
|
workflow.add_edge("explain_code", END) |
|
workflow.add_edge("test_code", END) |
|
|
|
|
|
app_workflow = workflow.compile() |
|
|
|
|
|
|
|
|
|
def process_input(user_input: str): |
|
"""Function that will be called by Gradio to process user input""" |
|
initial_state = { |
|
"user_input": user_input, |
|
"similar_examples": None, |
|
"generated_code": None, |
|
"error": None, |
|
"task_type": None |
|
} |
|
|
|
result = app_workflow.invoke(initial_state) |
|
|
|
if result.get("error"): |
|
return f"Error: {result['error']}" |
|
|
|
if result["task_type"] == "generate": |
|
return f"Generated Code:\n\n{result['generated_code']}" |
|
else: |
|
return f"Code Explanation:\n\n{result['generated_code']}" |
|
|
|
|
|
|
|
with gr.Blocks(title="Smart Code Assistant") as demo: |
|
gr.Markdown(""" |
|
# Smart Code Assistant |
|
Enter your request either to generate new code or to explain existing code |
|
""") |
|
|
|
with gr.Row(): |
|
input_text = gr.Textbox(label="Enter your request", placeholder="Example: Write a function to add two numbers... or Explain this code...") |
|
output_text = gr.Textbox(label="Result", interactive=False) |
|
|
|
submit_btn = gr.Button("Execute") |
|
submit_btn.click(fn=process_input, inputs=input_text, outputs=output_text) |
|
|
|
|
|
gr.Examples( |
|
examples=[ |
|
["Write a Python function to add two numbers"], |
|
["Explain this code: for i in range(5): print(i)"], |
|
["Create a function to convert temperature from Fahrenheit to Celsius"], |
|
["test for i in range(3): print('Hello from test', i)"] |
|
], |
|
|
|
inputs=input_text |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |