import bhi160
import display
import utime
from urandom import randint


class Vector:
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __repr__(self):
        return 'Vector(x={}, y={})'.format(self.x, self.y)

    def __radd__(self, other):
        return self.__add__(other)

    def __add__(self, other):
        return Vector(self.x + other.x, self.y + other.y)

    def __sub__(self, other):
        return Vector(self.x - other.x, self.y - other.y)

    def __rsub__(self, other):
        return Vector(other.x - self.x, other.y - self.y)

    def __neg__(self):
        return -1 * self

    def __rmul__(self, other):
        return self.__mul__(other)

    def __mul__(self, other):
        if type(other) is Vector:
            return self.x * other.x + self.y * other.y
        if type(other) in [int, float]:
            return Vector(self.x * other, self.y * other)

    def __truediv__(self, other):
        return Vector(self.x / other, self.y / other)

    def __eq__(self, other):
        return self.x == other.x and self.y == other.y

    def __lt__(self, other):
        return abs(self) < abs(other)

    def __le__(self, other):
        return abs(self) <= abs(other)

    def __ge__(self, other):
        return abs(self) >= abs(other)

    def __gt__(self, other):
        return abs(self) > abs(other)

    def __abs__(self):
        return (self.x ** 2 + self.y ** 2) ** 0.5

    def swap_xy(self):
        return Vector(self.y, self.x)


class Line:
    def __init__(self, a: Vector, b: Vector):
        _sorted = sorted([a, b])
        self.start = _sorted[0]
        self.end = _sorted[1]

    def __getitem__(self, index):
        if index in [0, 'a', 'start']:
            return self.start
        if index in [1, 'b', 'end']:
            return self.end

    def __contains__(self, other):
        diff_x = self.end.x - self.start.x
        if diff_x == 0:
            return other.swap_xy() in self.swap_xy()
        diff_y = self.end.y - self.start.y
        slope = diff_y / diff_x
        offset = self.start.y - slope * self.start.x
        if type(other) is Vector:
            if other.y != slope * other.x + offset:
                return False
            return self.start.x <= other.x <= self.end.x

    def __repr__(self):
        return 'Line(a={}, b={})'.format(self.start, self.end)

    def swap_xy(self):
        return Line(self.start.swap_xy(), self.end.swap_xy())


class Target:
    def __init__(self, x, y, radius):
        self.x = x
        self.y = y
        self.radius = radius

    def __contains__(self, other):
        if type(other) != Vector:
            raise('it is only possible to check whether a Vector is'
                  'within Target')
        return (
            (self.x - self.radius < other.x < self.x + self.radius)
            and
            (self.y - self.radius < other.y < self.y + self.radius)
        )

    def __repr__(self):
        return 'Target(x={}, y={}, radius={})'.format(
                self.x, self.y, self.radius)


def intersection_of_orthogonal_and_point(line: Line, point: Vector):
    diff_x = line.end.x - line.start.x
    diff_y = line.end.y - line.start.y
    if diff_x == 0:
        x = line.start.x
        y = point.y
    elif diff_y == 0:
        x = point.x
        y = line.start.y
    else:
        m_o = - diff_x / diff_y
        b_o = point.y - m_o * point.x

        m_g = diff_y / diff_x
        b_g = line.start.y - m_g * line.start.x

        x = (b_o - b_g) / (m_g - m_o)
        y = m_g * x + b_g

    return Vector(x, y)


class Timing:
    def __init__(self, timeout=15, start=None):
        self._timeout = timeout
        self._timestamp = start if start is not None else utime.time_ms()

    def malus(self, malus=0.5):
        self._timeout -= malus

    def remaining_time(self):
        now = utime.time_ms()
        remaining_time = (float(self._timeout)
                          - float(now - self._timestamp) / 1000)
        if remaining_time < 0:
            remaining_time = 0.0
        return remaining_time

    def draw(self, disp, posx=115, posy=15):
        disp.print('{:.2f}'.format(self.remaining_time()), posx=posx,
                   posy=posy, font=display.FONT8)


class Credits:
    def __init__(self):
        self.credit = 0

    def add_score(self, remaining_time):
        self.credit += round(remaining_time)

    def draw(self, disp, posx=5, posy=5):
        disp.print('Points: {: 4d}'.format(self.credit), posx=posx,
                   posy=posy, font=display.FONT8)


class Level:
    def __init__(self):
        self.level = 1

    def update_level(self, credits):
        level = credits.credit // 20 + 1
        if level > self.level:
            self.level = level
            return True
        else:
            return False

    def draw(self, disp, posx=115, posy=5):
        disp.print('lvl: {: 4d}'.format(self.level), posx=posx,
                   posy=posy, font=display.FONT8)


class Stats:
    def __init__(self, hits=0, level=0, circle_size=10, hits_per_level=5,
                 accel_factor=15, time_to_next_hit=15, decel_factor=0.66):
        self.hits = hits
        self.circle_size = circle_size
        self.accel_factor = accel_factor
        self.decel_factor = decel_factor
        self.time_to_next_hit = time_to_next_hit
        self.malus = 0.5

    def __repr__(self):
        return (('Stats(hits={}, circle_size={}, '
                 'accel_factor={}, time_to_next_hit={})')
                .format(self.hits, self.circle_size,
                        self.accel_factor,
                        self.time_to_next_hit))

    def level_up(self):
        self.accel_factor = round(self.accel_factor * 1.2)
        self.decel_factor = self.decel_factor * 1.1
        if self.time_to_next_hit > 5:
            self.time_to_next_hit -= 1
        if self.malus < self.time_to_next_hit:
            self.malus += 0.5
        if self.circle_size > 4:
            self.circle_size -= 1

        print('new accel factor:', self.accel_factor)
        print('new decel factor:', self.decel_factor)
        print('new malus:', self.malus)

    def new_hit(self):
        self.hits += 1


stats = Stats()

DISPLAY = Vector(x=160, y=80)
accel_factor = 5
DELAY = 0.01

# circle
target = Target(x=randint(0, DISPLAY.x), y=randint(0, DISPLAY.y),
                radius=stats.circle_size)
credits = Credits()
timer = Timing(timeout=stats.time_to_next_hit)
level = Level()

bhi = bhi160.BHI160Accelerometer()

ball_origin = Vector(x=DISPLAY.x // 2, y=DISPLAY.y // 2)
ball_trace = [ball_origin - Vector(1, 0)]
ball_velocity = Vector(x=0, y=0)

# setup display
with display.open() as d:
    if hasattr(d, 'backlight'):
        d.backlight(1000)


timestamp = utime.time_ms()
while True:
    while True:
        current_sample = bhi.read()
        if len(current_sample) > 0:
            current_sample = current_sample[0]
            break
        utime.sleep(DELAY)
    cumm_sleep_time = float(utime.time_ms() - timestamp) / 1000
    timestamp = utime.time_ms()

    ball_velocity = Vector(
        x=ball_velocity.x
        + stats.accel_factor * cumm_sleep_time * current_sample.x,
        y=ball_velocity.y
        - stats.accel_factor * cumm_sleep_time * current_sample.y
    )

    ball_origin = Vector(
        x=round(ball_origin.x + cumm_sleep_time * ball_velocity.x),
        y=round(ball_origin.y + cumm_sleep_time * ball_velocity.y)
    )

    last_segment = Line(ball_origin, ball_trace[-1])
    intersect = intersection_of_orthogonal_and_point(last_segment, target)

    # print('target', target)
    # print('ball_origin', ball_origin)
    # print('last_segment', last_segment)
    # print('intersect:', intersect)

    if (ball_origin in target or
            intersect is not None
            and last_segment.start != last_segment.end
            and intersect in last_segment
            and abs(intersect - target) < target.radius):
        target = Target(x=randint(0, DISPLAY.x), y=randint(0, DISPLAY.y),
                        radius=stats.circle_size)
        credits.add_score(timer.remaining_time())
        if level.update_level(credits):
            stats.level_up()
        timer = Timing(timeout=stats.time_to_next_hit)
        stats.new_hit()

    if ball_origin.x > DISPLAY.x:
        ball_velocity.x *= - stats.decel_factor
        ball_velocity.y *= stats.decel_factor
        remainder = ball_origin.x - DISPLAY.x
        ball_origin.x = DISPLAY.x - remainder
        timer.malus()

    if ball_origin.x < 0:
        ball_velocity.x *= - stats.decel_factor
        ball_velocity.y *= stats.decel_factor
        ball_origin.x *= -1
        timer.malus()

    if ball_origin.y > DISPLAY.y:
        ball_velocity.x *= stats.decel_factor
        ball_velocity.y *= - stats.decel_factor
        remainder = ball_origin.y - DISPLAY.y
        ball_origin.y = DISPLAY.y - remainder
        timer.malus()

    if ball_origin.y < 0:
        ball_velocity.x *= stats.decel_factor
        ball_velocity.y *= - stats.decel_factor
        ball_origin.y *= -1
        timer.malus()

    if len(ball_trace) == 20:
        ball_trace = ball_trace[1:] + [ball_origin]
    else:
        ball_trace.append(ball_origin)
    with display.open() as disp:
        disp.clear()

        credits.draw(disp)
        timer.draw(disp)
        level.draw(disp)

        # draw the circle
        disp.circ(target.x, target.y, target.radius, col=[255, 0, 0])
        disp.circ(target.x, target.y, target.radius - 2, col=[0, 0, 0])

        # draw the moving dot
        for i in range(len(ball_trace) - 1):
            col = round(255 / len(ball_trace) * i)
            disp.line(ball_trace[i].x, ball_trace[i].y,
                      ball_trace[i+1].x, ball_trace[i+1].y,
                      col=[col, col, col])

        disp.update()