180 lines
6.9 KiB
Python
180 lines
6.9 KiB
Python
"""
|
|
Base CLI class and utilities for NimbusFlow CLI.
|
|
"""
|
|
|
|
import shutil
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
from backend.db.connection import DatabaseConnection
|
|
from backend.repositories import (
|
|
MemberRepository,
|
|
ClassificationRepository,
|
|
ServiceRepository,
|
|
ServiceAvailabilityRepository,
|
|
ScheduleRepository,
|
|
ServiceTypeRepository
|
|
)
|
|
from backend.services.scheduling_service import SchedulingService
|
|
|
|
# Import Colors from interactive module for consistent styling
|
|
try:
|
|
from .interactive import Colors
|
|
except ImportError:
|
|
# Fallback colors if interactive module not available
|
|
class Colors:
|
|
RESET = '\033[0m'
|
|
SUCCESS = '\033[1m\033[92m'
|
|
WARNING = '\033[1m\033[93m'
|
|
ERROR = '\033[1m\033[91m'
|
|
CYAN = '\033[96m'
|
|
DIM = '\033[2m'
|
|
|
|
|
|
class CLIError(Exception):
|
|
"""Custom exception for CLI-specific errors."""
|
|
pass
|
|
|
|
|
|
class NimbusFlowCLI:
|
|
"""Main CLI application class with database versioning."""
|
|
|
|
def __init__(self, db_path: str = "database.db", create_version: bool = True):
|
|
"""Initialize CLI with database connection, always using most recent version."""
|
|
self.db_dir = Path(__file__).parent.parent / "db" / "sqlite"
|
|
self.base_db_path = self.db_dir / db_path
|
|
|
|
# Always find and use the most recent database version
|
|
self.db_path = self._get_most_recent_database()
|
|
|
|
if create_version:
|
|
# Create a new version based on the most recent one
|
|
self.db_path = self._create_versioned_database()
|
|
|
|
self.db = DatabaseConnection(self.db_path)
|
|
self._init_repositories()
|
|
|
|
def _get_most_recent_database(self) -> Path:
|
|
"""Get the most recent database version, or create base database if none exist."""
|
|
versions = self.list_database_versions()
|
|
|
|
if versions:
|
|
# Return the most recent versioned database
|
|
most_recent = versions[0] # Already sorted newest first
|
|
return most_recent
|
|
else:
|
|
# No versions exist, create base database if it doesn't exist
|
|
if not self.base_db_path.exists():
|
|
self._create_base_database()
|
|
return self.base_db_path
|
|
|
|
def _create_base_database(self) -> None:
|
|
"""Create the base database from schema.sql if it doesn't exist."""
|
|
# Ensure the directory exists
|
|
self.db_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Read the schema from the schema.sql file
|
|
schema_path = Path(__file__).parent.parent / "schema.sql"
|
|
if not schema_path.exists():
|
|
raise CLIError(f"Schema file not found: {schema_path}")
|
|
|
|
with open(schema_path, 'r') as f:
|
|
schema_sql = f.read()
|
|
|
|
# Create the database and execute the schema
|
|
with DatabaseConnection(self.base_db_path) as db:
|
|
db.executescript(schema_sql)
|
|
|
|
print(f"{Colors.SUCCESS}Created new database:{Colors.RESET} {Colors.CYAN}{self.base_db_path.name}{Colors.RESET}")
|
|
print(f"{Colors.DIM}Location: {self.base_db_path}{Colors.RESET}")
|
|
|
|
def _create_versioned_database(self) -> Path:
|
|
"""Create a versioned copy from the most recent database."""
|
|
source_db = self.db_path # Use the most recent database as source
|
|
|
|
# Generate timestamp-based version
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
version = self._get_next_version_number()
|
|
|
|
versioned_name = f"database_v{version}_{timestamp}.db"
|
|
versioned_path = self.db_dir / versioned_name
|
|
|
|
# Copy the most recent database to create the versioned copy
|
|
shutil.copy2(source_db, versioned_path)
|
|
|
|
print(f"{Colors.SUCCESS}Created versioned database:{Colors.RESET} {Colors.CYAN}{versioned_name}{Colors.RESET}")
|
|
print(f"{Colors.DIM}Based on: {source_db.name}{Colors.RESET}")
|
|
return versioned_path
|
|
|
|
def _get_next_version_number(self) -> int:
|
|
"""Get the next version number by checking existing versioned databases."""
|
|
version_pattern = "database_v*_*.db"
|
|
existing_versions = list(self.db_dir.glob(version_pattern))
|
|
|
|
if not existing_versions:
|
|
return 1
|
|
|
|
# Extract version numbers from existing files
|
|
versions = []
|
|
for db_file in existing_versions:
|
|
try:
|
|
# Parse version from filename like "database_v123_20250828_143022.db"
|
|
parts = db_file.stem.split('_')
|
|
if len(parts) >= 2 and parts[1].startswith('v'):
|
|
version_num = int(parts[1][1:]) # Remove 'v' prefix
|
|
versions.append(version_num)
|
|
except (ValueError, IndexError):
|
|
continue
|
|
|
|
return max(versions) + 1 if versions else 1
|
|
|
|
def list_database_versions(self) -> list[Path]:
|
|
"""List all versioned databases in chronological order."""
|
|
version_pattern = "database_v*_*.db"
|
|
versioned_dbs = list(self.db_dir.glob(version_pattern))
|
|
|
|
# Sort by modification time (newest first)
|
|
return sorted(versioned_dbs, key=lambda x: x.stat().st_mtime, reverse=True)
|
|
|
|
def cleanup_old_versions(self, keep_latest: int = 5) -> int:
|
|
"""Clean up old database versions, keeping only the latest N versions."""
|
|
versions = self.list_database_versions()
|
|
|
|
if len(versions) <= keep_latest:
|
|
return 0
|
|
|
|
versions_to_delete = versions[keep_latest:]
|
|
deleted_count = 0
|
|
|
|
for db_path in versions_to_delete:
|
|
try:
|
|
db_path.unlink()
|
|
deleted_count += 1
|
|
print(f"{Colors.DIM}Deleted old version: {db_path.name}{Colors.RESET}")
|
|
except OSError as e:
|
|
print(f"{Colors.WARNING}⚠️ Could not delete {db_path.name}: {e}{Colors.RESET}")
|
|
|
|
return deleted_count
|
|
|
|
def _init_repositories(self):
|
|
"""Initialize all repository instances."""
|
|
self.member_repo = MemberRepository(self.db)
|
|
self.classification_repo = ClassificationRepository(self.db)
|
|
self.service_repo = ServiceRepository(self.db)
|
|
self.availability_repo = ServiceAvailabilityRepository(self.db)
|
|
self.schedule_repo = ScheduleRepository(self.db)
|
|
self.service_type_repo = ServiceTypeRepository(self.db)
|
|
|
|
# Initialize scheduling service
|
|
self.scheduling_service = SchedulingService(
|
|
classification_repo=self.classification_repo,
|
|
member_repo=self.member_repo,
|
|
service_repo=self.service_repo,
|
|
availability_repo=self.availability_repo,
|
|
schedule_repo=self.schedule_repo,
|
|
)
|
|
|
|
def close(self):
|
|
"""Clean up database connection."""
|
|
if hasattr(self, 'db'):
|
|
self.db.close() |