from .abstract_environment import AbstractEnvironment
from typing import Tuple
import random

State = int
Action = int
Reward = int


class SequentialDecisionMaking(AbstractEnvironment[State, Action, Reward]):
    """
    This environment implements the game shown below.
    In the states 1-7, there is on action that results in not changing state, one action leads to the next state and
    two actions go the 'trap' states: 8-10. In these trap states, there are three actions each leading to one of the
    three trap states and one action leads to state 1. G is the final state, and is the terminal state where reward +1
    is received. Internally, state G is 11. So all states can be uniquely identified by 1, ..., 11.

    ┌──►1──►2──►3──►4──►5──►6──►7──►G
    │   └───┴───┴───┴───┴───┴┬──┘
    │          ┌──────┐      │
    └──────────┤8 9 10│◄─────┘
               └──────┘
    Upon initialisation, in each state the actions are randomly assigned to the transitions.
    """

    def __init__(self, seed=1):
        self._end = False
        self._current_state = 1
        self.seed = seed
        random.seed(seed)
        self.trap_states = [8, 9, 10]
        self.states = list(range(1, 12))
        self.goal_state = 11
        self.transitions = [
            [i, i+1] + random.sample(self.trap_states, 2) for i in range(1, 7)
        ] + [[7, 11] + random.sample(self.trap_states, 2)] + [
            [1] + random.sample(self.trap_states, 3) for _ in range(8, 10+1)
        ]
        for i in range(len(self.transitions)):
            random.shuffle(self.transitions[i])
    
    def reset(self):
        """
        Resets the environment, but keeps the randomly assigned transitions as is.
        """
        self._current_state = 1
        self._end = False
    
    @staticmethod
    def available():
        """
        Returns a list of available actions. This statically return `[1, 2, 3, 4]`.
        """
        return [1, 2, 3, 4]
    
    @staticmethod
    def get_num_actions():
        """
        Returns the number of available actions. This statically returns `4`.
        """
        return 4
    
    @staticmethod
    def get_num_states():
        """
        Returns the number of possible states. This statically returns `11`.
        """
        return 11

    @property
    def end(self) -> bool:
        """
        Attribute of the class that becomes true when the environment is in the goal state $G$, internally state `11`.
        """
        return self._end

    def do_action(self, action: Action) -> Tuple[State, Reward]:
        """
        Takes the action `action` in the current state. It returns a tuple `(state, reward)`.
        """
        if not type(action) is Action:
            action = int(action)
        if action not in {1, 2, 3, 4}:
            raise ValueError("Only four available actions: 1, 2, 3, 4")
        if self._current_state == 11:
            return self._current_state, 1
        next_state = self.transitions[self._current_state-1][action-1]
        self._current_state = next_state
        if next_state == 11:
            self._end = True
        return self._current_state, self.reward()
    

    def simulate_action(self, state: State, action: Action) -> Tuple[State, Reward]:
        """
        Takes the action `action` in the current state. It returns a tuple `(state, reward)`. But only simulate
        """
        if not type(action) is Action:
            action = int(action)
        if action not in {1, 2, 3, 4}:
            raise ValueError("Only four available actions: 1, 2, 3, 4")
        if self._current_state == 11:
            return self._current_state, 1
        
        current_state = self._current_state
        next_state = self.transitions[state-1][action-1]
        self._current_state = next_state
        reward = self.reward()
        self._current_state = current_state

        return next_state, reward


    def get_state(self) -> State:
        """
        Returns the current state.
        """
        return self._current_state

    def reward(self) -> Reward:
        """
        Returns the reward for getting to the current state. This is a reward of $1$ if the agent reaches the goal state $G$, otherwise $0$.
        """
        if self._current_state == 11:
            return 1
        else:
            return 0

    def render(self):
        """
        Prints the current game state. It puts a block around the current state. Not all fonts support this. If not, then the box will be printed after the state. 
        """
        basegame = """\
┌──►1──►2──►3──►4──►5──►6──►7──►G
│   └───┴───┴───┴───┴───┴┬──┘
│          ┌──────┐      │
└──────────┤8 9 10│◄─────┘
           └──────┘"""
        marker = "⃞"
        indices = [4, 8, 12, 16, 20, 24, 28, 103, 105, 107, 32]
        if self._current_state != 10:
            index = indices[self._current_state-1]
            print(basegame[:index+1] + marker + basegame[index+1:])
        else:
            index = indices[self._current_state-1]
            print(basegame[:index+1] + marker + basegame[index+1] + marker + basegame[index+2:])

