#!/usr/bin/env python3
"""
Test script to verify attendance estimation is working
Run this after the main analysis to check filtering
"""

import pandas as pd
import numpy as np
import os

# Load the saved session data
RESULTS_DIR = "./analysis_results"
session_file = os.path.join(RESULTS_DIR, "device_sessions.csv")

if not os.path.exists(session_file):
    print(f"Error: {session_file} not found. Run the main analysis first.")
    exit(1)

# Load data
print("Loading session data...")
df = pd.read_csv(session_file)
df['firstSeenTime'] = pd.to_datetime(df['firstSeenTime'])
df['lastSeenTime'] = pd.to_datetime(df['lastSeenTime'])

print("\n=== ATTENDANCE ESTIMATION ANALYSIS ===")

# 1. Raw counts
print("\n1. RAW UNIQUE DEVICES:")
daily_unique = df.groupby('date')['address'].nunique()
for date, count in daily_unique.items():
    print(f"  {date}: {count:,}")

# 2. Address type breakdown
print("\n2. ADDRESS TYPE BREAKDOWN:")
# First, we need to classify addresses
def classify_address(address):
    """Classify MAC address type"""
    if len(address) != 12:
        return "Invalid"
    
    try:
        msb = int(address[:2], 16)
        is_local = (msb & 0x02) != 0
        is_multicast = (msb & 0x01) != 0
        
        if is_local and not is_multicast:
            type_bits = (msb >> 6) & 0x03
            if type_bits == 0b11:
                return "Random Static"
            elif type_bits == 0b01:
                return "Random Private Resolvable"
            elif type_bits == 0b00:
                return "Random Private Non-Resolvable"
            else:
                return "Reserved"
        else:
            return "Public"
    except:
        return "Invalid"

df['address_type'] = df['address'].apply(classify_address)

for date in df['date'].unique():
    date_df = df[df['date'] == date]
    print(f"\n  {date}:")
    type_counts = date_df.groupby('address_type')['address'].nunique().sort_values(ascending=False)
    for addr_type, count in type_counts.items():
        print(f"    {addr_type}: {count:,}")

# 3. Dwell time filtering
print("\n3. FILTERED BY DWELL TIME:")

# Calculate total dwell time per device
device_dwell = df.groupby(['date', 'address', 'address_type'])['dwellTime'].sum().reset_index()
device_dwell.columns = ['date', 'address', 'address_type', 'total_dwell_time']

# Different thresholds
thresholds = [0, 5, 10, 15, 30, 60]

for threshold in thresholds:
    print(f"\n  Devices with ≥{threshold} min total dwell time:")
    filtered = device_dwell[device_dwell['total_dwell_time'] >= threshold]
    
    by_date = filtered.groupby('date')['address'].nunique()
    for date, count in by_date.items():
        print(f"    {date}: {count:,} total")
    
    if threshold == 10:
        print(f"    Breakdown by type:")
        by_type = filtered.groupby(['date', 'address_type'])['address'].nunique()
        for (date, addr_type), count in by_type.items():
            print(f"      {addr_type}: {count:,}")

# 4. Random Static with 10+ minutes
print("\n4. RANDOM STATIC DEVICES WITH ≥10 MIN DWELL TIME:")
random_static_10min = device_dwell[
    (device_dwell['address_type'] == 'Random Static') & 
    (device_dwell['total_dwell_time'] >= 10)
]

if len(random_static_10min) > 0:
    rs_by_date = random_static_10min.groupby('date')['address'].nunique()
    for date, count in rs_by_date.items():
        print(f"  {date}: {count:,} devices")
        
        date_data = random_static_10min[random_static_10min['date'] == date]['total_dwell_time']
        if len(date_data) > 0:
            print(f"    Dwell stats: mean={date_data.mean():.1f} min, "
                  f"median={date_data.median():.1f} min, max={date_data.max():.1f} min")
else:
    print("  No Random Static devices found with ≥10 min dwell time")

# 5. Summary
print("\n=== SUMMARY ===")
for date in daily_unique.index:
    print(f"\n{date}:")
    print(f"  All devices: {daily_unique[date]:,}")
    
    # 10+ min count
    date_10min = device_dwell[
        (device_dwell['date'] == date) & 
        (device_dwell['total_dwell_time'] >= 10)
    ]['address'].nunique()
    print(f"  10+ min dwell: {date_10min:,}")
    
    # Random Static 10+ min
    rs_10min = device_dwell[
        (device_dwell['date'] == date) & 
        (device_dwell['address_type'] == 'Random Static') & 
        (device_dwell['total_dwell_time'] >= 10)
    ]['address'].nunique()
    print(f"  Random Static 10+ min: {rs_10min:,}")
    
    # Percentage
    if daily_unique[date] > 0:
        print(f"  Random Static 10+ min %: {rs_10min/daily_unique[date]*100:.1f}%")