juzer09 commited on
Commit
4b12a2e
·
verified ·
1 Parent(s): 68cbc97

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +385 -272
app.py CHANGED
@@ -1,272 +1,385 @@
1
- #!/usr/bin/env python3
2
- """
3
- Madverse Music - Hugging Face Spaces Version
4
- Streamlit app for HF Spaces deployment
5
- """
6
-
7
- import streamlit as st
8
- import torch
9
- import librosa
10
- import tempfile
11
- import os
12
- import time
13
- import numpy as np
14
-
15
- # Import the sonics library for model loading
16
- try:
17
- from sonics import HFAudioClassifier
18
- except ImportError:
19
- st.error("Sonics library not found. Please install it first.")
20
- st.stop()
21
-
22
- # Global model variable
23
- model = None
24
-
25
- # Page configuration
26
- st.set_page_config(
27
- page_title="Madverse Music: AI Music Detector",
28
- page_icon="🎵",
29
- layout="wide",
30
- initial_sidebar_state="expanded"
31
- )
32
-
33
- # Custom CSS
34
- st.markdown("""
35
- <style>
36
- .main-header {
37
- background: linear-gradient(90deg, #667eea 0%, #764ba2 100%);
38
- padding: 1rem;
39
- border-radius: 10px;
40
- color: white;
41
- text-align: center;
42
- margin-bottom: 2rem;
43
- }
44
- .result-box {
45
- padding: 1rem;
46
- border-radius: 10px;
47
- margin: 1rem 0;
48
- border-left: 5px solid;
49
- }
50
- .real-music {
51
- background-color: #d4edda;
52
- border-left-color: #28a745;
53
- }
54
- .fake-music {
55
- background-color: #f8d7da;
56
- border-left-color: #dc3545;
57
- }
58
- </style>
59
- """, unsafe_allow_html=True)
60
-
61
- @st.cache_resource
62
- def load_model():
63
- """Load the model with caching for HF Spaces"""
64
- try:
65
- with st.spinner("Loading AI model... This may take a moment..."):
66
- # Use the same loading method as the working API
67
- model = HFAudioClassifier.from_pretrained("awsaf49/sonics-spectttra-alpha-120s")
68
- model.eval()
69
- return model
70
- except Exception as e:
71
- st.error(f"Failed to load model: {str(e)}")
72
- return None
73
-
74
- def process_audio(audio_file, model):
75
- """Process audio file and return classification"""
76
- try:
77
- # Save uploaded file temporarily
78
- with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
79
- tmp_file.write(audio_file.read())
80
- tmp_path = tmp_file.name
81
-
82
- # Load audio (model uses 16kHz sample rate)
83
- audio, sr = librosa.load(tmp_path, sr=16000)
84
-
85
- # Convert to tensor and add batch dimension
86
- audio_tensor = torch.FloatTensor(audio).unsqueeze(0)
87
-
88
- # Get prediction using the same pattern as working API
89
- with torch.no_grad():
90
- output = model(audio_tensor)
91
-
92
- # Convert logit to probability using sigmoid
93
- probability = torch.sigmoid(output).item()
94
-
95
- # Classify: prob < 0.5 = Real, prob >= 0.5 = Fake
96
- if probability < 0.5:
97
- classification = "Real"
98
- confidence = (1 - probability) * 2 # Convert to 0-1 scale
99
- else:
100
- classification = "Fake"
101
- confidence = (probability - 0.5) * 2 # Convert to 0-1 scale
102
-
103
- # Calculate duration
104
- duration = len(audio) / sr
105
-
106
- # Clean up
107
- os.unlink(tmp_path)
108
-
109
- return {
110
- 'classification': classification,
111
- 'confidence': min(confidence, 1.0), # Cap at 1.0
112
- 'probability': probability,
113
- 'raw_score': output.item(),
114
- 'duration': duration,
115
- 'success': True
116
- }
117
-
118
- except Exception as e:
119
- # Clean up on error
120
- if 'tmp_path' in locals():
121
- try:
122
- os.unlink(tmp_path)
123
- except:
124
- pass
125
- return {
126
- 'success': False,
127
- 'error': str(e)
128
- }
129
-
130
- def main():
131
- # Header
132
- st.markdown("""
133
- <div class="main-header">
134
- <h1>Madverse Music: AI Music Detector</h1>
135
- <p>Detect AI-generated music vs human-created music using advanced AI technology</p>
136
- </div>
137
- """, unsafe_allow_html=True)
138
-
139
- # Sidebar
140
- with st.sidebar:
141
- st.markdown("### About")
142
- st.markdown("""
143
- This AI model can detect whether music is:
144
- - **Real**: Human-created music
145
- - **Fake**: AI-generated music (Suno, Udio, etc.)
146
-
147
- **Model**: SpecTTTra-α (120s)
148
- **Accuracy**: 97% F1 score
149
- **Max Duration**: 120 seconds
150
- """)
151
-
152
- st.markdown("### Supported Formats")
153
- st.markdown("- WAV (.wav)")
154
- st.markdown("- MP3 (.mp3)")
155
- st.markdown("- FLAC (.flac)")
156
- st.markdown("- M4A (.m4a)")
157
- st.markdown("- OGG (.ogg)")
158
-
159
- st.markdown("### Links")
160
- st.markdown("- [Madverse Website](https://madverse.co)")
161
- st.markdown("- [GitHub Repository](#)")
162
-
163
- # Load model
164
- model = load_model()
165
-
166
- if model is None:
167
- st.error("Model failed to load. Please refresh the page.")
168
- return
169
-
170
- st.success("AI model loaded successfully!")
171
-
172
- # File upload
173
- st.markdown("### Upload Audio File")
174
- uploaded_file = st.file_uploader(
175
- "Choose an audio file",
176
- type=['wav', 'mp3', 'flac', 'm4a', 'ogg'],
177
- help="Upload an audio file to analyze (max 120 seconds)"
178
- )
179
-
180
- if uploaded_file is not None:
181
- # Display file info
182
- st.markdown("### File Information")
183
- col1, col2, col3 = st.columns(3)
184
-
185
- with col1:
186
- st.metric("Filename", uploaded_file.name)
187
- with col2:
188
- st.metric("File Size", f"{uploaded_file.size / 1024:.1f} KB")
189
- with col3:
190
- st.metric("Format", uploaded_file.type)
191
-
192
- # Audio player
193
- st.markdown("### Preview")
194
- st.audio(uploaded_file)
195
-
196
- # Analysis button
197
- if st.button("Analyze Audio", type="primary", use_container_width=True):
198
- try:
199
- with st.spinner("Analyzing audio... This may take a few seconds..."):
200
- # Reset file pointer
201
- uploaded_file.seek(0)
202
-
203
- # Process audio
204
- start_time = time.time()
205
- result = process_audio(uploaded_file, model)
206
- processing_time = time.time() - start_time
207
-
208
- if not result['success']:
209
- st.error(f"Error processing audio: {result['error']}")
210
- return
211
-
212
- # Display results
213
- st.markdown("### Analysis Results")
214
-
215
- classification = result['classification']
216
- confidence = result['confidence']
217
-
218
- # Result box
219
- if classification == "Real":
220
- st.markdown(f"""
221
- <div class="result-box real-music">
222
- <h3>Result: Human-Created Music</h3>
223
- <p><strong>Classification:</strong> {classification}</p>
224
- <p><strong>Confidence:</strong> {confidence:.1%}</p>
225
- <p><strong>Message:</strong> This appears to be human-created music!</p>
226
- </div>
227
- """, unsafe_allow_html=True)
228
- else:
229
- st.markdown(f"""
230
- <div class="result-box fake-music">
231
- <h3>Result: AI-Generated Music</h3>
232
- <p><strong>Classification:</strong> {classification}</p>
233
- <p><strong>Confidence:</strong> {confidence:.1%}</p>
234
- <p><strong>Message:</strong> This appears to be AI-generated music!</p>
235
- </div>
236
- """, unsafe_allow_html=True)
237
-
238
- # Detailed metrics
239
- with st.expander("Detailed Metrics"):
240
- col1, col2, col3 = st.columns(3)
241
-
242
- with col1:
243
- st.metric("Confidence", f"{confidence:.1%}")
244
- with col2:
245
- st.metric("Probability", f"{result['probability']:.3f}")
246
- with col3:
247
- st.metric("Processing Time", f"{processing_time:.2f}s")
248
-
249
- if result['duration'] > 0:
250
- st.metric("Duration", f"{result['duration']:.1f}s")
251
-
252
- st.markdown("**Interpretation:**")
253
- st.markdown("""
254
- - **Probability < 0.5**: Classified as Real (human-created)
255
- - **Probability ≥ 0.5**: Classified as Fake (AI-generated)
256
- - **Confidence**: How certain the model is about its prediction
257
- """)
258
-
259
- except Exception as e:
260
- st.error(f"Error processing audio: {str(e)}")
261
-
262
- # Footer
263
- st.markdown("---")
264
- st.markdown("""
265
- <div style="text-align: center; color: #666;">
266
- <p>Powered by <strong>Madverse Music</strong> | Built with Streamlit & PyTorch</p>
267
- <p>This tool is for research and educational purposes. Results may vary depending on audio quality.</p>
268
- </div>
269
- """, unsafe_allow_html=True)
270
-
271
- if __name__ == "__main__":
272
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Madverse Music API
4
+ AI Music Detection Service
5
+ """
6
+
7
+ from fastapi import FastAPI, HTTPException, BackgroundTasks, Header, Depends
8
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
9
+ from pydantic import BaseModel, HttpUrl
10
+ import torch
11
+ import librosa
12
+ import tempfile
13
+ import os
14
+ import requests
15
+ from pathlib import Path
16
+ import time
17
+ from typing import Optional, Annotated, List
18
+ import uvicorn
19
+ import asyncio
20
+
21
+ # Initialize FastAPI app
22
+ app = FastAPI(
23
+ title="Madverse Music API",
24
+ description="AI-powered music detection API to identify AI-generated vs human-created music",
25
+ version="1.0.0",
26
+ docs_url="/",
27
+ redoc_url="/docs"
28
+ )
29
+
30
+ # API Key Configuration
31
+ API_KEY = os.getenv("MADVERSE_API_KEY", "madverse-music-api-key-2024") # Default key for demo
32
+
33
+ # Global model variable
34
+ model = None
35
+
36
+ async def verify_api_key(x_api_key: Annotated[str | None, Header()] = None):
37
+ """Verify API key from header"""
38
+ if x_api_key is None:
39
+ raise HTTPException(
40
+ status_code=401,
41
+ detail="Missing API key. Please provide a valid X-API-Key header."
42
+ )
43
+ if x_api_key != API_KEY:
44
+ raise HTTPException(
45
+ status_code=401,
46
+ detail="Invalid API key. Please provide a valid X-API-Key header."
47
+ )
48
+ return x_api_key
49
+
50
+ class MusicAnalysisRequest(BaseModel):
51
+ urls: List[HttpUrl]
52
+
53
+ def check_api_key_first(request: MusicAnalysisRequest, x_api_key: Annotated[str | None, Header()] = None):
54
+ """Check API key before processing request"""
55
+ if x_api_key is None:
56
+ raise HTTPException(
57
+ status_code=401,
58
+ detail="Missing API key. Please provide a valid X-API-Key header."
59
+ )
60
+ if x_api_key != API_KEY:
61
+ raise HTTPException(
62
+ status_code=401,
63
+ detail="Invalid API key. Please provide a valid X-API-Key header."
64
+ )
65
+ return request
66
+
67
+ class FileAnalysisResult(BaseModel):
68
+ url: str
69
+ success: bool
70
+ classification: Optional[str] = None # "Real" or "Fake"
71
+ confidence: Optional[float] = None # 0.0 to 1.0
72
+ probability: Optional[float] = None # Raw sigmoid probability
73
+ raw_score: Optional[float] = None # Raw model output
74
+ duration: Optional[float] = None # Audio duration in seconds
75
+ message: str
76
+ processing_time: Optional[float] = None
77
+ error: Optional[str] = None
78
+
79
+ class MusicAnalysisResponse(BaseModel):
80
+ success: bool
81
+ total_files: int
82
+ successful_analyses: int
83
+ failed_analyses: int
84
+ results: List[FileAnalysisResult]
85
+ total_processing_time: float
86
+ message: str
87
+
88
+ class ErrorResponse(BaseModel):
89
+ success: bool
90
+ error: str
91
+ message: str
92
+
93
+ @app.on_event("startup")
94
+ async def load_model():
95
+ """Load the AI model on startup"""
96
+ global model
97
+ try:
98
+ from sonics import HFAudioClassifier
99
+ print("🔄 Loading Madverse Music AI model...")
100
+ model = HFAudioClassifier.from_pretrained("awsaf49/sonics-spectttra-alpha-120s")
101
+ model.eval()
102
+ print("✅ Model loaded successfully!")
103
+ except Exception as e:
104
+ print(f"❌ Failed to load model: {e}")
105
+ raise
106
+
107
+ def cleanup_file(file_path: str):
108
+ """Background task to cleanup temporary files"""
109
+ try:
110
+ if os.path.exists(file_path):
111
+ os.unlink(file_path)
112
+ except:
113
+ pass
114
+
115
+ def download_audio(url: str, max_size_mb: int = 100) -> str:
116
+ """Download audio file from URL with size validation"""
117
+ try:
118
+ # Check if URL is accessible
119
+ response = requests.head(str(url), timeout=10)
120
+
121
+ # Check content size
122
+ content_length = response.headers.get('Content-Length')
123
+ if content_length and int(content_length) > max_size_mb * 1024 * 1024:
124
+ raise HTTPException(
125
+ status_code=413,
126
+ detail=f"File too large. Maximum size: {max_size_mb}MB"
127
+ )
128
+
129
+ # Download file
130
+ response = requests.get(str(url), timeout=30, stream=True)
131
+ response.raise_for_status()
132
+
133
+ # Create temporary file
134
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.tmp') as tmp_file:
135
+ downloaded_size = 0
136
+ for chunk in response.iter_content(chunk_size=8192):
137
+ downloaded_size += len(chunk)
138
+ if downloaded_size > max_size_mb * 1024 * 1024:
139
+ os.unlink(tmp_file.name)
140
+ raise HTTPException(
141
+ status_code=413,
142
+ detail=f"File too large. Maximum size: {max_size_mb}MB"
143
+ )
144
+ tmp_file.write(chunk)
145
+
146
+ return tmp_file.name
147
+
148
+ except requests.exceptions.RequestException as e:
149
+ raise HTTPException(
150
+ status_code=400,
151
+ detail=f"Failed to download audio: {str(e)}"
152
+ )
153
+ except Exception as e:
154
+ raise HTTPException(
155
+ status_code=500,
156
+ detail=f"Error downloading file: {str(e)}"
157
+ )
158
+
159
+ def classify_audio(file_path: str) -> dict:
160
+ """Classify audio file using the AI model"""
161
+ try:
162
+ # Load audio (model uses 16kHz sample rate)
163
+ audio, sr = librosa.load(file_path, sr=16000)
164
+
165
+ # Convert to tensor and add batch dimension
166
+ audio_tensor = torch.FloatTensor(audio).unsqueeze(0)
167
+
168
+ # Get prediction
169
+ with torch.no_grad():
170
+ output = model(audio_tensor)
171
+
172
+ # Convert logit to probability using sigmoid
173
+ prob = torch.sigmoid(output).item()
174
+
175
+ # Classify: prob < 0.5 = Real, prob >= 0.5 = Fake
176
+ if prob < 0.5:
177
+ classification = "Real"
178
+ confidence = (1 - prob) * 2 # Convert to 0-1 scale
179
+ else:
180
+ classification = "Fake"
181
+ confidence = (prob - 0.5) * 2 # Convert to 0-1 scale
182
+
183
+ return {
184
+ "classification": classification,
185
+ "confidence": min(confidence, 1.0), # Cap at 1.0
186
+ "probability": prob,
187
+ "raw_score": output.item(),
188
+ "duration": len(audio) / sr
189
+ }
190
+
191
+ except Exception as e:
192
+ raise HTTPException(
193
+ status_code=500,
194
+ detail=f"Error analyzing audio: {str(e)}"
195
+ )
196
+
197
+ async def process_single_url(url: str) -> FileAnalysisResult:
198
+ """Process a single URL and return result"""
199
+ start_time = time.time()
200
+
201
+ try:
202
+ # Download audio file
203
+ temp_file = download_audio(url)
204
+
205
+ # Classify audio
206
+ result = classify_audio(temp_file)
207
+
208
+ # Calculate processing time
209
+ processing_time = time.time() - start_time
210
+
211
+ # Cleanup file in background
212
+ try:
213
+ os.unlink(temp_file)
214
+ except:
215
+ pass
216
+
217
+ # Prepare response
218
+ emoji = "🎤" if result["classification"] == "Real" else "🤖"
219
+ message = f'{emoji} Detected as {result["classification"].lower()} music'
220
+
221
+ return FileAnalysisResult(
222
+ url=str(url),
223
+ success=True,
224
+ classification=result["classification"],
225
+ confidence=result["confidence"],
226
+ probability=result["probability"],
227
+ raw_score=result["raw_score"],
228
+ duration=result["duration"],
229
+ message=message,
230
+ processing_time=processing_time
231
+ )
232
+
233
+ except Exception as e:
234
+ processing_time = time.time() - start_time
235
+ error_msg = str(e)
236
+
237
+ return FileAnalysisResult(
238
+ url=str(url),
239
+ success=False,
240
+ message=f"❌ Failed to process: {error_msg}",
241
+ processing_time=processing_time,
242
+ error=error_msg
243
+ )
244
+
245
+ @app.post("/analyze", response_model=MusicAnalysisResponse)
246
+ async def analyze_music(
247
+ request: MusicAnalysisRequest = Depends(check_api_key_first)
248
+ ):
249
+ """
250
+ Analyze music from URL(s) to detect if it's AI-generated or human-created
251
+
252
+ - **urls**: Array of direct URLs to audio files (MP3, WAV, FLAC, M4A, OGG)
253
+ - Returns classification results for each file
254
+ - Processes files concurrently for better performance when multiple URLs provided
255
+ """
256
+ start_time = time.time()
257
+
258
+ if not model:
259
+ raise HTTPException(
260
+ status_code=503,
261
+ detail="Model not loaded. Please try again later."
262
+ )
263
+
264
+ if len(request.urls) > 50: # Limit processing
265
+ raise HTTPException(
266
+ status_code=400,
267
+ detail="Too many URLs. Maximum 50 files per request."
268
+ )
269
+
270
+ if len(request.urls) == 0:
271
+ raise HTTPException(
272
+ status_code=400,
273
+ detail="At least one URL is required."
274
+ )
275
+
276
+ try:
277
+ # Process all URLs concurrently with limited concurrency
278
+ semaphore = asyncio.Semaphore(5) # Limit to 5 concurrent downloads
279
+
280
+ async def process_with_semaphore(url):
281
+ async with semaphore:
282
+ return await process_single_url(str(url))
283
+
284
+ # Create tasks for all URLs
285
+ tasks = [process_with_semaphore(url) for url in request.urls]
286
+
287
+ # Wait for all tasks to complete
288
+ results = await asyncio.gather(*tasks, return_exceptions=True)
289
+
290
+ # Process results and handle any exceptions
291
+ processed_results = []
292
+ successful_count = 0
293
+ failed_count = 0
294
+
295
+ for i, result in enumerate(results):
296
+ if isinstance(result, Exception):
297
+ # Handle exception case
298
+ processed_results.append(FileAnalysisResult(
299
+ url=str(request.urls[i]),
300
+ success=False,
301
+ message=f"❌ Processing failed: {str(result)}",
302
+ error=str(result)
303
+ ))
304
+ failed_count += 1
305
+ else:
306
+ processed_results.append(result)
307
+ if result.success:
308
+ successful_count += 1
309
+ else:
310
+ failed_count += 1
311
+
312
+ # Calculate total processing time
313
+ total_processing_time = time.time() - start_time
314
+
315
+ # Prepare summary message
316
+ total_files = len(request.urls)
317
+ if total_files == 1:
318
+ # Single file message
319
+ if successful_count == 1:
320
+ message = processed_results[0].message
321
+ else:
322
+ message = processed_results[0].message
323
+ else:
324
+ # Multiple files message
325
+ if successful_count == total_files:
326
+ message = f"✅ Successfully analyzed all {total_files} files"
327
+ elif successful_count > 0:
328
+ message = f"⚠️ Analyzed {successful_count}/{total_files} files successfully"
329
+ else:
330
+ message = f"❌ Failed to analyze any files"
331
+
332
+ return MusicAnalysisResponse(
333
+ success=successful_count > 0,
334
+ total_files=total_files,
335
+ successful_analyses=successful_count,
336
+ failed_analyses=failed_count,
337
+ results=processed_results,
338
+ total_processing_time=total_processing_time,
339
+ message=message
340
+ )
341
+
342
+ except Exception as e:
343
+ raise HTTPException(
344
+ status_code=500,
345
+ detail=f"Internal server error during processing: {str(e)}"
346
+ )
347
+
348
+ @app.get("/health")
349
+ async def health_check():
350
+ """Health check endpoint"""
351
+ return {
352
+ "status": "healthy",
353
+ "model_loaded": model is not None,
354
+ "service": "Madverse Music API"
355
+ }
356
+
357
+ @app.get("/info")
358
+ async def get_info():
359
+ """Get API information"""
360
+ return {
361
+ "name": "Madverse Music API",
362
+ "version": "1.0.0",
363
+ "description": "AI-powered music detection to identify AI-generated vs human-created music",
364
+ "model": "SpecTTTra-α (120s)",
365
+ "accuracy": {
366
+ "f1_score": 0.97,
367
+ "sensitivity": 0.96,
368
+ "specificity": 0.99
369
+ },
370
+ "supported_formats": ["MP3", "WAV", "FLAC", "M4A", "OGG"],
371
+ "max_file_size": "100MB",
372
+ "max_duration": "120 seconds",
373
+ "authentication": {
374
+ "required": True,
375
+ "type": "API Key",
376
+ "header": "X-API-Key",
377
+ "example": "X-API-Key: your-api-key-here"
378
+ },
379
+ "usage": {
380
+ "curl_example": "curl -X POST 'http://localhost:8000/analyze' -H 'X-API-Key: your-api-key' -H 'Content-Type: application/json' -d '{\"url\":\"https://example.com/song.mp3\"}'"
381
+ }
382
+ }
383
+
384
+ if __name__ == "__main__":
385
+ uvicorn.run(app, host="0.0.0.0", port=8000)