#!/usr/bin/env python3
"""
Advanced BLE Analysis with Fingerprinting and Clustering
Implements research-based techniques to handle MAC randomization
"""

import json
import os
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
from scipy.spatial.distance import pdist, squareform
from scipy.cluster.hierarchy import dendrogram, linkage, fcluster
from collections import defaultdict, Counter
import warnings
warnings.filterwarnings('ignore')

# Configuration
DATA_DIR = "/var/www/html/CrowdFlow/Picklecon"
EVENT_DATES = ["20250807", "20250808", "20250809", "20250810"]
RESULTS_DIR = "./analysis_results"

# Analysis parameters based on research
RSSI_BIN_WIDTH = 5  # dBm bins for RSSI histograms
MIN_RSSI_SAMPLES = 10  # Minimum RSSI readings for reliable fingerprint
IOS_ROTATION_WINDOW = 15  # minutes - iOS MAC rotation cycle
DBSCAN_EPS = 0.3  # DBSCAN epsilon for clustering
MIN_CLUSTER_SIZE = 2  # Minimum devices to form a cluster
CONFIDENCE_THRESHOLD = 0.75  # 75% confidence for cross-day linking

class BLEFingerprinter:
    def __init__(self):
        self.devices = {}
        self.fingerprints = {}
        self.clusters = {}
        self.cross_day_links = defaultdict(list)
        
        # Create results directory
        os.makedirs(RESULTS_DIR, exist_ok=True)
    
    def normalize_address(self, address):
        """Normalize MAC address format"""
        if isinstance(address, (int, float)):
            address = str(int(address))
        else:
            address = str(address)
        
        address = address.replace(':', '').upper()
        if len(address) < 12:
            address = address.zfill(12)
        return address[:12] if len(address) > 12 else address
    
    def classify_address(self, address):
        """Classify MAC address type based on research findings"""
        if len(address) != 12:
            return "Invalid"
        
        try:
            msb = int(address[:2], 16)
            
            # Check locally administered bit (bit 41)
            is_local = (msb & 0x02) != 0
            
            # Check multicast bit (bit 40) 
            is_multicast = (msb & 0x01) != 0
            
            if is_local and not is_multicast:
                # Further classify random addresses
                type_bits = (msb >> 6) & 0x03
                
                if type_bits == 0b11:
                    return "Random Static"
                elif type_bits == 0b01:
                    return "Random Private Resolvable"
                elif type_bits == 0b00:
                    return "Random Private Non-Resolvable"
                else:
                    return "Reserved"
            else:
                return "Public"
        except:
            return "Invalid"
    
    def extract_features(self, device_data):
        """Extract multi-modal features for fingerprinting"""
        features = {}
        
        # RSSI fingerprint - create histogram
        if 'rssi_values' in device_data and len(device_data['rssi_values']) >= MIN_RSSI_SAMPLES:
            rssi_hist, _ = np.histogram(device_data['rssi_values'], 
                                       bins=range(-100, -40, RSSI_BIN_WIDTH))
            features['rssi_histogram'] = rssi_hist / len(device_data['rssi_values'])
            features['rssi_mean'] = np.mean(device_data['rssi_values'])
            features['rssi_std'] = np.std(device_data['rssi_values'])
            features['rssi_skew'] = stats.skew(device_data['rssi_values'])
        
        # Timing patterns
        if 'timestamps' in device_data and len(device_data['timestamps']) > 1:
            times = sorted(device_data['timestamps'])
            intervals = np.diff(times)
            if len(intervals) > 0:
                features['avg_interval'] = np.mean(intervals)
                features['interval_std'] = np.std(intervals)
                features['burst_ratio'] = np.sum(intervals < 1) / len(intervals)
        
        # Location patterns
        if 'locations' in device_data:
            loc_counts = Counter(device_data['locations'])
            total_visits = sum(loc_counts.values())
            features['location_entropy'] = stats.entropy(list(loc_counts.values()))
            features['primary_location_ratio'] = max(loc_counts.values()) / total_visits
            features['location_diversity'] = len(set(device_data['locations']))
        
        # Behavioral features
        features['dwell_time'] = device_data.get('dwellTime', 0)
        features['record_count'] = device_data.get('recordCount', 0)
        features['first_seen_hour'] = device_data.get('first_seen_hour', 0)
        
        return features
    
    def compute_fingerprint_similarity(self, fp1, fp2):
        """Compute similarity between two fingerprints"""
        similarity_scores = []
        
        # RSSI histogram similarity (if both have it)
        if 'rssi_histogram' in fp1 and 'rssi_histogram' in fp2:
            # Cosine similarity
            hist_sim = np.dot(fp1['rssi_histogram'], fp2['rssi_histogram']) / (
                np.linalg.norm(fp1['rssi_histogram']) * np.linalg.norm(fp2['rssi_histogram'])
            )
            similarity_scores.append(hist_sim * 2)  # Weight RSSI heavily
        
        # Behavioral similarity
        behavioral_features = ['dwell_time', 'location_entropy', 'primary_location_ratio']
        for feat in behavioral_features:
            if feat in fp1 and feat in fp2:
                # Normalized difference
                max_val = max(abs(fp1[feat]), abs(fp2[feat]), 1)
                sim = 1 - abs(fp1[feat] - fp2[feat]) / max_val
                similarity_scores.append(sim)
        
        # Timing similarity
        if 'avg_interval' in fp1 and 'avg_interval' in fp2:
            interval_sim = 1 - abs(fp1['avg_interval'] - fp2['avg_interval']) / max(
                fp1['avg_interval'], fp2['avg_interval'], 1
            )
            similarity_scores.append(interval_sim)
        
        return np.mean(similarity_scores) if similarity_scores else 0
    
    def cluster_devices_dbscan(self, date_devices):
        """Use DBSCAN to cluster devices based on fingerprints"""
        from sklearn.cluster import DBSCAN
        from sklearn.preprocessing import StandardScaler
        
        # Extract feature matrix
        addresses = list(date_devices.keys())
        feature_matrix = []
        
        for addr in addresses:
            fp = self.fingerprints.get(addr, {})
            # Create feature vector
            features = [
                fp.get('rssi_mean', -80),
                fp.get('rssi_std', 10),
                fp.get('dwell_time', 0),
                fp.get('location_entropy', 0),
                fp.get('primary_location_ratio', 1),
                fp.get('avg_interval', 60),
                fp.get('first_seen_hour', 12)
            ]
            feature_matrix.append(features)
        
        if len(feature_matrix) < MIN_CLUSTER_SIZE:
            return {}
        
        # Standardize features
        scaler = StandardScaler()
        X = scaler.fit_transform(feature_matrix)
        
        # Apply DBSCAN
        clustering = DBSCAN(eps=DBSCAN_EPS, min_samples=MIN_CLUSTER_SIZE)
        labels = clustering.fit_predict(X)
        
        # Group by cluster
        clusters = defaultdict(list)
        for i, label in enumerate(labels):
            if label != -1:  # Not noise
                clusters[label].append(addresses[i])
        
        return dict(clusters)
    
    def analyze_ios_rotation_windows(self, df):
        """Analyze 15-minute windows for iOS rotation patterns"""
        # Create 15-minute bins
        df['time_bin'] = pd.to_datetime(df['firstSeenTime']).dt.floor(f'{IOS_ROTATION_WINDOW}min')
        
        rotation_analysis = []
        
        for time_bin, group in df.groupby('time_bin'):
            addresses = group['address'].unique()
            
            # Look for similar RSSI patterns in this window
            rssi_patterns = {}
            for _, row in group.iterrows():
                addr = row['address']
                if addr not in rssi_patterns:
                    rssi_patterns[addr] = []
                rssi_patterns[addr].append(row.get('avgRssi', row.get('rssi', -80)))
            
            # Find potential rotations (similar RSSI in same window)
            potential_rotations = []
            addrs = list(rssi_patterns.keys())
            
            for i in range(len(addrs)):
                for j in range(i+1, len(addrs)):
                    if len(rssi_patterns[addrs[i]]) > 0 and len(rssi_patterns[addrs[j]]) > 0:
                        rssi_diff = abs(np.mean(rssi_patterns[addrs[i]]) - 
                                      np.mean(rssi_patterns[addrs[j]]))
                        if rssi_diff < 5:  # Within 5 dBm
                            potential_rotations.append((addrs[i], addrs[j], rssi_diff))
            
            rotation_analysis.append({
                'time_bin': time_bin,
                'unique_addresses': len(addresses),
                'potential_rotations': len(potential_rotations)
            })
        
        return pd.DataFrame(rotation_analysis)
    
    def probabilistic_cross_day_linking(self, day1_data, day2_data):
        """Link devices across days using probabilistic matching"""
        links = []
        
        for addr1, fp1 in day1_data.items():
            best_match = None
            best_score = 0
            
            for addr2, fp2 in day2_data.items():
                # Skip if same address (unlikely with randomization)
                if addr1 == addr2:
                    continue
                
                # Compute similarity
                similarity = self.compute_fingerprint_similarity(fp1, fp2)
                
                if similarity > best_score and similarity >= CONFIDENCE_THRESHOLD:
                    best_score = similarity
                    best_match = addr2
            
            if best_match:
                links.append({
                    'day1_address': addr1,
                    'day2_address': best_match,
                    'confidence': best_score,
                    'day1_type': self.classify_address(addr1),
                    'day2_type': self.classify_address(addr2)
                })
        
        return links
    
    def load_and_process_data(self):
        """Load data and create fingerprints"""
        all_data = []
        
        for date in EVENT_DATES:
            filename = os.path.join(DATA_DIR, f"Picklecon_combineddwell_flat_{date}.json")
            
            if not os.path.exists(filename):
                print(f"⚠️  Missing: {filename}")
                continue
            
            print(f"\nProcessing {date}...")
            
            with open(filename, 'r') as f:
                data = json.load(f)
            
            # Process each device
            date_devices = {}
            
            for record in data:
                address = self.normalize_address(record.get('address', ''))
                if not address:
                    continue
                
                # Initialize device data structure
                if address not in date_devices:
                    date_devices[address] = {
                        'records': [],
                        'rssi_values': [],
                        'timestamps': [],
                        'locations': [],
                        'address': address,
                        'date': date
                    }
                
                # Aggregate data
                date_devices[address]['records'].append(record)
                date_devices[address]['rssi_values'].append(
                    record.get('avgRssi', record.get('rssi', -80))
                )
                
                # Parse timestamp
                try:
                    ts = pd.to_datetime(record.get('firstSeenTime'))
                    date_devices[address]['timestamps'].append(ts.timestamp())
                    date_devices[address]['first_seen_hour'] = ts.hour
                except:
                    pass
                
                # Locations
                locs = record.get('locationsVisited', [record.get('Location', 'Unknown')])
                date_devices[address]['locations'].extend(locs)
                
                # Copy key fields
                for field in ['dwellTime', 'recordCount']:
                    if field in record:
                        date_devices[address][field] = record[field]
                
                # Add to DataFrame format
                record['date'] = date
                record['address_normalized'] = address
                record['address_type'] = self.classify_address(address)
                all_data.append(record)
            
            # Create fingerprints
            print(f"Creating fingerprints for {len(date_devices)} devices...")
            for addr, device_data in date_devices.items():
                self.fingerprints[addr] = self.extract_features(device_data)
            
            # Cluster devices
            print("Clustering devices...")
            clusters = self.cluster_devices_dbscan(date_devices)
            self.clusters[date] = clusters
            print(f"Found {len(clusters)} clusters")
            
            # Store for cross-day analysis
            self.devices[date] = date_devices
        
        # Create DataFrame
        self.df = pd.DataFrame(all_data)
        if not self.df.empty:
            self.df['firstSeenTime'] = pd.to_datetime(self.df['firstSeenTime'])
            self.df['lastSeenTime'] = pd.to_datetime(self.df['lastSeenTime'])
    
    def analyze_and_visualize(self):
        """Perform analysis and create visualizations"""
        print("\n=== ADVANCED BLE ANALYSIS RESULTS ===")
        
        # 1. Address type distribution
        plt.figure(figsize=(12, 6))
        
        plt.subplot(1, 2, 1)
        type_counts = self.df['address_type'].value_counts()
        type_counts.plot(kind='bar')
        plt.title('Address Type Distribution')
        plt.xticks(rotation=45)
        plt.tight_layout()
        
        # 2. RSSI fingerprint visualization
        plt.subplot(1, 2, 2)
        rssi_data = []
        for fp in self.fingerprints.values():
            if 'rssi_mean' in fp:
                rssi_data.append([fp['rssi_mean'], fp.get('rssi_std', 0)])
        
        if rssi_data:
            rssi_df = pd.DataFrame(rssi_data, columns=['RSSI Mean', 'RSSI Std'])
            plt.scatter(rssi_df['RSSI Mean'], rssi_df['RSSI Std'], alpha=0.5)
            plt.xlabel('RSSI Mean (dBm)')
            plt.ylabel('RSSI Std Dev')
            plt.title('RSSI Fingerprint Distribution')
        
        plt.savefig(os.path.join(RESULTS_DIR, 'address_analysis.png'))
        plt.close()
        
        # 3. iOS rotation window analysis
        print("\n--- iOS Rotation Window Analysis ---")
        for date in EVENT_DATES:
            if date in self.devices:
                date_df = self.df[self.df['date'] == date]
                rotation_df = self.analyze_ios_rotation_windows(date_df)
                
                avg_addresses = rotation_df['unique_addresses'].mean()
                avg_rotations = rotation_df['potential_rotations'].mean()
                
                print(f"{date}: Avg {avg_addresses:.1f} addresses per 15min, "
                      f"{avg_rotations:.1f} potential rotations")
        
        # 4. Cross-day linking
        print("\n--- Cross-Day Probabilistic Linking ---")
        dates = sorted(self.devices.keys())
        
        cross_day_summary = []
        for i in range(len(dates)-1):
            date1, date2 = dates[i], dates[i+1]
            
            # Get fingerprints for each day
            day1_fps = {addr: self.fingerprints[addr] 
                       for addr in self.devices[date1].keys() 
                       if addr in self.fingerprints}
            day2_fps = {addr: self.fingerprints[addr] 
                       for addr in self.devices[date2].keys() 
                       if addr in self.fingerprints}
            
            # Perform linking
            links = self.probabilistic_cross_day_linking(day1_fps, day2_fps)
            
            if links:
                link_df = pd.DataFrame(links)
                high_conf_links = link_df[link_df['confidence'] >= 0.8]
                
                print(f"{date1} → {date2}: {len(links)} total links, "
                      f"{len(high_conf_links)} high confidence (≥80%)")
                
                cross_day_summary.append({
                    'day_pair': f"{date1}-{date2}",
                    'total_links': len(links),
                    'high_conf_links': len(high_conf_links),
                    'avg_confidence': link_df['confidence'].mean()
                })
        
        # 5. Clustering effectiveness
        print("\n--- Clustering Analysis ---")
        cluster_summary = []
        
        for date, clusters in self.clusters.items():
            if clusters:
                cluster_sizes = [len(addrs) for addrs in clusters.values()]
                cluster_summary.append({
                    'date': date,
                    'num_clusters': len(clusters),
                    'avg_cluster_size': np.mean(cluster_sizes),
                    'max_cluster_size': max(cluster_sizes),
                    'clustered_devices': sum(cluster_sizes)
                })
        
        if cluster_summary:
            cluster_df = pd.DataFrame(cluster_summary)
            print(cluster_df)
        
        # 6. Attendance estimation
        print("\n--- Attendance Estimation ---")
        
        # Method 1: Raw unique count (overestimate)
        daily_unique = self.df.groupby('date')['address_normalized'].nunique()
        print(f"\nRaw unique devices per day:")
        for date, count in daily_unique.items():
            print(f"  {date}: {count:,}")
        
        # Method 2: Cluster-based estimate
        cluster_based_estimate = {}
        for date, clusters in self.clusters.items():
            # Count clusters + unclustered devices
            clustered = sum(len(addrs) for addrs in clusters.values())
            total = len(self.devices.get(date, {}))
            unclustered = total - clustered
            estimated = len(clusters) + unclustered
            cluster_based_estimate[date] = estimated
            print(f"\nCluster-based estimate for {date}: {estimated:,}")
        
        # Method 3: Cross-day adjusted
        if cross_day_summary:
            avg_retention = np.mean([s['high_conf_links'] for s in cross_day_summary])
            total_adjusted = int(sum(daily_unique.values()) * 0.7)  # Assume 30% are same devices
            print(f"\nCross-day adjusted total estimate: {total_adjusted:,}")
        
        # 7. Behavioral patterns
        plt.figure(figsize=(15, 10))
        
        # Dwell time distribution by address type
        plt.subplot(2, 2, 1)
        for addr_type in self.df['address_type'].unique():
            type_data = self.df[self.df['address_type'] == addr_type]['dwellTime']
            if len(type_data) > 10:
                plt.hist(type_data, bins=30, alpha=0.5, label=addr_type)
        plt.xlabel('Dwell Time (minutes)')
        plt.ylabel('Count')
        plt.title('Dwell Time Distribution by Address Type')
        plt.legend()
        plt.xlim(0, 300)
        
        # Location patterns
        plt.subplot(2, 2, 2)
        location_counts = Counter()
        for locs in self.df['locationsVisited']:
            if isinstance(locs, list):
                location_counts.update(locs)
        
        if location_counts:
            locs, counts = zip(*location_counts.most_common(10))
            plt.bar(locs, counts)
            plt.xlabel('Location')
            plt.ylabel('Visit Count')
            plt.title('Top 10 Visited Locations')
            plt.xticks(rotation=45)
        
        # RSSI patterns over time
        plt.subplot(2, 2, 3)
        hourly_rssi = self.df.groupby(self.df['firstSeenTime'].dt.hour)['rssi'].mean()
        hourly_rssi.plot(marker='o')
        plt.xlabel('Hour of Day')
        plt.ylabel('Average RSSI (dBm)')
        plt.title('Average RSSI by Hour')
        plt.grid(True)
        
        # Device count by hour
        plt.subplot(2, 2, 4)
        hourly_counts = self.df.groupby(self.df['firstSeenTime'].dt.hour)['address_normalized'].nunique()
        hourly_counts.plot(kind='bar')
        plt.xlabel('Hour of Day')
        plt.ylabel('Unique Devices')
        plt.title('Unique Devices by Hour')
        
        plt.tight_layout()
        plt.savefig(os.path.join(RESULTS_DIR, 'behavioral_analysis.png'))
        plt.close()
        
        # Save detailed results
        self.save_results()
    
    def save_results(self):
        """Save analysis results to files"""
        # Save fingerprints
        fingerprint_data = []
        for addr, fp in self.fingerprints.items():
            fp_record = {'address': addr, **fp}
            fingerprint_data.append(fp_record)
        
        fp_df = pd.DataFrame(fingerprint_data)
        fp_df.to_csv(os.path.join(RESULTS_DIR, 'device_fingerprints.csv'), index=False)
        
        # Save clusters
        cluster_data = []
        for date, clusters in self.clusters.items():
            for cluster_id, addresses in clusters.items():
                for addr in addresses:
                    cluster_data.append({
                        'date': date,
                        'cluster_id': cluster_id,
                        'address': addr
                    })
        
        if cluster_data:
            cluster_df = pd.DataFrame(cluster_data)
            cluster_df.to_csv(os.path.join(RESULTS_DIR, 'device_clusters.csv'), index=False)
        
        print(f"\nResults saved to {RESULTS_DIR}/")

def main():
    # Initialize analyzer
    analyzer = BLEFingerprinter()
    
    # Load and process data
    analyzer.load_and_process_data()
    
    # Perform analysis
    if not analyzer.df.empty:
        analyzer.analyze_and_visualize()
    else:
        print("No data loaded!")

if __name__ == "__main__":
    main()