Spaces:
Paused
Paused
| """ | |
| Database integration for Streamlit application. | |
| This module provides functions to interact with the database for the Streamlit frontend. | |
| It wraps the async database functions in sync functions for Streamlit compatibility. | |
| """ | |
| import os | |
| import asyncio | |
| import pandas as pd | |
| from typing import List, Dict, Any, Optional, Union, Tuple | |
| from datetime import datetime, timedelta | |
| from sqlalchemy.orm import sessionmaker | |
| from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession | |
| # Import database models | |
| from src.models.threat import Threat, ThreatSeverity, ThreatStatus, ThreatCategory | |
| from src.models.indicator import Indicator, IndicatorType | |
| from src.models.dark_web_content import DarkWebContent, DarkWebMention, ContentType, ContentStatus | |
| from src.models.alert import Alert, AlertStatus, AlertCategory | |
| from src.models.report import Report, ReportType, ReportStatus | |
| # Import service functions | |
| from src.api.services.dark_web_content_service import ( | |
| create_content, get_content_by_id, get_contents, count_contents, | |
| create_mention, get_mentions, create_threat_from_content | |
| ) | |
| from src.api.services.alert_service import ( | |
| create_alert, get_alert_by_id, get_alerts, count_alerts, | |
| update_alert_status, mark_alert_as_read, get_alert_counts_by_severity | |
| ) | |
| from src.api.services.threat_service import ( | |
| create_threat, get_threat_by_id, get_threats, count_threats, | |
| update_threat, add_indicator_to_threat, get_threat_statistics | |
| ) | |
| from src.api.services.report_service import ( | |
| create_report, get_report_by_id, get_reports, count_reports, | |
| update_report, add_threat_to_report, publish_report | |
| ) | |
| # Import schemas | |
| from src.api.schemas import PaginationParams | |
| # Get database URL from environment | |
| db_url = os.getenv("DATABASE_URL", "") | |
| if db_url.startswith("postgresql://"): | |
| # Remove sslmode parameter if present which causes issues with asyncpg | |
| if "?" in db_url: | |
| base_url, params = db_url.split("?", 1) | |
| param_list = params.split("&") | |
| filtered_params = [p for p in param_list if not p.startswith("sslmode=")] | |
| if filtered_params: | |
| db_url = f"{base_url}?{'&'.join(filtered_params)}" | |
| else: | |
| db_url = base_url | |
| ASYNC_DATABASE_URL = db_url.replace("postgresql://", "postgresql+asyncpg://", 1) | |
| else: | |
| ASYNC_DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost:5432/postgres" | |
| # Create async engine | |
| engine = create_async_engine( | |
| ASYNC_DATABASE_URL, | |
| echo=False, | |
| future=True, | |
| pool_size=5, | |
| max_overflow=10 | |
| ) | |
| # Create async session factory | |
| async_session = sessionmaker( | |
| engine, | |
| class_=AsyncSession, | |
| expire_on_commit=False | |
| ) | |
| def run_async(coro): | |
| """Run an async function in a sync context.""" | |
| try: | |
| loop = asyncio.get_event_loop() | |
| except RuntimeError: | |
| loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(loop) | |
| return loop.run_until_complete(coro) | |
| async def get_session(): | |
| """Get an async database session.""" | |
| async with async_session() as session: | |
| yield session | |
| def get_db_session(): | |
| """Get a database session for use in Streamlit.""" | |
| try: | |
| session_gen = get_session().__aiter__() | |
| return run_async(session_gen.__anext__()) | |
| except StopAsyncIteration: | |
| return None | |
| async def get_async_session(): | |
| """ | |
| Async context manager for database sessions. | |
| Usage: | |
| async with get_async_session() as session: | |
| # Use session here | |
| """ | |
| session = async_session() | |
| try: | |
| yield session | |
| await session.commit() | |
| except Exception as e: | |
| await session.rollback() | |
| raise e | |
| finally: | |
| await session.close() | |
| # Dark Web Content functions | |
| def get_dark_web_contents( | |
| page: int = 1, | |
| size: int = 10, | |
| content_type: Optional[List[ContentType]] = None, | |
| content_status: Optional[List[ContentStatus]] = None, | |
| source_name: Optional[str] = None, | |
| search_query: Optional[str] = None, | |
| from_date: Optional[datetime] = None, | |
| to_date: Optional[datetime] = None, | |
| ) -> pd.DataFrame: | |
| """ | |
| Get dark web contents as a DataFrame. | |
| Args: | |
| page: Page number | |
| size: Page size | |
| content_type: Filter by content type | |
| content_status: Filter by content status | |
| source_name: Filter by source name | |
| search_query: Search in title and content | |
| from_date: Filter by scraped_at >= from_date | |
| to_date: Filter by scraped_at <= to_date | |
| Returns: | |
| pd.DataFrame: DataFrame with dark web contents | |
| """ | |
| session = get_db_session() | |
| if not session: | |
| return pd.DataFrame() | |
| contents = run_async(get_contents( | |
| db=session, | |
| pagination=PaginationParams(page=page, size=size), | |
| content_type=content_type, | |
| content_status=content_status, | |
| source_name=source_name, | |
| search_query=search_query, | |
| from_date=from_date, | |
| to_date=to_date, | |
| )) | |
| if not contents: | |
| return pd.DataFrame() | |
| # Convert to DataFrame | |
| data = [] | |
| for content in contents: | |
| data.append({ | |
| "id": content.id, | |
| "url": content.url, | |
| "title": content.title, | |
| "content_type": content.content_type.value if content.content_type else None, | |
| "content_status": content.content_status.value if content.content_status else None, | |
| "source_name": content.source_name, | |
| "source_type": content.source_type, | |
| "language": content.language, | |
| "scraped_at": content.scraped_at, | |
| "relevance_score": content.relevance_score, | |
| "sentiment_score": content.sentiment_score, | |
| }) | |
| return pd.DataFrame(data) | |
| def add_dark_web_content( | |
| url: str, | |
| content: str, | |
| title: Optional[str] = None, | |
| content_type: ContentType = ContentType.OTHER, | |
| source_name: Optional[str] = None, | |
| source_type: Optional[str] = None, | |
| ) -> Optional[DarkWebContent]: | |
| """ | |
| Add a new dark web content. | |
| Args: | |
| url: URL of the content | |
| content: Text content | |
| title: Title of the content | |
| content_type: Type of content | |
| source_name: Name of the source | |
| source_type: Type of source | |
| Returns: | |
| Optional[DarkWebContent]: Created content or None | |
| """ | |
| session = get_db_session() | |
| if not session: | |
| return None | |
| return run_async(create_content( | |
| db=session, | |
| url=url, | |
| content=content, | |
| title=title, | |
| content_type=content_type, | |
| source_name=source_name, | |
| source_type=source_type, | |
| )) | |
| def get_dark_web_mentions( | |
| page: int = 1, | |
| size: int = 10, | |
| keyword: Optional[str] = None, | |
| content_id: Optional[int] = None, | |
| is_verified: Optional[bool] = None, | |
| from_date: Optional[datetime] = None, | |
| to_date: Optional[datetime] = None, | |
| ) -> pd.DataFrame: | |
| """ | |
| Get dark web mentions as a DataFrame. | |
| Args: | |
| page: Page number | |
| size: Page size | |
| keyword: Filter by keyword | |
| content_id: Filter by content ID | |
| is_verified: Filter by verification status | |
| from_date: Filter by created_at >= from_date | |
| to_date: Filter by created_at <= to_date | |
| Returns: | |
| pd.DataFrame: DataFrame with dark web mentions | |
| """ | |
| session = get_db_session() | |
| if not session: | |
| return pd.DataFrame() | |
| mentions = run_async(get_mentions( | |
| db=session, | |
| pagination=PaginationParams(page=page, size=size), | |
| keyword=keyword, | |
| content_id=content_id, | |
| is_verified=is_verified, | |
| from_date=from_date, | |
| to_date=to_date, | |
| )) | |
| if not mentions: | |
| return pd.DataFrame() | |
| # Convert to DataFrame | |
| data = [] | |
| for mention in mentions: | |
| data.append({ | |
| "id": mention.id, | |
| "content_id": mention.content_id, | |
| "keyword": mention.keyword, | |
| "snippet": mention.snippet, | |
| "mention_type": mention.mention_type, | |
| "confidence": mention.confidence, | |
| "is_verified": mention.is_verified, | |
| "created_at": mention.created_at, | |
| }) | |
| return pd.DataFrame(data) | |
| def add_dark_web_mention( | |
| content_id: int, | |
| keyword: str, | |
| context: Optional[str] = None, | |
| snippet: Optional[str] = None, | |
| ) -> Optional[DarkWebMention]: | |
| """ | |
| Add a new dark web mention. | |
| Args: | |
| content_id: ID of the content where the mention was found | |
| keyword: Keyword that was mentioned | |
| context: Text surrounding the mention | |
| snippet: Extract of text containing the mention | |
| Returns: | |
| Optional[DarkWebMention]: Created mention or None | |
| """ | |
| session = get_db_session() | |
| if not session: | |
| return None | |
| return run_async(create_mention( | |
| db=session, | |
| content_id=content_id, | |
| keyword=keyword, | |
| context=context, | |
| snippet=snippet, | |
| )) | |
| # Alerts functions | |
| def get_alerts_df( | |
| page: int = 1, | |
| size: int = 10, | |
| severity: Optional[List[ThreatSeverity]] = None, | |
| status: Optional[List[AlertStatus]] = None, | |
| category: Optional[List[AlertCategory]] = None, | |
| is_read: Optional[bool] = None, | |
| search_query: Optional[str] = None, | |
| from_date: Optional[datetime] = None, | |
| to_date: Optional[datetime] = None, | |
| ) -> pd.DataFrame: | |
| """ | |
| Get alerts as a DataFrame. | |
| Args: | |
| page: Page number | |
| size: Page size | |
| severity: Filter by severity | |
| status: Filter by status | |
| category: Filter by category | |
| is_read: Filter by read status | |
| search_query: Search in title and description | |
| from_date: Filter by generated_at >= from_date | |
| to_date: Filter by generated_at <= to_date | |
| Returns: | |
| pd.DataFrame: DataFrame with alerts | |
| """ | |
| session = get_db_session() | |
| if not session: | |
| return pd.DataFrame() | |
| alerts = run_async(get_alerts( | |
| db=session, | |
| pagination=PaginationParams(page=page, size=size), | |
| severity=severity, | |
| status=status, | |
| category=category, | |
| is_read=is_read, | |
| search_query=search_query, | |
| from_date=from_date, | |
| to_date=to_date, | |
| )) | |
| if not alerts: | |
| return pd.DataFrame() | |
| # Convert to DataFrame | |
| data = [] | |
| for alert in alerts: | |
| data.append({ | |
| "id": alert.id, | |
| "title": alert.title, | |
| "description": alert.description, | |
| "severity": alert.severity.value if alert.severity else None, | |
| "status": alert.status.value if alert.status else None, | |
| "category": alert.category.value if alert.category else None, | |
| "generated_at": alert.generated_at, | |
| "source_url": alert.source_url, | |
| "is_read": alert.is_read, | |
| "threat_id": alert.threat_id, | |
| "mention_id": alert.mention_id, | |
| "assigned_to_id": alert.assigned_to_id, | |
| "action_taken": alert.action_taken, | |
| "resolved_at": alert.resolved_at, | |
| }) | |
| return pd.DataFrame(data) | |
| def add_alert( | |
| title: str, | |
| description: str, | |
| severity: ThreatSeverity, | |
| category: AlertCategory, | |
| source_url: Optional[str] = None, | |
| threat_id: Optional[int] = None, | |
| mention_id: Optional[int] = None, | |
| ) -> Optional[Alert]: | |
| """ | |
| Add a new alert. | |
| Args: | |
| title: Alert title | |
| description: Alert description | |
| severity: Alert severity | |
| category: Alert category | |
| source_url: Source URL for the alert | |
| threat_id: ID of related threat | |
| mention_id: ID of related dark web mention | |
| Returns: | |
| Optional[Alert]: Created alert or None | |
| """ | |
| session = get_db_session() | |
| if not session: | |
| return None | |
| return run_async(create_alert( | |
| db=session, | |
| title=title, | |
| description=description, | |
| severity=severity, | |
| category=category, | |
| source_url=source_url, | |
| threat_id=threat_id, | |
| mention_id=mention_id, | |
| )) | |
| def update_alert( | |
| alert_id: int, | |
| status: AlertStatus, | |
| action_taken: Optional[str] = None, | |
| ) -> Optional[Alert]: | |
| """ | |
| Update alert status. | |
| Args: | |
| alert_id: Alert ID | |
| status: New status | |
| action_taken: Description of action taken | |
| Returns: | |
| Optional[Alert]: Updated alert or None | |
| """ | |
| session = get_db_session() | |
| if not session: | |
| return None | |
| return run_async(update_alert_status( | |
| db=session, | |
| alert_id=alert_id, | |
| status=status, | |
| action_taken=action_taken, | |
| )) | |
| def get_alert_severity_counts( | |
| from_date: Optional[datetime] = None, | |
| to_date: Optional[datetime] = None, | |
| ) -> Dict[str, int]: | |
| """ | |
| Get count of alerts by severity. | |
| Args: | |
| from_date: Filter by generated_at >= from_date | |
| to_date: Filter by generated_at <= to_date | |
| Returns: | |
| Dict[str, int]: Mapping of severity to count | |
| """ | |
| session = get_db_session() | |
| if not session: | |
| return {} | |
| return run_async(get_alert_counts_by_severity( | |
| db=session, | |
| from_date=from_date, | |
| to_date=to_date, | |
| )) | |
| # Threats functions | |
| def get_threats_df( | |
| page: int = 1, | |
| size: int = 10, | |
| severity: Optional[List[ThreatSeverity]] = None, | |
| status: Optional[List[ThreatStatus]] = None, | |
| category: Optional[List[ThreatCategory]] = None, | |
| search_query: Optional[str] = None, | |
| from_date: Optional[datetime] = None, | |
| to_date: Optional[datetime] = None, | |
| ) -> pd.DataFrame: | |
| """ | |
| Get threats as a DataFrame. | |
| Args: | |
| page: Page number | |
| size: Page size | |
| severity: Filter by severity | |
| status: Filter by status | |
| category: Filter by category | |
| search_query: Search in title and description | |
| from_date: Filter by discovered_at >= from_date | |
| to_date: Filter by discovered_at <= to_date | |
| Returns: | |
| pd.DataFrame: DataFrame with threats | |
| """ | |
| session = get_db_session() | |
| if not session: | |
| return pd.DataFrame() | |
| threats = run_async(get_threats( | |
| db=session, | |
| pagination=PaginationParams(page=page, size=size), | |
| severity=severity, | |
| status=status, | |
| category=category, | |
| search_query=search_query, | |
| from_date=from_date, | |
| to_date=to_date, | |
| )) | |
| if not threats: | |
| return pd.DataFrame() | |
| # Convert to DataFrame | |
| data = [] | |
| for threat in threats: | |
| data.append({ | |
| "id": threat.id, | |
| "title": threat.title, | |
| "description": threat.description, | |
| "severity": threat.severity.value if threat.severity else None, | |
| "status": threat.status.value if threat.status else None, | |
| "category": threat.category.value if threat.category else None, | |
| "source_url": threat.source_url, | |
| "source_name": threat.source_name, | |
| "source_type": threat.source_type, | |
| "discovered_at": threat.discovered_at, | |
| "affected_entity": threat.affected_entity, | |
| "affected_entity_type": threat.affected_entity_type, | |
| "confidence_score": threat.confidence_score, | |
| "risk_score": threat.risk_score, | |
| }) | |
| return pd.DataFrame(data) | |
| def add_threat( | |
| title: str, | |
| description: str, | |
| severity: ThreatSeverity, | |
| category: ThreatCategory, | |
| status: ThreatStatus = ThreatStatus.NEW, | |
| source_url: Optional[str] = None, | |
| source_name: Optional[str] = None, | |
| source_type: Optional[str] = None, | |
| affected_entity: Optional[str] = None, | |
| affected_entity_type: Optional[str] = None, | |
| confidence_score: float = 0.0, | |
| risk_score: float = 0.0, | |
| ) -> Optional[Threat]: | |
| """ | |
| Add a new threat. | |
| Args: | |
| title: Threat title | |
| description: Threat description | |
| severity: Threat severity | |
| category: Threat category | |
| status: Threat status | |
| source_url: URL of the source | |
| source_name: Name of the source | |
| source_type: Type of source | |
| affected_entity: Name of affected entity | |
| affected_entity_type: Type of affected entity | |
| confidence_score: Confidence score (0-1) | |
| risk_score: Risk score (0-1) | |
| Returns: | |
| Optional[Threat]: Created threat or None | |
| """ | |
| session = get_db_session() | |
| if not session: | |
| return None | |
| return run_async(create_threat( | |
| db=session, | |
| title=title, | |
| description=description, | |
| severity=severity, | |
| category=category, | |
| status=status, | |
| source_url=source_url, | |
| source_name=source_name, | |
| source_type=source_type, | |
| affected_entity=affected_entity, | |
| affected_entity_type=affected_entity_type, | |
| confidence_score=confidence_score, | |
| risk_score=risk_score, | |
| )) | |
| def add_indicator( | |
| threat_id: int, | |
| value: str, | |
| indicator_type: IndicatorType, | |
| description: Optional[str] = None, | |
| is_verified: bool = False, | |
| context: Optional[str] = None, | |
| source: Optional[str] = None, | |
| ) -> Optional[Indicator]: | |
| """ | |
| Add an indicator to a threat. | |
| Args: | |
| threat_id: Threat ID | |
| value: Indicator value | |
| indicator_type: Indicator type | |
| description: Indicator description | |
| is_verified: Whether the indicator is verified | |
| context: Context of the indicator | |
| source: Source of the indicator | |
| Returns: | |
| Optional[Indicator]: Created indicator or None | |
| """ | |
| session = get_db_session() | |
| if not session: | |
| return None | |
| return run_async(add_indicator_to_threat( | |
| db=session, | |
| threat_id=threat_id, | |
| value=value, | |
| indicator_type=indicator_type, | |
| description=description, | |
| is_verified=is_verified, | |
| context=context, | |
| source=source, | |
| )) | |
| def get_threat_stats( | |
| from_date: Optional[datetime] = None, | |
| to_date: Optional[datetime] = None, | |
| ) -> Dict[str, Any]: | |
| """ | |
| Get threat statistics. | |
| Args: | |
| from_date: Filter by discovered_at >= from_date | |
| to_date: Filter by discovered_at <= to_date | |
| Returns: | |
| Dict[str, Any]: Threat statistics | |
| """ | |
| session = get_db_session() | |
| if not session: | |
| return {} | |
| return run_async(get_threat_statistics( | |
| db=session, | |
| from_date=from_date, | |
| to_date=to_date, | |
| )) | |
| # Reports functions | |
| def get_reports_df( | |
| page: int = 1, | |
| size: int = 10, | |
| report_type: Optional[List[ReportType]] = None, | |
| status: Optional[List[ReportStatus]] = None, | |
| severity: Optional[List[ThreatSeverity]] = None, | |
| search_query: Optional[str] = None, | |
| from_date: Optional[datetime] = None, | |
| to_date: Optional[datetime] = None, | |
| ) -> pd.DataFrame: | |
| """ | |
| Get reports as a DataFrame. | |
| Args: | |
| page: Page number | |
| size: Page size | |
| report_type: Filter by report type | |
| status: Filter by status | |
| severity: Filter by severity | |
| search_query: Search in title and summary | |
| from_date: Filter by created_at >= from_date | |
| to_date: Filter by created_at <= to_date | |
| Returns: | |
| pd.DataFrame: DataFrame with reports | |
| """ | |
| session = get_db_session() | |
| if not session: | |
| return pd.DataFrame() | |
| reports = run_async(get_reports( | |
| db=session, | |
| pagination=PaginationParams(page=page, size=size), | |
| report_type=report_type, | |
| status=status, | |
| severity=severity, | |
| search_query=search_query, | |
| from_date=from_date, | |
| to_date=to_date, | |
| )) | |
| if not reports: | |
| return pd.DataFrame() | |
| # Convert to DataFrame | |
| data = [] | |
| for report in reports: | |
| data.append({ | |
| "id": report.id, | |
| "report_id": report.report_id, | |
| "title": report.title, | |
| "summary": report.summary, | |
| "report_type": report.report_type.value if report.report_type else None, | |
| "status": report.status.value if report.status else None, | |
| "severity": report.severity.value if report.severity else None, | |
| "publish_date": report.publish_date, | |
| "created_at": report.created_at, | |
| "time_period_start": report.time_period_start, | |
| "time_period_end": report.time_period_end, | |
| "author_id": report.author_id, | |
| }) | |
| return pd.DataFrame(data) | |
| def add_report( | |
| title: str, | |
| summary: str, | |
| content: str, | |
| report_type: ReportType, | |
| report_id: str, | |
| status: ReportStatus = ReportStatus.DRAFT, | |
| severity: Optional[ThreatSeverity] = None, | |
| publish_date: Optional[datetime] = None, | |
| time_period_start: Optional[datetime] = None, | |
| time_period_end: Optional[datetime] = None, | |
| keywords: Optional[List[str]] = None, | |
| author_id: Optional[int] = None, | |
| ) -> Optional[Report]: | |
| """ | |
| Add a new report. | |
| Args: | |
| title: Report title | |
| summary: Report summary | |
| content: Report content | |
| report_type: Type of report | |
| report_id: Custom ID for the report | |
| status: Report status | |
| severity: Report severity | |
| publish_date: Publication date | |
| time_period_start: Start of time period covered | |
| time_period_end: End of time period covered | |
| keywords: List of keywords related to the report | |
| author_id: ID of the report author | |
| Returns: | |
| Optional[Report]: Created report or None | |
| """ | |
| session = get_db_session() | |
| if not session: | |
| return None | |
| return run_async(create_report( | |
| db=session, | |
| title=title, | |
| summary=summary, | |
| content=content, | |
| report_type=report_type, | |
| report_id=report_id, | |
| status=status, | |
| severity=severity, | |
| publish_date=publish_date, | |
| time_period_start=time_period_start, | |
| time_period_end=time_period_end, | |
| keywords=keywords, | |
| author_id=author_id, | |
| )) | |
| # Helper functions | |
| def get_time_range_dates(time_range: str) -> Tuple[datetime, datetime]: | |
| """ | |
| Get start and end dates for a time range. | |
| Args: | |
| time_range: Time range string (e.g., "Last 7 Days") | |
| Returns: | |
| Tuple[datetime, datetime]: (start_date, end_date) | |
| """ | |
| end_date = datetime.utcnow() | |
| if time_range == "Last 24 Hours": | |
| start_date = end_date - timedelta(days=1) | |
| elif time_range == "Last 7 Days": | |
| start_date = end_date - timedelta(days=7) | |
| elif time_range == "Last 30 Days": | |
| start_date = end_date - timedelta(days=30) | |
| elif time_range == "Last Quarter": | |
| start_date = end_date - timedelta(days=90) | |
| else: # Default to last 30 days | |
| start_date = end_date - timedelta(days=30) | |
| return start_date, end_date | |
| # Initialize database connection | |
| def init_db_connection(): | |
| """Initialize database connection and check if tables exist.""" | |
| session = get_db_session() | |
| if not session: | |
| return False | |
| # Check if tables exist | |
| from sqlalchemy.future import select | |
| try: | |
| # Try to query if tables exist using SQLAlchemy text() | |
| from sqlalchemy import text | |
| query = text("SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'users')") | |
| result = run_async(session.execute(query)) | |
| exists = result.scalar() | |
| return exists | |
| except Exception as e: | |
| # Tables might not exist yet | |
| print(f"Error checking database: {e}") | |
| return False |