#!/usr/bin/env python3
"""
Optimized BLE Analysis with Progress Tracking
Handles large datasets more efficiently
"""

import json
import os
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
from scipy.spatial.distance import pdist, squareform
from scipy.cluster.hierarchy import dendrogram, linkage, fcluster
from collections import defaultdict, Counter
import warnings
import time
from tqdm import tqdm
warnings.filterwarnings('ignore')

# Configuration
DATA_DIR = "/var/www/html/CrowdFlow/Picklecon"  # Update this path
FILE_PREFIX = "Picklecon"  # Update this to match your directory name prefix
EVENT_DATES = ["20250807"]  # Update these dates
RESULTS_DIR = "./analysis_results"

# Optional: specify hour range if event doesn't run 24 hours
START_HOUR = 8   # First hour to check (0-23)
END_HOUR = 16   # Last hour to check (0-23)

# Performance optimization settings
MAX_RECORDS_PER_DEVICE = 10000  # Limit records per device to prevent memory issues
SAMPLE_RATE = 1.0  # Use 1.0 for all data, or 0.1 to sample 10% for testing

# Analysis parameters based on research
RSSI_BIN_WIDTH = 5  # dBm bins for RSSI histograms
MIN_RSSI_SAMPLES = 10  # Minimum RSSI readings for reliable fingerprint
IOS_ROTATION_WINDOW = 15  # minutes - iOS MAC rotation cycle
DBSCAN_EPS = 0.3  # DBSCAN epsilon for clustering
MIN_CLUSTER_SIZE = 2  # Minimum devices to form a cluster
CONFIDENCE_THRESHOLD = 0.75  # 75% confidence for cross-day linking
DWELL_TIME_THRESHOLD = 15  # minutes - gap before considering new visit

class BLEFingerprinter:
    def __init__(self):
        self.devices = {}
        self.fingerprints = {}
        self.clusters = {}
        self.cross_day_links = defaultdict(list)
        self.file_loading_log = []  # Track which files were loaded
        
        # Create results directory
        os.makedirs(RESULTS_DIR, exist_ok=True)
    
    def normalize_address(self, address):
        """Normalize MAC address format"""
        if isinstance(address, (int, float)):
            address = str(int(address))
        else:
            address = str(address)
        
        address = address.replace(':', '').upper()
        if len(address) < 12:
            address = address.zfill(12)
        return address[:12] if len(address) > 12 else address
    
    def classify_address(self, address):
        """Classify MAC address type based on research findings"""
        if len(address) != 12:
            return "Invalid"
        
        try:
            msb = int(address[:2], 16)
            
            # Check locally administered bit (bit 41)
            is_local = (msb & 0x02) != 0
            
            # Check multicast bit (bit 40) 
            is_multicast = (msb & 0x01) != 0
            
            if is_local and not is_multicast:
                # Further classify random addresses
                type_bits = (msb >> 6) & 0x03
                
                if type_bits == 0b11:
                    return "Random Static"
                elif type_bits == 0b01:
                    return "Random Private Resolvable"
                elif type_bits == 0b00:
                    return "Random Private Non-Resolvable"
                else:
                    return "Reserved"
            else:
                return "Public"
        except:
            return "Invalid"
    
    def aggregate_raw_records(self, raw_records):
        """Aggregate raw records into device sessions"""
        # Sort by timestamp
        raw_records.sort(key=lambda x: x['timestamp'])
        
        sessions = []
        current_session = None
        
        for record in raw_records:
            if current_session is None:
                # Start new session
                current_session = {
                    'firstSeenTime': record['timestamp'],
                    'lastSeenTime': record['timestamp'],
                    'rssi_values': [record['rssi']],
                    'locations': [record['Location']],
                    'timestamps': [record['timestamp']],
                    'state_changes': [record.get('state', 'unknown')]
                }
            else:
                # Check if this is part of the same session
                last_time = pd.to_datetime(current_session['lastSeenTime'])
                current_time = pd.to_datetime(record['timestamp'])
                time_gap = (current_time - last_time).total_seconds() / 60  # minutes
                
                if time_gap <= DWELL_TIME_THRESHOLD:
                    # Same session
                    current_session['lastSeenTime'] = record['timestamp']
                    current_session['rssi_values'].append(record['rssi'])
                    current_session['locations'].append(record['Location'])
                    current_session['timestamps'].append(record['timestamp'])
                    current_session['state_changes'].append(record.get('state', 'unknown'))
                else:
                    # New session
                    sessions.append(current_session)
                    current_session = {
                        'firstSeenTime': record['timestamp'],
                        'lastSeenTime': record['timestamp'],
                        'rssi_values': [record['rssi']],
                        'locations': [record['Location']],
                        'timestamps': [record['timestamp']],
                        'state_changes': [record.get('state', 'unknown')]
                    }
        
        # Don't forget the last session
        if current_session:
            sessions.append(current_session)
        
        # Calculate derived metrics for each session
        for session in sessions:
            first_time = pd.to_datetime(session['firstSeenTime'])
            last_time = pd.to_datetime(session['lastSeenTime'])
            session['dwellTime'] = (last_time - first_time).total_seconds() / 60  # minutes
            session['recordCount'] = len(session['rssi_values'])
            session['avgRssi'] = np.mean(session['rssi_values'])
            session['locationsVisited'] = list(set(session['locations']))
        
        return sessions
    
    def extract_features(self, device_data):
        """Extract multi-modal features for fingerprinting"""
        features = {}
        
        # RSSI fingerprint - create histogram
        if 'rssi_values' in device_data and len(device_data['rssi_values']) >= MIN_RSSI_SAMPLES:
            rssi_hist, _ = np.histogram(device_data['rssi_values'], 
                                       bins=range(-100, -40, RSSI_BIN_WIDTH))
            features['rssi_histogram'] = rssi_hist / len(device_data['rssi_values'])
            features['rssi_mean'] = np.mean(device_data['rssi_values'])
            features['rssi_std'] = np.std(device_data['rssi_values'])
            features['rssi_skew'] = stats.skew(device_data['rssi_values'])
        
        # Timing patterns
        if 'timestamps' in device_data and len(device_data['timestamps']) > 1:
            times = sorted(device_data['timestamps'])
            intervals = np.diff(times)
            if len(intervals) > 0:
                features['avg_interval'] = np.mean(intervals)
                features['interval_std'] = np.std(intervals)
                features['burst_ratio'] = np.sum(intervals < 1) / len(intervals)
        
        # Location patterns
        if 'locations' in device_data:
            loc_counts = Counter(device_data['locations'])
            total_visits = sum(loc_counts.values())
            features['location_entropy'] = stats.entropy(list(loc_counts.values()))
            features['primary_location_ratio'] = max(loc_counts.values()) / total_visits
            features['location_diversity'] = len(set(device_data['locations']))
        
        # Behavioral features
        features['total_dwell_time'] = device_data.get('total_dwell_time', 0)
        features['session_count'] = device_data.get('session_count', 0)
        features['total_records'] = device_data.get('total_records', 0)
        features['first_seen_hour'] = device_data.get('first_seen_hour', 0)
        
        return features
    
    def cluster_devices_dbscan(self, date_devices):
        """Use DBSCAN to cluster devices based on fingerprints"""
        from sklearn.cluster import DBSCAN
        from sklearn.preprocessing import StandardScaler
        
        # Extract feature matrix
        addresses = list(date_devices.keys())
        feature_matrix = []
        
        for addr in addresses:
            fp = self.fingerprints.get(addr, {})
            # Create feature vector
            features = [
                fp.get('rssi_mean', -80),
                fp.get('rssi_std', 10),
                fp.get('total_dwell_time', 0),
                fp.get('location_entropy', 0),
                fp.get('primary_location_ratio', 1),
                fp.get('avg_interval', 60),
                fp.get('first_seen_hour', 12)
            ]
            feature_matrix.append(features)
        
        if len(feature_matrix) < MIN_CLUSTER_SIZE:
            return {}
        
        # Standardize features
        scaler = StandardScaler()
        X = scaler.fit_transform(feature_matrix)
        
        # Apply DBSCAN
        clustering = DBSCAN(eps=DBSCAN_EPS, min_samples=MIN_CLUSTER_SIZE)
        labels = clustering.fit_predict(X)
        
        # Group by cluster
        clusters = defaultdict(list)
        for i, label in enumerate(labels):
            if label != -1:  # Not noise
                clusters[label].append(addresses[i])
        
        return dict(clusters)
    
    def load_and_process_data(self):
        """Load data and create fingerprints"""
        all_data = []
        
        for date in EVENT_DATES:
            print(f"\nProcessing date: {date}")
            
            # Collect all data for this date across all hours
            date_raw_data = []
            hours_found = []
            
            # Try to load hourly files
            for hour in range(START_HOUR, END_HOUR + 1):
                hour_str = f"{hour:02d}"  # Format as 00, 01, ..., 23
                
                # Build filename for this hour
                # Using the configured prefix
                filename = os.path.join(DATA_DIR, f"{FILE_PREFIX}_combined_flat_{date}_{hour_str}.json")
                
                if os.path.exists(filename):
                    hours_found.append(hour_str)
                    try:
                        with open(filename, 'r') as f:
                            hour_data = json.load(f)
                            
                            # Sample data if requested
                            if SAMPLE_RATE < 1.0:
                                sample_size = int(len(hour_data) * SAMPLE_RATE)
                                hour_data = np.random.choice(hour_data, size=sample_size, replace=False).tolist()
                            
                            date_raw_data.extend(hour_data)
                            print(f"  ✓ Loaded hour {hour_str}: {len(hour_data)} records")
                            
                            # Track loading info
                            self.file_loading_log.append({
                                'date': date,
                                'hour': hour_str,
                                'filename': filename,
                                'exists': True,
                                'records': len(hour_data),
                                'status': 'success'
                            })
                    except Exception as e:
                        print(f"  ✗ Error loading hour {hour_str}: {e}")
                        self.file_loading_log.append({
                            'date': date,
                            'hour': hour_str,
                            'filename': filename,
                            'exists': True,
                            'records': 0,
                            'status': f'error: {str(e)}'
                        })
                else:
                    self.file_loading_log.append({
                        'date': date,
                        'hour': hour_str,
                        'filename': filename,
                        'exists': False,
                        'records': 0,
                        'status': 'missing'
                    })
                
            if not date_raw_data:
                print(f"  ⚠️  No data found for {date}")
                continue
            
            print(f"  Total records for {date}: {len(date_raw_data)} from hours {', '.join(hours_found)}")
            
            # Show hourly distribution
            hour_distribution = defaultdict(int)
            for record in date_raw_data:
                try:
                    ts = pd.to_datetime(record.get('timestamp'))
                    hour_distribution[ts.hour] += 1
                except:
                    pass
            
            if hour_distribution:
                print(f"  Hourly distribution: {dict(sorted(hour_distribution.items()))}")
            
            # Group raw records by device address
            print(f"  Grouping records by device address...")
            start_time = time.time()
            raw_by_device = defaultdict(list)
            
            for record in tqdm(date_raw_data, desc="  Processing records"):
                address = self.normalize_address(record.get('address', ''))
                if not address:
                    continue
                
                # Ensure timestamp is properly formatted
                if 'timestamp' in record:
                    try:
                        # Parse and reformat timestamp if needed
                        ts = pd.to_datetime(record['timestamp'])
                        record['timestamp'] = ts
                    except:
                        continue
                
                # Limit records per device to prevent memory issues
                if len(raw_by_device[address]) < MAX_RECORDS_PER_DEVICE:
                    raw_by_device[address].append(record)
            
            print(f"  Grouped into {len(raw_by_device)} unique devices in {time.time() - start_time:.1f}s")
            
            # Process each device's raw records
            print(f"  Aggregating records into sessions...")
            date_devices = {}
            
            for address, raw_records in tqdm(raw_by_device.items(), desc="  Creating sessions"):
                # Aggregate raw records into sessions
                sessions = self.aggregate_raw_records(raw_records)
                
                if not sessions:
                    continue
                
                # Create device data structure
                date_devices[address] = {
                    'sessions': sessions,
                    'rssi_values': [],
                    'timestamps': [],
                    'locations': [],
                    'address': address,
                    'date': date,
                    'total_dwell_time': sum(s['dwellTime'] for s in sessions),
                    'session_count': len(sessions),
                    'total_records': sum(s['recordCount'] for s in sessions)
                }
                
                # Aggregate all values across sessions
                for session in sessions:
                    date_devices[address]['rssi_values'].extend(session['rssi_values'])
                    date_devices[address]['locations'].extend(session['locations'])
                    
                    # Convert timestamps to float for processing
                    for ts in session['timestamps']:
                        date_devices[address]['timestamps'].append(
                            pd.to_datetime(ts).timestamp()
                        )
                
                # Get first seen hour
                first_ts = pd.to_datetime(sessions[0]['firstSeenTime'])
                date_devices[address]['first_seen_hour'] = first_ts.hour
                
                # Add aggregated sessions to DataFrame format
                for session in sessions:
                    session_record = {
                        'date': date,
                        'address': address,
                        'address_normalized': address,
                        'address_type': self.classify_address(address),
                        'firstSeenTime': session['firstSeenTime'],
                        'lastSeenTime': session['lastSeenTime'],
                        'dwellTime': session['dwellTime'],
                        'recordCount': session['recordCount'],
                        'avgRssi': session['avgRssi'],
                        'locationsVisited': session['locationsVisited']
                    }
                    all_data.append(session_record)
            
            # Create fingerprints
            print(f"  Creating fingerprints for {len(date_devices)} devices...")
            for addr, device_data in tqdm(date_devices.items(), desc="  Fingerprinting"):
                self.fingerprints[addr] = self.extract_features(device_data)
            
            # Cluster devices
            print("  Clustering devices...")
            clusters = self.cluster_devices_dbscan(date_devices)
            self.clusters[date] = clusters
            print(f"  Found {len(clusters)} clusters")
            
            # Store for cross-day analysis
            self.devices[date] = date_devices
        
        # Create DataFrame
        self.df = pd.DataFrame(all_data)
        if not self.df.empty:
            self.df['firstSeenTime'] = pd.to_datetime(self.df['firstSeenTime'])
            self.df['lastSeenTime'] = pd.to_datetime(self.df['lastSeenTime'])
            
            # Print loading summary
            print("\n=== DATA LOADING SUMMARY ===")
            print(f"Total dates processed: {len(self.devices)}")
            print(f"Total unique devices: {len(self.fingerprints)}")
            print(f"Total sessions: {len(self.df)}")
            print(f"Date range: {self.df['date'].min()} to {self.df['date'].max()}")
            print(f"Hour range checked: {START_HOUR:02d}:00 to {END_HOUR:02d}:59")
            if SAMPLE_RATE < 1.0:
                print(f"Sample rate: {SAMPLE_RATE*100:.0f}% of data")
    
    def analyze_and_visualize(self):
        """Perform analysis and create visualizations"""
        print("\n=== ADVANCED BLE ANALYSIS RESULTS ===")
        
        # Quick summary statistics
        print("\n--- Quick Statistics ---")
        print(f"Total unique MAC addresses: {self.df['address_normalized'].nunique()}")
        print(f"Average dwell time: {self.df['dwellTime'].mean():.1f} minutes")
        print(f"Maximum dwell time: {self.df['dwellTime'].max():.1f} minutes")
        print(f"Average sessions per device: {len(self.df) / self.df['address_normalized'].nunique():.1f}")
        
        # Continue with regular analysis...
        # [Rest of the analysis methods remain the same]
        
        # Save results
        self.save_results()
    
    def save_results(self):
        """Save analysis results to files"""
        # Save summary statistics first
        summary_stats = {
            'total_unique_devices': self.df['address_normalized'].nunique(),
            'total_sessions': len(self.df),
            'avg_dwell_time': self.df['dwellTime'].mean(),
            'max_dwell_time': self.df['dwellTime'].max(),
            'total_dates': len(self.devices),
            'sample_rate': SAMPLE_RATE
        }
        
        with open(os.path.join(RESULTS_DIR, 'summary_stats.json'), 'w') as f:
            json.dump(summary_stats, f, indent=2)
        
        # Save fingerprints (sample if too large)
        fingerprint_data = []
        fp_sample = list(self.fingerprints.items())
        if len(fp_sample) > 10000:
            fp_sample = np.random.choice(fp_sample, size=10000, replace=False)
        
        for addr, fp in fp_sample:
            fp_record = {'address': addr, **fp}
            fingerprint_data.append(fp_record)
        
        fp_df = pd.DataFrame(fingerprint_data)
        fp_df.to_csv(os.path.join(RESULTS_DIR, 'device_fingerprints.csv'), index=False)
        
        # Save file loading summary
        self.save_loading_summary()
        
        print(f"\nResults saved to {RESULTS_DIR}/")
    
    def save_loading_summary(self):
        """Save a summary of which hourly files were loaded"""
        if self.file_loading_log:
            summary_df = pd.DataFrame(self.file_loading_log)
            summary_df.to_csv(os.path.join(RESULTS_DIR, 'file_loading_summary.csv'), index=False)
            
            # Print loading statistics
            total_files = len(summary_df)
            loaded_files = len(summary_df[summary_df['status'] == 'success'])
            missing_files = len(summary_df[summary_df['status'] == 'missing'])
            error_files = len(summary_df[summary_df['status'].str.startswith('error')])
            total_records = summary_df[summary_df['status'] == 'success']['records'].sum()
            
            print(f"\n=== FILE LOADING STATISTICS ===")
            print(f"Total files checked: {total_files}")
            print(f"Successfully loaded: {loaded_files}")
            print(f"Missing files: {missing_files}")
            print(f"Files with errors: {error_files}")
            print(f"Total records loaded: {total_records:,}")

def main():
    # Initialize analyzer
    analyzer = BLEFingerprinter()
    
    # Load and process data
    analyzer.load_and_process_data()
    
    # Perform analysis if data was loaded
    if not analyzer.df.empty:
        # For now, just save basic results without full visualization
        print("\n=== Saving Results ===")
        analyzer.save_results()
        print("\nAnalysis complete! Check the results in ./analysis_results/")
    else:
        print("No data loaded!")

if __name__ == "__main__":
    main()