#!/usr/bin/env python3
"""
Optimized BLE Analysis with HDBSCAN Clustering
Memory-efficient version for large datasets
"""

import json
import os
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
from collections import defaultdict, Counter
import warnings
import time
import gc
from tqdm import tqdm
from multiprocessing import Pool, cpu_count
import pickle

warnings.filterwarnings('ignore')

# Configuration
DATA_DIR = "/var/www/html/CrowdFlow/HYROX"
FILE_PREFIX = "HYROX"
EVENT_DATES = ["20250926"]
RESULTS_DIR = "./analysis_results"

# Optional: specify hour range if event doesn't run 24 hours
START_HOUR = 8
END_HOUR = 12

# Analysis parameters
MIN_RSSI_SAMPLES = 5  # Reduced for faster processing
MIN_CLUSTER_SIZE = 2
DWELL_TIME_THRESHOLD = 15  # minutes

# Performance options
USE_MULTIPROCESSING = True
BATCH_SIZE = 1000  # Process devices in batches
MEMORY_CLEANUP_INTERVAL = 5000  # Clean memory every N devices

# Known BLE manufacturer OUIs (kept for Classic BT detection)
KNOWN_MANUFACTURER_OUIS = {
    "000A95": "Bose",
    "04E676": "Sony",
    "006037": "JBL",
    "00166C": "Beats",
    "B8F6B1": "Apple (AirPods)",
    "AC9E17": "Apple (AirPods)",
    "0018DE": "Nordic Semiconductor",
    "AC233F": "Shenzhen Minew",
    "5A6D5A": "Estimote",
    "0025DF": "Kontakt.io",
}

class OptimizedBLEAnalyzer:
    def __init__(self):
        self.devices = {}
        self.fingerprints = {}
        self.clusters = {}
        self.infrastructure_devices = set()
        self.classic_bt_devices = set()
        self.file_loading_log = []
        
        # Create results directory
        os.makedirs(RESULTS_DIR, exist_ok=True)
    
    def normalize_address(self, address):
        """Normalize MAC address format"""
        if isinstance(address, (int, float)):
            address = str(int(address))
        else:
            address = str(address)
        
        address = address.replace(':', '').upper()
        if len(address) < 12:
            address = address.zfill(12)
        return address[:12] if len(address) > 12 else address
    
    def classify_address(self, address):
        """Classify MAC address type based on BLE Core Specification"""
        if len(address) != 12:
            return "Invalid"
        
        try:
            # Get the most significant byte
            msb = int(address[:2], 16)
            type_bits = (msb >> 6) & 0x03
            
            if type_bits == 0b11:  # 11xxxxxx
                return "Random Static"
            elif type_bits == 0b01:  # 01xxxxxx
                return "Random Private Resolvable"
            elif type_bits == 0b00:  # 00xxxxxx
                return "Random Private Non-Resolvable"
            else:  # 10xxxxxx
                # Check if it matches known manufacturer OUI
                oui = address[:6]
                if oui in KNOWN_MANUFACTURER_OUIS:
                    return "Public"
                elif self.is_valid_oui_structure(address):
                    return "Public (Unregistered OUI)"
                else:
                    return "Ambiguous"
        except:
            return "Invalid"
    
    def is_valid_oui_structure(self, address):
        """Check if address has valid OUI structure for public address"""
        oui = address[:6]
        if oui == "000000" or oui == "FFFFFF":
            return False
        if address[0:2] == address[2:4] == address[4:6]:
            return False
        return True
    
    def extract_features_optimized(self, device_data):
        """Streamlined feature extraction - only essential features"""
        features = {}
        
        # Basic RSSI stats (no histogram)
        rssi_vals = device_data.get('rssi_values', [])
        if len(rssi_vals) >= MIN_RSSI_SAMPLES:
            features['rssi_mean'] = np.mean(rssi_vals)
            features['rssi_std'] = np.std(rssi_vals)
        else:
            features['rssi_mean'] = -80
            features['rssi_std'] = 10
        
        # Simple timing features (only if enough data)
        if device_data.get('total_records', 0) > 10:
            timestamps = device_data.get('timestamps', [])
            if len(timestamps) > 1:
                intervals = np.diff(sorted(timestamps))
                if len(intervals) > 0:
                    features['avg_interval'] = np.mean(intervals)
                    features['interval_std'] = np.std(intervals) if len(intervals) > 1 else 0
        
        # Basic location metrics
        locations = device_data.get('locations', [])
        if locations:
            loc_counts = Counter(locations)
            features['location_diversity'] = len(loc_counts)
            features['primary_location_ratio'] = max(loc_counts.values()) / len(locations)
        
        # Essential behavioral features
        features['total_dwell_time'] = device_data.get('total_dwell_time', 0)
        features['session_count'] = device_data.get('session_count', 0)
        features['total_records'] = device_data.get('total_records', 0)
        features['first_seen_hour'] = device_data.get('first_seen_hour', 0)
        
        return features
    
    def is_likely_infrastructure(self, device_fingerprint):
        """Detect likely infrastructure devices"""
        indicators = 0
        
        if device_fingerprint.get('rssi_std', float('inf')) < 3:
            indicators += 1
        if device_fingerprint.get('total_dwell_time', 0) > 120:
            indicators += 1
        avg_interval = device_fingerprint.get('avg_interval', 0)
        if 0.9 < avg_interval < 1.1:
            indicators += 1
        if device_fingerprint.get('interval_std', float('inf')) < 0.1:
            indicators += 1
        if device_fingerprint.get('location_diversity', float('inf')) == 1:
            indicators += 1
        
        return indicators >= 3
    
    def is_likely_classic_bluetooth(self, address, device_fingerprint):
        """Detect likely Classic Bluetooth devices"""
        indicators = 0
        
        if device_fingerprint.get('rssi_std', float('inf')) < 3:
            indicators += 1
        
        avg_interval = device_fingerprint.get('avg_interval', 0)
        classic_intervals = [1.28, 2.56, 0.64, 5.12]
        for interval in classic_intervals:
            if abs(avg_interval - interval) < 0.05:
                indicators += 2
                break
        
        if device_fingerprint.get('total_dwell_time', 0) > 60:
            indicators += 1
        if device_fingerprint.get('total_records', 0) > 1000:
            indicators += 1
        if device_fingerprint.get('session_count', 0) == 1:
            indicators += 1
        
        # Check Classic BT OUI
        oui = address[:6]
        if oui in ["000A95", "04E676", "006037", "00166C", "B8F6B1", "AC9E17"]:
            indicators += 2
        
        return indicators >= 4
    
    def aggregate_raw_records(self, raw_records):
        """Aggregate raw records into device sessions"""
        raw_records.sort(key=lambda x: x['timestamp'])
        
        sessions = []
        current_session = None
        
        for record in raw_records:
            if current_session is None:
                current_session = {
                    'firstSeenTime': record['timestamp'],
                    'lastSeenTime': record['timestamp'],
                    'rssi_values': [record['rssi']],
                    'locations': [record['Location']],
                    'timestamps': [record['timestamp']]
                }
            else:
                last_time = pd.to_datetime(current_session['lastSeenTime'])
                current_time = pd.to_datetime(record['timestamp'])
                time_gap = (current_time - last_time).total_seconds() / 60
                
                if time_gap <= DWELL_TIME_THRESHOLD:
                    current_session['lastSeenTime'] = record['timestamp']
                    current_session['rssi_values'].append(record['rssi'])
                    current_session['locations'].append(record['Location'])
                    current_session['timestamps'].append(record['timestamp'])
                else:
                    sessions.append(current_session)
                    current_session = {
                        'firstSeenTime': record['timestamp'],
                        'lastSeenTime': record['timestamp'],
                        'rssi_values': [record['rssi']],
                        'locations': [record['Location']],
                        'timestamps': [record['timestamp']]
                    }
        
        if current_session:
            sessions.append(current_session)
        
        # Calculate derived metrics
        for session in sessions:
            first_time = pd.to_datetime(session['firstSeenTime'])
            last_time = pd.to_datetime(session['lastSeenTime'])
            session['dwellTime'] = (last_time - first_time).total_seconds() / 60
            session['recordCount'] = len(session['rssi_values'])
            session['avgRssi'] = np.mean(session['rssi_values'])
            session['locationsVisited'] = list(set(session['locations']))
        
        return sessions
    
    def process_devices_batch(self, device_items):
        """Process a batch of devices for fingerprinting"""
        batch_fingerprints = {}
        for addr, device_data in device_items:
            batch_fingerprints[addr] = self.extract_features_optimized(device_data)
        return batch_fingerprints
    
    def create_fingerprints_parallel(self, date_devices):
        """Create fingerprints using multiprocessing"""
        print(f"  Creating fingerprints for {len(date_devices)} devices...")
        
        device_items = list(date_devices.items())
        
        if USE_MULTIPROCESSING and len(device_items) > 1000:
            # Use multiprocessing for large datasets
            n_cores = min(cpu_count(), 8)  # Limit cores to prevent memory issues
            chunk_size = max(500, len(device_items) // (n_cores * 4))
            chunks = [device_items[i:i+chunk_size] for i in range(0, len(device_items), chunk_size)]
            
            print(f"  Using {n_cores} cores, {len(chunks)} chunks...")
            
            with Pool(processes=n_cores) as pool:
                results = list(tqdm(
                    pool.imap(self.process_devices_batch, chunks),
                    total=len(chunks),
                    desc="  Processing chunks"
                ))
            
            # Combine results
            for batch_fps in results:
                self.fingerprints.update(batch_fps)
        else:
            # Single-threaded for smaller datasets
            for i, (addr, device_data) in enumerate(tqdm(device_items, desc="  Fingerprinting")):
                self.fingerprints[addr] = self.extract_features_optimized(device_data)
                
                # Periodic memory cleanup
                if i % MEMORY_CLEANUP_INTERVAL == 0:
                    gc.collect()
    
    def simple_clustering(self, devices):
        """Simplified clustering using DBSCAN"""
        try:
            from sklearn.cluster import DBSCAN
            from sklearn.preprocessing import StandardScaler
        except ImportError:
            print("  scikit-learn not available. Skipping clustering.")
            return {}
        
        # Pre-filter devices
        clusterable = {}
        for addr, device in devices.items():
            if addr in self.infrastructure_devices or addr in self.classic_bt_devices:
                continue
            
            fp = self.fingerprints.get(addr, {})
            if (5 <= fp.get('total_dwell_time', 0) <= 300 and
                fp.get('total_records', 0) >= 5):
                clusterable[addr] = device
        
        if len(clusterable) < MIN_CLUSTER_SIZE:
            return {}
        
        print(f"  Clustering {len(clusterable)} devices...")
        
        # Build feature matrix
        addresses = list(clusterable.keys())
        features = []
        
        for addr in addresses:
            fp = self.fingerprints.get(addr, {})
            features.append([
                fp.get('rssi_mean', -80),
                fp.get('rssi_std', 10),
                np.log1p(fp.get('total_dwell_time', 0)),
                fp.get('location_diversity', 1),
                fp.get('first_seen_hour', 12)
            ])
        
        X = np.array(features)
        X_scaled = StandardScaler().fit_transform(X)
        
        # Simple DBSCAN with fixed parameters
        clustering = DBSCAN(eps=0.5, min_samples=MIN_CLUSTER_SIZE, n_jobs=1)
        labels = clustering.fit_predict(X_scaled)
        
        # Convert to cluster dictionary
        clusters = defaultdict(list)
        for i, label in enumerate(labels):
            if label != -1:
                clusters[label].append(addresses[i])
        
        # Filter out single-device clusters
        valid_clusters = {k: v for k, v in clusters.items() if len(v) >= MIN_CLUSTER_SIZE}
        
        print(f"  Found {len(valid_clusters)} clusters")
        return valid_clusters
    
    def load_and_process_data(self):
        """Load data efficiently with batch processing"""
        all_sessions = []
        
        for date in EVENT_DATES:
            print(f"\nProcessing date: {date}")
            
            date_devices = {}
            hours_found = []
            total_records_processed = 0
            
            # Process each hour separately to manage memory
            for hour in range(START_HOUR, END_HOUR + 1):
                hour_str = f"{hour:02d}"
                filename = os.path.join(DATA_DIR, f"{FILE_PREFIX}_combined_flat_{date}_{hour_str}.json")
                
                if os.path.exists(filename):
                    try:
                        print(f"  Processing hour {hour_str}...")
                        with open(filename, 'r') as f:
                            hour_data = json.load(f)
                        
                        print(f"    Loaded {len(hour_data)} records")
                        hours_found.append(hour_str)
                        
                        # Group by device
                        hour_by_device = defaultdict(list)
                        for record in hour_data:
                            address = self.normalize_address(record.get('address', ''))
                            if address:
                                try:
                                    ts = pd.to_datetime(record['timestamp'])
                                    record['timestamp'] = ts
                                    hour_by_device[address].append(record)
                                except:
                                    continue
                        
                        # Process devices in this hour
                        for address, raw_records in hour_by_device.items():
                            sessions = self.aggregate_raw_records(raw_records)
                            
                            if sessions:
                                if address not in date_devices:
                                    date_devices[address] = {
                                        'sessions': [],
                                        'rssi_values': [],
                                        'timestamps': [],
                                        'locations': [],
                                        'address': address,
                                        'date': date
                                    }
                                
                                # Append sessions
                                date_devices[address]['sessions'].extend(sessions)
                                
                                # Aggregate values efficiently
                                for session in sessions:
                                    date_devices[address]['rssi_values'].extend(session['rssi_values'])
                                    date_devices[address]['locations'].extend(session['locations'])
                                    date_devices[address]['timestamps'].extend(
                                        [pd.to_datetime(ts).timestamp() for ts in session['timestamps']]
                                    )
                        
                        total_records_processed += len(hour_data)
                        
                        # Clear hour data to free memory
                        hour_data = None
                        hour_by_device = None
                        gc.collect()
                        
                    except Exception as e:
                        print(f"  Error loading hour {hour_str}: {e}")
                        self.file_loading_log.append({
                            'date': date,
                            'hour': hour_str,
                            'filename': filename,
                            'error': str(e)
                        })
            
            if not date_devices:
                print(f"  No data found for {date}")
                continue
            
            print(f"  Total records processed: {total_records_processed:,}")
            print(f"  Total devices found: {len(date_devices):,}")
            
            # Calculate final metrics for each device
            for addr, device in date_devices.items():
                device['total_dwell_time'] = sum(s['dwellTime'] for s in device['sessions'])
                device['session_count'] = len(device['sessions'])
                device['total_records'] = sum(s['recordCount'] for s in device['sessions'])
                
                if device['sessions']:
                    first_ts = pd.to_datetime(device['sessions'][0]['firstSeenTime'])
                    device['first_seen_hour'] = first_ts.hour
                
                # Add to all_sessions for final analysis
                for session in device['sessions']:
                    all_sessions.append({
                        'date': date,
                        'address': addr,
                        'address_type': self.classify_address(addr),
                        'firstSeenTime': session['firstSeenTime'],
                        'lastSeenTime': session['lastSeenTime'],
                        'dwellTime': session['dwellTime'],
                        'recordCount': session['recordCount'],
                        'avgRssi': session['avgRssi'],
                        'locationsVisited': session['locationsVisited']
                    })
            
            # Create fingerprints
            self.create_fingerprints_parallel(date_devices)
            
            # Identify device types
            print("  Identifying device types...")
            for addr, fp in self.fingerprints.items():
                if self.is_likely_infrastructure(fp):
                    self.infrastructure_devices.add(addr)
                elif self.is_likely_classic_bluetooth(addr, fp):
                    self.classic_bt_devices.add(addr)
            
            print(f"  Infrastructure devices: {len(self.infrastructure_devices)}")
            print(f"  Classic Bluetooth devices: {len(self.classic_bt_devices)}")
            
            # Clustering
            try:
                clusters = self.simple_clustering(date_devices)
                self.clusters[date] = clusters
            except Exception as e:
                print(f"  Clustering error: {e}")
                self.clusters[date] = {}
            
            # Store devices
            self.devices[date] = date_devices
            
            # Clear to save memory
            date_devices = None
            gc.collect()
        
        # Create DataFrame
        self.df = pd.DataFrame(all_sessions)
        if not self.df.empty:
            self.df['firstSeenTime'] = pd.to_datetime(self.df['firstSeenTime'])
            self.df['lastSeenTime'] = pd.to_datetime(self.df['lastSeenTime'])
    
    def analyze_attendance(self):
        """Focused attendance analysis"""
        print("\n=== ATTENDANCE ANALYSIS ===")
        
        if self.df.empty:
            print("No data to analyze!")
            return
        
        # 1. Raw counts
        daily_unique = self.df.groupby('date')['address'].nunique()
        print(f"\n1. RAW UNIQUE DEVICES:")
        for date, count in daily_unique.items():
            print(f"  {date}: {count:,}")
        
        # 2. Filter out infrastructure
        visitor_df = self.df[~self.df['address'].isin(self.infrastructure_devices)]
        
        print(f"\n2. DEVICE CATEGORIZATION:")
        print(f"  Infrastructure devices: {len(self.infrastructure_devices)}")
        print(f"  Classic Bluetooth devices: {len(self.classic_bt_devices)}")
        print(f"  Total visitor devices: {visitor_df['address'].nunique()}")
        
        # 3. Separate Classic BT from BLE
        classic_bt_df = visitor_df[visitor_df['address'].isin(self.classic_bt_devices)]
        ble_df = visitor_df[~visitor_df['address'].isin(self.classic_bt_devices)]
        
        # 4. Analyze BLE devices by dwell time
        ble_device_dwell = ble_df.groupby(['date', 'address', 'address_type'])['dwellTime'].sum().reset_index()
        ble_device_dwell.columns = ['date', 'address', 'address_type', 'total_dwell_time']
        
        print(f"\n3. BLE DEVICE ANALYSIS (by dwell time):")
        thresholds = [0, 10, 15, 30, 60]
        
        for threshold in thresholds:
            filtered = ble_device_dwell[ble_device_dwell['total_dwell_time'] >= threshold]
            print(f"\n  ≥{threshold} min dwell time:")
            
            by_date = filtered.groupby('date')['address'].nunique()
            for date, count in by_date.items():
                print(f"    {date}: {count:,} devices")
            
            if threshold == 15:  # Detailed breakdown for 15 min
                by_type = filtered.groupby(['date', 'address_type'])['address'].nunique()
                print(f"    By address type:")
                for (date, addr_type), count in by_type.items():
                    print(f"      {addr_type}: {count:,}")
        
        # 5. Apply rotation factors
        print(f"\n4. ATTENDANCE ESTIMATION (10+ min dwell):")
        
        rotation_factors = {
            'Random Static': 1.0,
            'Public': 1.0,
            'Public (Unregistered OUI)': 1.0,
            'Random Private Resolvable': 4.0,
            'Random Private Non-Resolvable': 8.0,
            'Ambiguous': 2.0,
            'Invalid': 1.0
        }
        
        ble_10min = ble_device_dwell[ble_device_dwell['total_dwell_time'] >= 10]
        
        for date in ble_10min['date'].unique():
            date_data = ble_10min[ble_10min['date'] == date]
            raw_count = len(date_data)
            
            # Calculate adjusted count
            adjusted_count = 0
            for addr_type, factor in rotation_factors.items():
                type_count = len(date_data[date_data['address_type'] == addr_type])
                adjusted_count += type_count / factor
            
            # Add cluster adjustment if available
            if date in self.clusters and len(self.clusters[date]) > 0:
                cluster_estimate = len(self.clusters[date])
                adjusted_count = 0.7 * adjusted_count + 0.3 * cluster_estimate
            
            # Add Classic BT count
            classic_dwell = classic_bt_df.groupby('address')['dwellTime'].sum()
            classic_10min = len(classic_dwell[classic_dwell >= 10]) if len(classic_bt_df) > 0 else 0
            
            print(f"\n  {date}:")
            print(f"    BLE raw count (10+ min): {raw_count:,}")
            print(f"    BLE adjusted estimate: {int(adjusted_count):,}")
            print(f"    Classic BT (10+ min): {classic_10min:,}")
            print(f"    Clusters found: {len(self.clusters.get(date, {}))}")
            print(f"    TOTAL ESTIMATE: {int(adjusted_count) + classic_10min:,}")
        
        # Save results
        self.save_results()
    
    def save_results(self):
        """Save analysis results"""
        # Save device fingerprints
        fingerprint_data = []
        for addr, fp in self.fingerprints.items():
            fp_record = {
                'address': addr,
                'address_type': self.classify_address(addr),
                'is_infrastructure': addr in self.infrastructure_devices,
                'is_classic_bt': addr in self.classic_bt_devices,
                **fp
            }
            fingerprint_data.append(fp_record)
        
        fp_df = pd.DataFrame(fingerprint_data)
        fp_df.to_csv(os.path.join(RESULTS_DIR, 'device_fingerprints.csv'), index=False)
        
        # Save device categories
        categories = []
        for addr in self.infrastructure_devices:
            categories.append({'address': addr, 'category': 'infrastructure'})
        for addr in self.classic_bt_devices:
            categories.append({'address': addr, 'category': 'classic_bluetooth'})
        
        if categories:
            cat_df = pd.DataFrame(categories)
            cat_df.to_csv(os.path.join(RESULTS_DIR, 'device_categories.csv'), index=False)
        
        # Save clusters
        if any(self.clusters.values()):
            cluster_data = []
            for date, clusters in self.clusters.items():
                for cluster_id, addresses in clusters.items():
                    for addr in addresses:
                        cluster_data.append({
                            'date': date,
                            'cluster_id': cluster_id,
                            'address': addr
                        })
            
            if cluster_data:
                cluster_df = pd.DataFrame(cluster_data)
                cluster_df.to_csv(os.path.join(RESULTS_DIR, 'device_clusters.csv'), index=False)
        
        # Save session summary
        self.df.to_csv(os.path.join(RESULTS_DIR, 'device_sessions.csv'), index=False)
        
        print(f"\nResults saved to {RESULTS_DIR}/")

def main():
    analyzer = OptimizedBLEAnalyzer()
    
    # Load and process data
    analyzer.load_and_process_data()
    
    # Perform analysis
    if not analyzer.df.empty:
        analyzer.analyze_attendance()
    else:
        print("No data loaded!")

if __name__ == "__main__":
    main()