import random
import time

import display
import buttons
import sndmixer
import system

from machine import Pin
from neopixel import NeoPixel

powerPin = Pin(19, Pin.OUT)
dataPin = Pin(5, Pin.OUT)
np = NeoPixel(dataPin, 5)
powerPin.on()
for i in range(5): np[i] = (0, 0, 0)
np.write()
try:
    sndmixer.begin(3)
    synthId = sndmixer.synth()
except:
    synthId = 1
sndmixer.waveform(synthId, 2)
sndmixer.volume(synthId, 0)
sndmixer.freq(synthId, 440)
sndmixer.play(synthId)

BLACK = 0x541388
WHITE = 0xcccccc
GREY = 0x888888
BLUE = 0x261447
CYAN = 0x2de2e6
RED = 0xff3864
YELLOW = 0xff6c11
YELLOW2 = 0x662700

COLOR_BG = BLACK
COLOR_TEXT = WHITE
COLOR_LABEL = GREY
COLOR_AN_BG = GREY
COLOR_A_FG = WHITE
COLOR_A_BG = RED
COLOR_OUTPUT = WHITE
COLOR_OUTPUT_BG = YELLOW
COLOR_LED = BLACK
COLOR_LED_BG = CYAN

def frame(x, y, w, h, fg, bg, bg2 = None):
    TRI=12
    display.drawRect(x, y, w, h, True, fg)
    display.drawTri(x, y, x + TRI, y, x, y+TRI, bg)
    display.drawTri(x+w, y+h, x+w - TRI, y+h, x+w, y+h-TRI, bg if bg2 is None else bg2)

BITFONT = {
    "0": [0x7f, 0x41, 0x41, 0x41, 0x41, 0x41, 0x7f],
    "1": [0x38, 0x08, 0x08, 0x08, 0x08, 0x08, 0x7f],
    "2": [0x7f, 0x01, 0x01, 0x7f, 0x40, 0x40, 0x7f],
    "3": [0x7f, 0x01, 0x01, 0x3f, 0x01, 0x01, 0x7f],
    "4": [0x41, 0x41, 0x41, 0x7f, 0x01, 0x01, 0x01],
    "5": [0x7f, 0x40, 0x40, 0x7f, 0x01, 0x01, 0x7f],
    "6": [0x7f, 0x40, 0x40, 0x7f, 0x41, 0x41, 0x7f],
    "7": [0x7f, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01],
    "8": [0x7f, 0x41, 0x41, 0x7f, 0x41, 0x41, 0x7f],
    "9": [0x7f, 0x41, 0x41, 0x7f, 0x01, 0x01, 0x7f],
    ":": [0x00, 0x18, 0x18, 0x00, 0x18, 0x18, 0x00],
    "*": [0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f],
    " ": [0],
}
def text(x, y, s, fg, bg):
    for c in s:
        for row in range(len(BITFONT[c])):
            for i in range(7):
                display.drawPixel(x + 7 - i, y + row, fg if (BITFONT[c][row] >> i) & 1 else bg)
        x += 8

class LMC:
    def __init__(self, code):
        self.mem = [0] * 100
        self.mem[0:len(code)] = code
        self.a, self.n, self.pc, self.op = 0, 0, 0, 0
        self.out = [' '] * 45
        self.dots = [0] * 10
        self.redraw()
    def step(self):
        self.op = self.read(self.pc)
        self.set_pc((self.pc + 1) % 100)
        a, pc, op, arg = self.a, self.pc, self.op // 100, self.op % 100
        if op == 1: # ADD nn
            self.set_a((a + self.read(arg)) % 100)
        elif op == 2: # SUB nn
            n = self.read(arg)
            self.set_n(int(a < n))
            self.set_a((a + 100 - n) % 100)
        elif op == 3: # STA [nn]
            self.write(arg, a)
        elif op == 4: # STI [nn]
            self.write(a, self.read(arg))
        elif op == 5: # LDA [nn]
            self.set_a(self.read(arg))
        elif op == 6: # JMP
            self.set_pc(arg)
        elif op == 7: # JZ
            self.set_pc(pc if a != 0 else arg)
        elif op == 8: # JP
            self.set_pc(pc if self.n else arg)
        elif op == 9 and arg == 1: # IN
            n = self.input(0)
            self.set_a(n)
        elif op == 9 and arg == 2: # OUT
            self.output(self.a)
        elif op == 9 and arg == 3: # LED ON
            self.set_dot(self.a % 10, 1)
        elif op == 9 and arg == 4: # LED OFF
            self.set_dot(self.a % 10, 0)
        elif op == 9 and arg == 5: # REAL LED
            self.set_led((self.a // 10) %5, self.a % 10)
        elif op == 9 and arg == 6: # KEY
            a = 0
            if buttons.value(buttons.KEY_A): a = a + 1
            if buttons.value(buttons.KEY_B): a = a + 2
            if buttons.value(buttons.KEY_UP): a = a + 4
            if buttons.value(buttons.KEY_DOWN): a = a + 8
            if buttons.value(buttons.KEY_LEFT): a = a + 16
            if buttons.value(buttons.KEY_RIGHT): a = a + 32
            self.set_a(a)
        elif op == 9 and arg == 7: # SOUND
            NOTES = [246, 261, 293, 329, 349, 392, 440, 493, 523, 587]
            pitch, duration = (self.a // 10) % 10, self.a % 10
            sndmixer.volume(synthId, 10)
            sndmixer.freq(synthId, NOTES[pitch])
            time.sleep(duration / 10)
            sndmixer.volume(synthId, 0)
        elif op == 9 and arg > 10 and arg < 20: # SLEEP
            time.sleep((arg % 10) / 10)
        elif op == 9 and arg == 99:
            self.set_a(self.read(a))
    def read(self, addr):
        return self.mem[addr]
    def write(self, addr, value):
        self.mem[addr] = value
        self.render_cell(addr)
    def set_a(self, a):
        self.a = a
        c = COLOR_AN_BG if self.n else COLOR_A_BG
        frame(0, 0, 40, 32, c, 0)
        text(12, 13, f"{self.a:02}", COLOR_A_FG, c)
    def set_pc(self, pc):
        row, col = self.pc // 10, self.pc % 10
        display.drawRect(col * 28 + 31, row * 17 + 65, 28, 12, False, COLOR_BG)
        row, col = pc // 10, pc % 10
        display.drawRect(col * 28 + 31, row * 17 + 65, 28, 12, False, COLOR_TEXT)
        self.pc = pc
    def set_n(self, n):
        self.n = n
    def output(self, n):
        self.out.append(' ')
        self.out.append(str((n // 10) % 10))
        self.out.append(str(n % 10))
        self.out = self.out[-12:]
        self.render_output()
    def render_output(self):
        frame(210, 0, 110, 32, COLOR_OUTPUT_BG, 0)
        text(213, 13, ''.join(self.out), COLOR_OUTPUT, COLOR_OUTPUT_BG)
    def set_led(self, index, value):
        LEDCOLOR = [(0, 0, 0), (0, 0, 40), (0, 40, 0), (40, 0, 0), (0, 40, 40), (40, 40, 0), (40, 0, 40), (80, 80, 80), (random.randint(0, 40), random.randint(0, 40), random.randint(0, 40)), (40, 40, 40) ]
        np[index%5] = LEDCOLOR[value]
        np.write()
    def set_dot(self, index, value):
        i = 9 - index
        self.dots[i] = value
        frame(48, 0, 155, 32, COLOR_LED_BG, 0)
        text(58 + i * 14, 13, "*", COLOR_LED if self.dots[i] else COLOR_LED_BG, COLOR_LED_BG)
    def input(self, value):
        digit = 2
        n = [(value // 100) % 10, (value // 10) % 10, value % 10]
        while not clicked(buttons.BTN_A):
            frame(88, 88, 150, 80, YELLOW2, BLACK)
            frame(80, 80, 150, 80, YELLOW, BLACK, YELLOW2)
            display.drawRect(100, 105, 32, 32, True, RED)
            display.drawRect(140, 105, 32, 32, True, RED)
            display.drawRect(180, 105, 32, 32, True, RED)
            text(112, 117, f'{n[0]}', WHITE, RED)
            text(152, 117, f'{n[1]}', WHITE, RED)
            text(192, 117, f'{n[2]}', WHITE, RED)
            display.drawTri(111 + digit * 40, 100, 119 + digit * 40, 100, 115 + digit * 40, 95, 0)
            display.drawTri(111 + digit * 40, 142, 119 + digit * 40, 142, 115 + digit * 40, 147, 0)
            display.flush()
            if pressed(buttons.BTN_UP): n[digit] = (n[digit] + 1) % 10
            elif pressed(buttons.BTN_DOWN): n[digit] = (n[digit] + 9) % 10
            elif pressed(buttons.BTN_LEFT): digit = (digit + 2) % 3
            elif pressed(buttons.BTN_RIGHT): digit = (digit + 1) % 3
            elif pressed(buttons.BTN_HOME):
                n = [0, 0, 0]
            elif clicked(buttons.BTN_B):
                n = [(value // 100) % 10, (value // 10) % 10, value % 10]
                break
        self.redraw()
        display.flush()
        return n[0] * 100 + n[1] * 10 + n[2]
    def redraw(self):
        display.drawFill(0)
        frame(0, 40, 320, 200, COLOR_BG, 0)
        for label in range(10):
            text(label * 28 + 33, 50, f"  {label}", COLOR_LABEL, COLOR_BG)
            text(8, label * 17 + 68, f"{label}0:", COLOR_LABEL, COLOR_BG)
        for addr in range(100):
            self.render_cell(addr)
        self.set_n(self.n)
        self.set_a(self.a)
        self.set_pc(self.pc)
        for i in range(10): self.set_dot(i, 0)
        self.render_output()

    def render_cell(self, addr):
        row, col = addr // 10, addr % 10
        text(col * 28 + 33, row * 17 + 68, f"{self.mem[addr]: >3}",COLOR_TEXT, COLOR_BG)

BLINK = [507,905,915,208,905,915,600,12,2]
DOT = [506,903,913,904,107,601,0,1]
ADD = [901,306,901,106,902]
MUL = [901,320,901,321,522,120,322,521,223,321,804,522,220,322,902,0,0,0,0,0,0,0,0,1]
MUSIC = [510,999,708,907,510, 111,310,601,0,0,
         20,1,0,0,0, 0,0,0,0,0,
         52, 52, 52, 22, 32, 32, 24, 72, 72, 62, 62, 55, 0]
lmc = LMC(MUSIC)

def clicked(btn):
    if not buttons.value(btn): return False
    while buttons.value(btn): time.sleep(0.1)
    return True

def pressed(btn):
    if not buttons.value(btn): return False
    time.sleep(0.2)
    return True

def editor():
    while True:
        if pressed(buttons.BTN_HOME):
            lmc.set_pc(0)
            for i in range(10): lmc.set_dot(i, 0)
            for i in range(5): lmc.set_led(i, 0)
            sndmixer.volume(synthId, 0)
            lmc.out = [' '] * 32
            lmc.render_output()
        elif pressed(buttons.BTN_MENU): system.launcher()
        elif pressed(buttons.BTN_DOWN): lmc.set_pc((lmc.pc + 10) % 100)
        elif pressed(buttons.BTN_UP): lmc.set_pc((lmc.pc + 90) % 100)
        elif pressed(buttons.BTN_RIGHT): lmc.set_pc((lmc.pc + 1) % 100)
        elif pressed(buttons.BTN_LEFT): lmc.set_pc((lmc.pc + 99) % 100)
        elif clicked(buttons.BTN_A) or clicked(buttons.BTN_B):
            lmc.write(lmc.pc, lmc.input(lmc.read(lmc.pc)))
        elif clicked(buttons.BTN_START):
            return
        elif clicked(buttons.BTN_SELECT):
            lmc.step()
        display.flush()

def run():
    while True:
        lmc.write(99, random.randint(0, 99))
        lmc.step()
        display.flush()
        if lmc.op == 0 or clicked(buttons.BTN_START):
            return
        elif buttons.value(buttons.BTN_SELECT):
            time.sleep(1)

while True:
    editor()
    time.sleep(0.3)
    run()
    time.sleep(0.3)