Files
nimbusflow/backend/cli/base.py

180 lines
6.9 KiB
Python

"""
Base CLI class and utilities for NimbusFlow CLI.
"""
import shutil
from datetime import datetime
from pathlib import Path
from typing import Optional
from backend.db.connection import DatabaseConnection
from backend.repositories import (
MemberRepository,
ClassificationRepository,
ServiceRepository,
ServiceAvailabilityRepository,
ScheduleRepository,
ServiceTypeRepository
)
from backend.services.scheduling_service import SchedulingService
# Import Colors from interactive module for consistent styling
try:
from .interactive import Colors
except ImportError:
# Fallback colors if interactive module not available
class Colors:
RESET = '\033[0m'
SUCCESS = '\033[1m\033[92m'
WARNING = '\033[1m\033[93m'
ERROR = '\033[1m\033[91m'
CYAN = '\033[96m'
DIM = '\033[2m'
class CLIError(Exception):
"""Custom exception for CLI-specific errors."""
pass
class NimbusFlowCLI:
"""Main CLI application class with database versioning."""
def __init__(self, db_path: str = "database.db", create_version: bool = True):
"""Initialize CLI with database connection, always using most recent version."""
self.db_dir = Path(__file__).parent.parent / "db" / "sqlite"
self.base_db_path = self.db_dir / db_path
# Always find and use the most recent database version
self.db_path = self._get_most_recent_database()
if create_version:
# Create a new version based on the most recent one
self.db_path = self._create_versioned_database()
self.db = DatabaseConnection(self.db_path)
self._init_repositories()
def _get_most_recent_database(self) -> Path:
"""Get the most recent database version, or create base database if none exist."""
versions = self.list_database_versions()
if versions:
# Return the most recent versioned database
most_recent = versions[0] # Already sorted newest first
return most_recent
else:
# No versions exist, create base database if it doesn't exist
if not self.base_db_path.exists():
self._create_base_database()
return self.base_db_path
def _create_base_database(self) -> None:
"""Create the base database from schema.sql if it doesn't exist."""
# Ensure the directory exists
self.db_dir.mkdir(parents=True, exist_ok=True)
# Read the schema from the schema.sql file
schema_path = Path(__file__).parent.parent / "schema.sql"
if not schema_path.exists():
raise CLIError(f"Schema file not found: {schema_path}")
with open(schema_path, 'r') as f:
schema_sql = f.read()
# Create the database and execute the schema
with DatabaseConnection(self.base_db_path) as db:
db.executescript(schema_sql)
print(f"{Colors.SUCCESS}Created new database:{Colors.RESET} {Colors.CYAN}{self.base_db_path.name}{Colors.RESET}")
print(f"{Colors.DIM}Location: {self.base_db_path}{Colors.RESET}")
def _create_versioned_database(self) -> Path:
"""Create a versioned copy from the most recent database."""
source_db = self.db_path # Use the most recent database as source
# Generate timestamp-based version
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
version = self._get_next_version_number()
versioned_name = f"database_v{version}_{timestamp}.db"
versioned_path = self.db_dir / versioned_name
# Copy the most recent database to create the versioned copy
shutil.copy2(source_db, versioned_path)
print(f"{Colors.SUCCESS}Created versioned database:{Colors.RESET} {Colors.CYAN}{versioned_name}{Colors.RESET}")
print(f"{Colors.DIM}Based on: {source_db.name}{Colors.RESET}")
return versioned_path
def _get_next_version_number(self) -> int:
"""Get the next version number by checking existing versioned databases."""
version_pattern = "database_v*_*.db"
existing_versions = list(self.db_dir.glob(version_pattern))
if not existing_versions:
return 1
# Extract version numbers from existing files
versions = []
for db_file in existing_versions:
try:
# Parse version from filename like "database_v123_20250828_143022.db"
parts = db_file.stem.split('_')
if len(parts) >= 2 and parts[1].startswith('v'):
version_num = int(parts[1][1:]) # Remove 'v' prefix
versions.append(version_num)
except (ValueError, IndexError):
continue
return max(versions) + 1 if versions else 1
def list_database_versions(self) -> list[Path]:
"""List all versioned databases in chronological order."""
version_pattern = "database_v*_*.db"
versioned_dbs = list(self.db_dir.glob(version_pattern))
# Sort by modification time (newest first)
return sorted(versioned_dbs, key=lambda x: x.stat().st_mtime, reverse=True)
def cleanup_old_versions(self, keep_latest: int = 5) -> int:
"""Clean up old database versions, keeping only the latest N versions."""
versions = self.list_database_versions()
if len(versions) <= keep_latest:
return 0
versions_to_delete = versions[keep_latest:]
deleted_count = 0
for db_path in versions_to_delete:
try:
db_path.unlink()
deleted_count += 1
print(f"{Colors.DIM}Deleted old version: {db_path.name}{Colors.RESET}")
except OSError as e:
print(f"{Colors.WARNING}⚠️ Could not delete {db_path.name}: {e}{Colors.RESET}")
return deleted_count
def _init_repositories(self):
"""Initialize all repository instances."""
self.member_repo = MemberRepository(self.db)
self.classification_repo = ClassificationRepository(self.db)
self.service_repo = ServiceRepository(self.db)
self.availability_repo = ServiceAvailabilityRepository(self.db)
self.schedule_repo = ScheduleRepository(self.db)
self.service_type_repo = ServiceTypeRepository(self.db)
# Initialize scheduling service
self.scheduling_service = SchedulingService(
classification_repo=self.classification_repo,
member_repo=self.member_repo,
service_repo=self.service_repo,
availability_repo=self.availability_repo,
schedule_repo=self.schedule_repo,
)
def close(self):
"""Clean up database connection."""
if hasattr(self, 'db'):
self.db.close()