Syncbuz120 commited on
Commit
0e92f07
·
1 Parent(s): 9b1d949

Prepare Flask backend for Hugging Face Spaces deployment

Browse files
Files changed (7) hide show
  1. .dockerignore +6 -0
  2. .gitignore +124 -0
  3. Dockerfile +33 -0
  4. README.md +0 -10
  5. app.py +273 -0
  6. model/generate.py +262 -0
  7. requirements.txt +0 -0
.dockerignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ __pycache__
2
+ *.pyc
3
+ .git
4
+ .vscode
5
+ *.log
6
+ tests
.gitignore ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # Distribution / packaging
7
+ .Python
8
+ build/
9
+ develop-eggs/
10
+ dist/
11
+ downloads/
12
+ eggs/
13
+ .eggs/
14
+ lib/
15
+ lib64/
16
+ parts/
17
+ sdist/
18
+ var/
19
+ wheels/
20
+ pip-wheel-metadata/
21
+ share/python-wheels/
22
+ *.egg-info/
23
+ .installed.cfg
24
+ *.egg
25
+
26
+ # PyInstaller
27
+ # Usually these files are written by a python script from a template
28
+ *.manifest
29
+ *.spec
30
+
31
+ # Installer logs
32
+ pip-log.txt
33
+ pip-delete-this-directory.txt
34
+
35
+ # Unit test / coverage reports
36
+ htmlcov/
37
+ .tox/
38
+ .nox/
39
+ .coverage
40
+ .coverage.*
41
+ .cache
42
+ nosetests.xml
43
+ coverage.xml
44
+ *.cover
45
+ *.py,cover
46
+ .hypothesis/
47
+ .pytest_cache/
48
+
49
+ # Translations
50
+ *.mo
51
+ *.pot
52
+
53
+ # Django stuff:
54
+ *.log
55
+ local_settings.py
56
+ db.sqlite3
57
+
58
+ # Flask stuff:
59
+ instance/
60
+ .webassets-cache
61
+
62
+ # Scrapy stuff:
63
+ .scrapy
64
+
65
+ # Sphinx documentation
66
+ docs/_build/
67
+
68
+ # PyBuilder
69
+ target/
70
+
71
+ # Jupyter Notebook
72
+ .ipynb_checkpoints
73
+
74
+ # IPython
75
+ profile_default/
76
+ ipython_config.py
77
+
78
+ # pyenv
79
+ .python-version
80
+
81
+ # pipenv
82
+ pipenv.lock
83
+
84
+ # poetry
85
+ poetry.lock
86
+
87
+ # env files
88
+ .env
89
+ .venv
90
+ env/
91
+ venv/
92
+ ENV/
93
+ env.bak/
94
+ venv.bak/
95
+
96
+ # Spyder project settings
97
+ .spyderproject
98
+ .spyproject
99
+
100
+ # Rope project settings
101
+ .ropeproject
102
+
103
+ # mkdocs documentation
104
+ /site
105
+
106
+ # mypy
107
+ .mypy_cache/
108
+ .dmypy.json
109
+ dmypy.json
110
+
111
+ # Pyre type checker
112
+ .pyre/
113
+
114
+ # pytype static type analyzer
115
+ .pytype/
116
+
117
+ # C extensions
118
+ *.so
119
+
120
+ # VS Code settings
121
+ .vscode/
122
+
123
+ # PyCharm settings
124
+ .idea/
Dockerfile ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install system dependencies
6
+ RUN apt-get update && apt-get install -y \
7
+ build-essential \
8
+ git \
9
+ && rm -rf /var/lib/apt/lists/*
10
+
11
+ # Copy and install requirements
12
+ COPY requirements.txt .
13
+ RUN pip install --no-cache-dir --upgrade pip
14
+ RUN pip install --no-cache-dir -r requirements.txt
15
+ RUN pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
16
+
17
+ # Set environment variables for AI models
18
+ ENV TRANSFORMERS_CACHE=/tmp/model_cache
19
+ ENV HF_HOME=/tmp/model_cache
20
+ ENV TOKENIZERS_PARALLELISM=false
21
+ ENV OMP_NUM_THREADS=1
22
+
23
+ # Create cache directory
24
+ RUN mkdir -p /tmp/model_cache
25
+
26
+ # Copy application
27
+ COPY . .
28
+
29
+ # ✅ Expose correct port for Hugging Face Spaces
30
+ EXPOSE 7860
31
+
32
+ # ✅ Run app on correct port
33
+ CMD ["gunicorn", "--bind", "0.0.0.0:7860", "--workers", "1", "--timeout", "120", "app:app"]
README.md DELETED
@@ -1,10 +0,0 @@
1
- ---
2
- title: TestCaseGenerator
3
- emoji: 💻
4
- colorFrom: red
5
- colorTo: blue
6
- sdk: docker
7
- pinned: false
8
- ---
9
-
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
app.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify
2
+ from flask_cors import CORS
3
+ from model.generate import generate_test_cases, get_generator, monitor_memory
4
+ import os
5
+ import logging
6
+ import gc
7
+ import psutil
8
+ from functools import wraps
9
+ import time
10
+ import threading
11
+
12
+ # Configure logging for Railway
13
+ logging.basicConfig(
14
+ level=logging.INFO,
15
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
16
+ )
17
+ logger = logging.getLogger(__name__)
18
+
19
+ app = Flask(__name__)
20
+ CORS(app)
21
+
22
+ # Configuration for Railway
23
+ app.config['JSON_SORT_KEYS'] = False
24
+ app.config['JSONIFY_PRETTYPRINT_REGULAR'] = False # Reduce response size
25
+
26
+ # Thread-safe initialization
27
+ _init_lock = threading.Lock()
28
+ _initialized = False
29
+
30
+ def init_model():
31
+ """Initialize model on startup"""
32
+ try:
33
+ # Skip AI model loading in low memory environments
34
+ memory_mb = psutil.Process().memory_info().rss / 1024 / 1024
35
+ if memory_mb > 200 or os.environ.get('RAILWAY_ENVIRONMENT'):
36
+ logger.info("⚠️ Skipping AI model loading due to memory constraints")
37
+ logger.info("🔧 Using template-based generation mode")
38
+ return True
39
+
40
+ logger.info("🚀 Initializing AI model...")
41
+ generator = get_generator()
42
+ model_info = generator.get_model_info()
43
+ logger.info(f"✅ Model initialized: {model_info['model_name']} | Memory: {model_info['memory_usage']}")
44
+ return True
45
+ except Exception as e:
46
+ logger.error(f"❌ Model initialization failed: {e}")
47
+ logger.info("🔧 Falling back to template-based generation")
48
+ return False
49
+
50
+ def check_health():
51
+ """Check system health"""
52
+ try:
53
+ memory_mb = psutil.Process().memory_info().rss / 1024 / 1024
54
+ return {
55
+ "status": "healthy" if memory_mb < 450 else "warning",
56
+ "memory_usage": f"{memory_mb:.1f}MB",
57
+ "memory_limit": "512MB"
58
+ }
59
+ except Exception:
60
+ return {"status": "unknown", "memory_usage": "unavailable"}
61
+
62
+ def smart_memory_monitor(func):
63
+ """Enhanced memory monitoring with automatic cleanup"""
64
+ @wraps(func)
65
+ def wrapper(*args, **kwargs):
66
+ start_time = time.time()
67
+ try:
68
+ initial_memory = psutil.Process().memory_info().rss / 1024 / 1024
69
+ logger.info(f"🔍 {func.__name__} started | Memory: {initial_memory:.1f}MB")
70
+
71
+ if initial_memory > 400:
72
+ logger.warning("⚠️ High memory detected, forcing cleanup...")
73
+ gc.collect()
74
+
75
+ result = func(*args, **kwargs)
76
+ return result
77
+ except Exception as e:
78
+ logger.error(f"❌ Error in {func.__name__}: {str(e)}")
79
+ return jsonify({
80
+ "error": "Internal server error occurred",
81
+ "message": "Please try again or contact support"
82
+ }), 500
83
+ finally:
84
+ final_memory = psutil.Process().memory_info().rss / 1024 / 1024
85
+ execution_time = time.time() - start_time
86
+
87
+ logger.info(f"✅ {func.__name__} completed | Memory: {final_memory:.1f}MB | Time: {execution_time:.2f}s")
88
+
89
+ if final_memory > 450:
90
+ logger.warning("🧹 High memory usage, forcing aggressive cleanup...")
91
+ gc.collect()
92
+ post_cleanup_memory = psutil.Process().memory_info().rss / 1024 / 1024
93
+ logger.info(f"🧹 Post-cleanup memory: {post_cleanup_memory:.1f}MB")
94
+ return wrapper
95
+
96
+ def ensure_initialized():
97
+ """Ensure model is initialized (thread-safe)"""
98
+ global _initialized
99
+ if not _initialized:
100
+ with _init_lock:
101
+ if not _initialized:
102
+ logger.info("🚀 Flask app starting up on Railway...")
103
+ success = init_model()
104
+ if success:
105
+ logger.info("✅ Startup completed successfully")
106
+ else:
107
+ logger.warning("⚠️ Model initialization failed, using template mode")
108
+ _initialized = True
109
+
110
+ @app.before_request
111
+ def before_request():
112
+ """Initialize model on first request (Flask 2.2+ compatible)"""
113
+ ensure_initialized()
114
+
115
+ @app.route('/')
116
+ def home():
117
+ """Health check endpoint with system status"""
118
+ health_data = check_health()
119
+ try:
120
+ generator = get_generator()
121
+ model_info = generator.get_model_info()
122
+ except Exception:
123
+ model_info = {
124
+ "model_name": "Template-Based Generator",
125
+ "status": "template_mode",
126
+ "optimization": "memory_safe"
127
+ }
128
+
129
+ return jsonify({
130
+ "message": "AI Test Case Generator Backend is running",
131
+ "status": health_data["status"],
132
+ "memory_usage": health_data["memory_usage"],
133
+ "model": {
134
+ "name": model_info["model_name"],
135
+ "status": model_info["status"],
136
+ "optimization": model_info.get("optimization", "standard")
137
+ },
138
+ "version": "1.0.0-railway-optimized"
139
+ })
140
+
141
+ @app.route('/health')
142
+ def health():
143
+ """Dedicated health check for Railway monitoring"""
144
+ health_status = check_health()
145
+ try:
146
+ generator = get_generator()
147
+ model_info = generator.get_model_info()
148
+ model_loaded = model_info["status"] == "loaded"
149
+ except Exception:
150
+ model_loaded = False
151
+
152
+ return jsonify({
153
+ "status": health_status["status"],
154
+ "memory": health_status["memory_usage"],
155
+ "model_loaded": model_loaded,
156
+ "uptime": "ok"
157
+ })
158
+
159
+ @app.route('/generate_test_cases', methods=['POST'])
160
+ @smart_memory_monitor
161
+ def generate():
162
+ """Generate test cases with enhanced error handling"""
163
+ if not request.is_json:
164
+ return jsonify({"error": "Request must be JSON"}), 400
165
+
166
+ data = request.get_json()
167
+ if not data:
168
+ return jsonify({"error": "No JSON data provided"}), 400
169
+
170
+ srs_text = data.get('srs', '').strip()
171
+
172
+ if not srs_text:
173
+ return jsonify({"error": "No SRS or prompt content provided"}), 400
174
+
175
+ if len(srs_text) > 5000:
176
+ logger.warning(f"SRS text truncated from {len(srs_text)} to 5000 characters")
177
+ srs_text = srs_text[:5000]
178
+
179
+ try:
180
+ logger.info(f"🎯 Generating test cases for input ({len(srs_text)} chars)")
181
+ test_cases = generate_test_cases(srs_text)
182
+
183
+ if not test_cases or len(test_cases) == 0:
184
+ logger.error("No test cases generated")
185
+ return jsonify({"error": "Failed to generate test cases"}), 500
186
+
187
+ try:
188
+ generator = get_generator()
189
+ model_info = generator.get_model_info()
190
+ model_used = model_info.get("model_name", "Unknown Model")
191
+ generation_method = model_info.get("status", "unknown")
192
+ except Exception:
193
+ model_used = "Template-Based Generator"
194
+ generation_method = "template_mode"
195
+
196
+ if model_used == "Template-Based Generator":
197
+ model_algorithm = "Rule-based Template"
198
+ model_reason = "Used rule-based generation due to memory constraints or fallback condition."
199
+ elif "distilgpt2" in model_used:
200
+ model_algorithm = "Transformer-based LM"
201
+ model_reason = "Used DistilGPT2 for balanced performance and memory efficiency."
202
+ elif "DialoGPT" in model_used:
203
+ model_algorithm = "Transformer-based LM"
204
+ model_reason = "Used DialoGPT-small as it fits within memory limits and handles conversational input well."
205
+ else:
206
+ model_algorithm = "Transformer-based LM"
207
+ model_reason = "Used available Hugging Face causal LM due to sufficient resources."
208
+
209
+ logger.info(f"✅ Successfully generated {len(test_cases)} test cases")
210
+
211
+ return jsonify({
212
+ "test_cases": test_cases,
213
+ "count": len(test_cases),
214
+ "model_used": model_used,
215
+ "generation_method": generation_method,
216
+ "model_algorithm": model_algorithm,
217
+ "model_reason": model_reason
218
+ })
219
+
220
+ except Exception as e:
221
+ logger.error(f"❌ Test case generation failed: {str(e)}")
222
+ return jsonify({
223
+ "error": "Failed to generate test cases",
224
+ "message": "Please try again with different input"
225
+ }), 500
226
+
227
+ @app.route('/model_info')
228
+ def model_info():
229
+ """Get current model information"""
230
+ try:
231
+ generator = get_generator()
232
+ info = generator.get_model_info()
233
+ health_data = check_health()
234
+
235
+ return jsonify({
236
+ "model": info,
237
+ "system": health_data
238
+ })
239
+ except Exception as e:
240
+ logger.error(f"Error getting model info: {e}")
241
+ return jsonify({"error": "Unable to get model information"}), 500
242
+
243
+ @app.errorhandler(404)
244
+ def not_found(error):
245
+ return jsonify({"error": "Endpoint not found"}), 404
246
+
247
+ @app.errorhandler(405)
248
+ def method_not_allowed(error):
249
+ return jsonify({"error": "Method not allowed"}), 405
250
+
251
+ @app.errorhandler(500)
252
+ def internal_error(error):
253
+ logger.error(f"Internal server error: {error}")
254
+ return jsonify({"error": "Internal server error"}), 500
255
+
256
+ if __name__ == '__main__':
257
+ port = int(os.environ.get("PORT", 5000))
258
+ debug_mode = os.environ.get("FLASK_ENV") == "development"
259
+
260
+ logger.info(f"🚀 Starting Flask app on port {port}")
261
+ logger.info(f"🔧 Debug mode: {debug_mode}")
262
+ logger.info(f"🖥️ Environment: {'Railway' if os.environ.get('RAILWAY_ENVIRONMENT') else 'Local'}")
263
+
264
+ if not os.environ.get('RAILWAY_ENVIRONMENT'):
265
+ ensure_initialized()
266
+
267
+ app.run(
268
+ host='0.0.0.0',
269
+ port=port,
270
+ debug=debug_mode,
271
+ threaded=True,
272
+ use_reloader=False
273
+ )
model/generate.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ import logging
5
+ import psutil
6
+ import re
7
+ import gc
8
+
9
+ # Initialize logger
10
+ logger = logging.getLogger(__name__)
11
+ logging.basicConfig(level=logging.INFO)
12
+
13
+ # List of memory-optimized models
14
+ MEMORY_OPTIMIZED_MODELS = [
15
+ "gpt2", # ~500MB
16
+ "distilgpt2", # ~250MB
17
+ "microsoft/DialoGPT-small", # ~250MB
18
+ "huggingface/CodeBERTa-small-v1", # Code tasks
19
+ ]
20
+
21
+ # Singleton state
22
+ _generator_instance = None
23
+
24
+ def get_optimal_model_for_memory():
25
+ """Select the best model based on available memory."""
26
+ available_memory = psutil.virtual_memory().available / (1024 * 1024) # MB
27
+ logger.info(f"Available memory: {available_memory:.1f}MB")
28
+
29
+ if available_memory < 300:
30
+ return None # Use template fallback
31
+ elif available_memory < 600:
32
+ return "microsoft/DialoGPT-small"
33
+ else:
34
+ return "distilgpt2"
35
+
36
+ def load_model_with_memory_optimization(model_name):
37
+ """Load model with low memory settings."""
38
+ try:
39
+ logger.info(f"Loading {model_name} with memory optimizations...")
40
+
41
+ tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left', use_fast=True)
42
+
43
+ if tokenizer.pad_token is None:
44
+ tokenizer.pad_token = tokenizer.eos_token
45
+
46
+ model = AutoModelForCausalLM.from_pretrained(
47
+ model_name,
48
+ torch_dtype=torch.float16,
49
+ device_map="cpu",
50
+ low_cpu_mem_usage=True,
51
+ use_cache=False,
52
+ )
53
+
54
+ model.eval()
55
+ model.gradient_checkpointing_enable()
56
+ logger.info(f"✅ Model {model_name} loaded successfully")
57
+ return tokenizer, model
58
+
59
+ except Exception as e:
60
+ logger.error(f"❌ Failed to load model {model_name}: {e}")
61
+ return None, None
62
+
63
+ def extract_keywords(text):
64
+ common_keywords = [
65
+ 'login', 'authentication', 'user', 'password', 'database', 'data',
66
+ 'interface', 'api', 'function', 'feature', 'requirement', 'system',
67
+ 'input', 'output', 'validation', 'error', 'security', 'performance'
68
+ ]
69
+ words = re.findall(r'\b\w+\b', text.lower())
70
+ return [word for word in words if word in common_keywords]
71
+
72
+ def generate_template_based_test_cases(srs_text):
73
+ keywords = extract_keywords(srs_text)
74
+ test_cases = []
75
+
76
+ if any(word in keywords for word in ['login', 'authentication', 'user', 'password']):
77
+ test_cases.extend([
78
+ {
79
+ "id": "TC_001",
80
+ "title": "Valid Login Test",
81
+ "description": "Test login with valid credentials",
82
+ "steps": ["Enter valid username", "Enter valid password", "Click login"],
83
+ "expected": "User should be logged in successfully"
84
+ },
85
+ {
86
+ "id": "TC_002",
87
+ "title": "Invalid Login Test",
88
+ "description": "Test login with invalid credentials",
89
+ "steps": ["Enter invalid username", "Enter invalid password", "Click login"],
90
+ "expected": "Error message should be displayed"
91
+ }
92
+ ])
93
+
94
+ if any(word in keywords for word in ['database', 'data', 'store', 'save']):
95
+ test_cases.append({
96
+ "id": "TC_003",
97
+ "title": "Data Storage Test",
98
+ "description": "Test data storage functionality",
99
+ "steps": ["Enter data", "Save data", "Verify storage"],
100
+ "expected": "Data should be stored correctly"
101
+ })
102
+
103
+ if not test_cases:
104
+ test_cases = [
105
+ {
106
+ "id": "TC_001",
107
+ "title": "Basic Functionality Test",
108
+ "description": "Test basic system functionality",
109
+ "steps": ["Access the system", "Perform basic operations", "Verify results"],
110
+ "expected": "System should work as expected"
111
+ }
112
+ ]
113
+
114
+ return test_cases
115
+
116
+ def parse_generated_test_cases(generated_text):
117
+ lines = generated_text.split('\n')
118
+ test_cases = []
119
+ current_case = {}
120
+ case_counter = 1
121
+
122
+ for line in lines:
123
+ line = line.strip()
124
+ if line.startswith(('1.', '2.', '3.', 'TC', 'Test')):
125
+ if current_case:
126
+ test_cases.append(current_case)
127
+ current_case = {
128
+ "id": f"TC_{case_counter:03d}",
129
+ "title": line,
130
+ "description": line,
131
+ "steps": ["Execute the test"],
132
+ "expected": "Test should pass"
133
+ }
134
+ case_counter += 1
135
+
136
+ if current_case:
137
+ test_cases.append(current_case)
138
+
139
+ if not test_cases:
140
+ return [{
141
+ "id": "TC_001",
142
+ "title": "Generated Test Case",
143
+ "description": "Auto-generated test case based on requirements",
144
+ "steps": ["Review requirements", "Execute test", "Verify results"],
145
+ "expected": "Requirements should be met"
146
+ }]
147
+
148
+ return test_cases
149
+
150
+ def generate_with_ai_model(srs_text, tokenizer, model):
151
+ max_input_length = 200
152
+ if len(srs_text) > max_input_length:
153
+ srs_text = srs_text[:max_input_length]
154
+
155
+ prompt = f"""Generate test cases for this software requirement:
156
+ {srs_text}
157
+
158
+ Test Cases:
159
+ 1."""
160
+
161
+ try:
162
+ inputs = tokenizer.encode(
163
+ prompt,
164
+ return_tensors="pt",
165
+ max_length=150,
166
+ truncation=True
167
+ )
168
+
169
+ with torch.no_grad():
170
+ outputs = model.generate(
171
+ inputs,
172
+ max_new_tokens=100,
173
+ num_return_sequences=1,
174
+ temperature=0.7,
175
+ do_sample=True,
176
+ pad_token_id=tokenizer.eos_token_id,
177
+ use_cache=False,
178
+ )
179
+
180
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
181
+ del inputs, outputs
182
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
183
+ return parse_generated_test_cases(generated_text)
184
+
185
+ except Exception as e:
186
+ logger.error(f"❌ AI generation failed: {e}")
187
+ raise
188
+
189
+ def generate_with_fallback(srs_text):
190
+ model_name = get_optimal_model_for_memory()
191
+
192
+ if model_name:
193
+ tokenizer, model = load_model_with_memory_optimization(model_name)
194
+ if tokenizer and model:
195
+ try:
196
+ test_cases = generate_with_ai_model(srs_text, tokenizer, model)
197
+ reason = get_algorithm_reason(model_name)
198
+ return test_cases, model_name, "transformer (causal LM)", reason
199
+ except Exception as e:
200
+ logger.warning(f"AI generation failed: {e}, falling back to templates")
201
+
202
+ logger.info("⚠️ Using fallback template-based generation")
203
+ test_cases = generate_template_based_test_cases(srs_text)
204
+ return test_cases, "Template-Based Generator", "rule-based", "Low memory - fallback to rule-based generation"
205
+
206
+ # ✅ Function exposed to app.py
207
+ def generate_test_cases(srs_text):
208
+ return generate_with_fallback(srs_text)[0]
209
+
210
+ def get_generator():
211
+ global _generator_instance
212
+ if _generator_instance is None:
213
+ class Generator:
214
+ def __init__(self):
215
+ self.model_name = get_optimal_model_for_memory()
216
+ self.tokenizer = None
217
+ self.model = None
218
+ if self.model_name:
219
+ self.tokenizer, self.model = load_model_with_memory_optimization(self.model_name)
220
+
221
+ def get_model_info(self):
222
+ mem = psutil.Process().memory_info().rss / 1024 / 1024
223
+ return {
224
+ "model_name": self.model_name if self.model_name else "Template-Based Generator",
225
+ "status": "loaded" if self.model else "template_mode",
226
+ "memory_usage": f"{mem:.1f}MB",
227
+ "optimization": "low_memory"
228
+ }
229
+
230
+ _generator_instance = Generator()
231
+
232
+ return _generator_instance
233
+
234
+ def monitor_memory():
235
+ mem = psutil.Process().memory_info().rss / 1024 / 1024
236
+ logger.info(f"Memory usage: {mem:.1f}MB")
237
+ if mem > 450:
238
+ gc.collect()
239
+ logger.info("Memory cleanup triggered")
240
+
241
+ # ✅ NEW FUNCTION for enhanced output: test cases + model info + reason
242
+ def generate_test_cases_and_info(input_text):
243
+ test_cases, model_name, algorithm_used, reason = generate_with_fallback(input_text)
244
+ return {
245
+ "model": model_name,
246
+ "algorithm": algorithm_used,
247
+ "reason": reason,
248
+ "test_cases": test_cases
249
+ }
250
+
251
+ # ✅ Explain why each algorithm is selected
252
+ def get_algorithm_reason(model_name):
253
+ if model_name == "microsoft/DialoGPT-small":
254
+ return "Selected due to low memory availability; DialoGPT-small provides conversational understanding in limited memory environments."
255
+ elif model_name == "distilgpt2":
256
+ return "Selected for its balance between performance and low memory usage. Ideal for small environments needing causal language modeling."
257
+ elif model_name == "gpt2":
258
+ return "Chosen for general-purpose text generation with moderate memory headroom."
259
+ elif model_name is None:
260
+ return "No model used due to insufficient memory. Rule-based template generation chosen instead."
261
+ else:
262
+ return "Model selected based on best tradeoff between memory usage and language generation capability."
requirements.txt ADDED
Binary file (2.48 kB). View file