import display, math, uinterface, virtualtimers, wifi
from umqtt.simple import MQTTClient

DISPLAY_RATE_DOWN = 0
DISPLAY_GRAPH_DOWN = 1
DISPLAY_RATE_UP = 2
DISPLAY_GRAPH_UP = 3
DISPLAY_IDX_COUNT = 4

COLOR_MAP = {
    0: 0x00FF00,
    5*1024**2: 0x11EE00,
    10*1024**2: 0x22DD00,
    15*1024**2: 0x33CC00,
    20*1024**2: 0x44BB00,
    25*1024**2: 0x55AA00,
    30*1024**2: 0x669900,
    35*1024**2: 0x778800,
    40*1024**2: 0x887700,
    45*1024**2: 0x996600,
    50*1024**2: 0xAA5500,
    55*1024**2: 0xBB4400,
    60*1024**2: 0xCC3300,
    65*1024**2: 0xDD2200,
    70*1024**2: 0xEE1100,
    75*1024**2: 0xFF0000,
}
HUMAN_RATES = [
    'B',
    'KB',
    'MB',
    'GB',
    'TB',
    'PB',
]


# Note: We're not using a class as I've had bad experiences with the GC on the Badge.team firmware
def callback(topic, msg):
    global display, display_dirty, measurements
    topic = topic.decode('utf-8')
    msg = msg.decode('utf-8').rstrip('\x00').split(':')
    print(">> Callback")
    print(" - Topic: %s" % topic)
    print(" - Message: %s" % msg)
    
    measurements['down'].append(int(float(msg[1]) * 8))
    measurements['down'] = measurements['down'][-32:]
    measurements['up'].append(int(float(msg[2]) * 8))
    measurements['up'] = measurements['up'][-32:]
    display_dirty = True
    

def process():
    global client
    
    if not wifi.status():
        print("Wifi disconnected?")
        return 0
    
    client.check_msg()
    return 100


def update_display():
    global measurements
    global display_mode, display_count, display_dirty
    
    if display_dirty:
        if display_mode == DISPLAY_GRAPH_DOWN or display_mode == DISPLAY_GRAPH_UP:
            if display_mode == DISPLAY_GRAPH_DOWN:
                data = measurements['down']
            else:
                data = measurements['up']
            if len(data):
                max_val = max(data)
                pixelmap = [8 - math.ceil(x/(max_val/8)) for x in data]
                colormap = [COLOR_MAP[max([key for key in COLOR_MAP if key <= rate])] for rate in data]

                display.drawFill(0)
                for x in range(len(data)):
                    display.drawPixel(x, pixelmap[x], colormap[x])
                display.flush()
            
        elif display_mode == DISPLAY_RATE_DOWN or display_mode == DISPLAY_RATE_UP:
            if display_mode == DISPLAY_RATE_DOWN:
                data = measurements['down']
                color = 0x00FF00
            else:
                data = measurements['up']
                color = 0xFF0000
            if len(data):
                rate = data[-1]
                rate_label = 0
                while (rate > 1024):
                    rate /= 1024
                    rate_label += 1
                rate = int(rate)
                
                display.drawFill(0)
                display.drawText(0, -1, '%s%s' % (rate, HUMAN_RATES[rate_label]), color, "7x5")
                display.flush()
                
        display_dirty = False
        
    
    max_count = 25 # 2.5 seconds per mode
    display_count += 1
    display_count %= max_count
    if display_count == 0:
        display_mode += 1
        display_mode %= DISPLAY_IDX_COUNT
        display_dirty = True
    
    return 100


def main():
    global client, display_count, display_dirty, display_mode, measurements 
    
    print(">> Main")
    uinterface.connect_wifi()
    
    if not wifi.status():
        print("<< Failed to connect to WiFi")
        return
    
    print(" - MQTT")
    server = b"mqtt.space.hackalot.nl"
    client = MQTTClient(b"mqttBadge-CZ19", server)
    client.set_callback(callback)
    client.connect()
    client.subscribe("collectd/OpenWrt/interface-eth0/if_octets")
    
    print(" - Timer")
    virtualtimers.new(100, update_display, True)
    virtualtimers.new(100, process, True)
    virtualtimers.begin(100)
    
    print(" - Defaults")
    measurements = {
        'up': [],
        'down': [],
        'changed': False
    }
    display_count = 0
    display_dirty = True
    display_mode = 0
    
    print(" + Bootup done")

    
main()

