
import java.util.ArrayList;
import java.util.HashMap;

import logist.agent.Agent;
import logist.plan.Action;
import logist.plan.Action.Move;
import logist.plan.Action.Pickup;
import logist.simulation.Vehicle;
import logist.task.Task;
import logist.task.TaskDistribution;
import logist.topology.Topology;
import logist.topology.Topology.City;

class QValueIteration{

	// store the vehicle
	public Vehicle V;
	public TaskDistribution td;
	public Topology T;
	// some relevant values
	public int numCities;
	public int numMyStates;
	public int numMyActions;
	
	// We construct some maps so that values only needed to be computed
	// once and are acquired much faster during iterations
	
	// a map from city object id to city object for convenience
	public ArrayList<City> cityMap;
	// a map from a state to a list of possible actions
	public HashMap<myState, ArrayList<myAction>> actionMap;
	// a map from a state-action pair to a list of possible next states
	public HashMap<myState, HashMap<myAction, ArrayList<myState>>> nextStatesMap;
	// a map from a state and action pair to a reward
	public HashMap<myState, HashMap<myAction, Double>> rewardTable;
	// a map from current state, action, and potential future state to the probability of that transition
	public HashMap<myState, HashMap<myAction, HashMap<myState, Double>>> transitionProbabilityTable;
	// a map from state to values for optimization
	public HashMap<myState, Double> stateValues;
	// a map from state to optimal action
	public HashMap<myState, myAction> policy;
	
	// ************************************************************
	// HERE WE DEFINE THE CONSTRUCTOR
	// ************************************************************
	public QValueIteration(Topology topology, TaskDistribution td, Agent agent) {
		this.V = agent.vehicles().get(0);
		this.td = td;
		this.T = topology;
		
		// construct helper variables for efficiency
		this.numCities = topology.size();
		this.numMyActions = topology.size()*2;
		this.numMyStates = topology.size()*(topology.size()+1);
		
		// a map from city object id to city object for convenience
		this.cityMap = new ArrayList<City>(this.numCities);
		for(City c: topology) {
			this.cityMap.add(c.id, c);
		}
		this.cityMap.add(null);
		
		// initialize map from a state to a list of possible actions
		this.actionMap = new HashMap<myState, ArrayList<myAction>>(this.numMyStates, 1.0f);
		
		// initialize map from a state-action pair to a list of possible next states
		this.nextStatesMap = new HashMap<myState, HashMap<myAction, ArrayList<myState>>>(this.numMyStates, 1.0f);
		
		// build reward table for rapid lookup
		this.buildRewardTable();
		
		// build transition probability table
		this.buildTransitionProbabilityTable();
		
		// initialize value and policy maps
		this.stateValues = new HashMap<myState, Double>(this.numMyStates, 1.0f);
		for(int i=0; i<this.numMyStates; i++) {
			myState s = new myState(i);
			this.stateValues.put(s, 0.0);
		}
		this.policy = new HashMap<myState, myAction>(this.numMyStates, 1.0f);
		
		// Reads the discount factor from the agents.xml file.
		// If the property is not present it defaults to 0.95
		Double discount = agent.readProperty("discount-factor", Double.class,
				0.95);

		// Reads the iteration threshold from the agents.xml file.
		// If the property is not present it defaults to 0.00000000000001
		Double threshold = agent.readProperty("threshold", Double.class,
				0.00000000000001);
		
		// run value iteration algorithm to determine optimal policy
		this.valueIteration(discount, threshold);
	}

	// ************************************************************
	// HERE WE DEFINE OUR STATE OBJECT
	// ************************************************************
	// Note: We define the state nested inside the larger class so
	// we have access to some macro variables like numCities, which
	// we use to compute hashcodes
	private class myState {
		// The main components of a properly defined state are the current city of the vehicle
		// and the target city of the task available at that city
		public City currentCity;
		public City taskCity;
		
		public myState(City current, City next) {
			this.currentCity = current;
			this.taskCity = next;
		}
		
		// an alternative way to build a state object using the hashcode
		public myState(int id) {
			int currentID = id%numCities;
			int taskID = (id-currentID)/numCities;
			this.currentCity = cityMap.get(currentID);
			this.taskCity = cityMap.get(taskID);
		}
		
		// override the hashcode so it's based on content not pointers
		@Override
	    public int hashCode() {
			int taskID;
			if(this.taskCity==null) {
				taskID = numCities;
			}
			else {
				taskID = this.taskCity.id;
			}
			return currentCity.id + taskID*numCities;
	    }

		// override equals so the hashing properly handles collisions
		@Override
		public boolean equals(Object o) {
			// Check if pointers are equal
			if (o == this) {
				return true;
			}

        	// Check if o is an instance of myState
			if (o instanceof myState c) {
				// compare hashcodes
				return this.hashCode() == c.hashCode();
			}
			else {
				return false;
			}
		}
	}
	
	// ************************************************************
	// HERE WE DEFINE OUR ACTION OBJECT
	// ************************************************************
	private class myAction {
		
		// The main components of an action are to decide to pick up the task or not.
		// If the task is picked up, the next city is determined, but if not, the next
		// next city is a choice that the agent can make.
		public Boolean takeTask;
		public City nextCity;
		
		public myAction(Boolean take, City next) {
			this.takeTask = take;
			this.nextCity = next;
		}
		
		// override the hashcode so it's based on content not pointers
		@Override
	    public int hashCode() {
			if(this.takeTask) {
				return this.nextCity.id + numCities;
			}
			else {
				return this.nextCity.id;
			}
	    }

		// override equals so the hashing properly handles collisions
		@Override
		public boolean equals(Object o) {
			// Check if pointers are equal
			if (o == this) {
				return true;
			}

			// Check if o is an instance of myAction
			if (o instanceof myAction c) {
				// compare hashcodes
				return this.hashCode() == c.hashCode();
			}
			else {
				return false;
			}
		}
	}
	
	// ************************************************************
	// HERE WE WRITE OUR CLASS METHODS
	// ************************************************************
	
	// Get the list of possible actions given a state
	private ArrayList<myAction> getActions(myState s){
		// check if this list has already been computed in the hashmap
		if(this.actionMap.containsKey(s)) {
			return this.actionMap.get(s);
		}
		// compute list and add it to the hashmap before returning
		else {
			ArrayList<myAction> actions = new ArrayList<myAction>();
			// if there is a task available, there exists the action of
			// taking the task and moving to the target city
			if(s.taskCity!=null) {
				actions.add(new myAction(true, s.taskCity));
			}
			// there is always the option to not take the task and move
			// to any city, however, the problem is localized such that
			// we only have to consider the neighboring cities, since
			// we must move through a neighboring city to reach any other
			// city regardless.
			for(City c:s.currentCity.neighbors()) {
				actions.add(new myAction(false, c));
			}
			// add the list to the hashmap
			this.actionMap.put(s, actions);
			// return the list
			return actions;
		}
	}
	
	// Get the list of possible next states given state and action pair
	private ArrayList<myState> getNextStates(myState s, myAction a){
		// check if this list has already been computed in the hashmap
		if(!this.nextStatesMap.containsKey(s)) {
			HashMap<myAction, ArrayList<myState>> actionList = new HashMap<myAction, ArrayList<myState>>(this.numMyActions, 1.0f);
			this.nextStatesMap.put(s, actionList);
		}
		if(this.nextStatesMap.get(s).containsKey(a)) {
			return this.nextStatesMap.get(s).get(a);
		}
		
		// compute list and add it to the hashmap before returning
		ArrayList<myState> states = new ArrayList<myState>();
		// the target city becomes the new current city
		City cCity = a.nextCity;
		// there could be no task, giving target city of null
		myState noTask = new myState(cCity, null);
		states.add(noTask);
		// there could be a task with potentially any target city
		for(City c: this.T.cities()) {
			myState nextState = new myState(cCity, c);
			states.add(nextState);
		}
		// add the list to the hashmap
		this.nextStatesMap.get(s).put(a, states);
		// return the list
		return states;
	}
	
	// Construct the reward table
	private void buildRewardTable() {
		// Initialize the table
		this.rewardTable = new HashMap<myState, HashMap<myAction, Double>>(this.numMyStates, 1.0f);
		
		// iterate over each state
		for(int i=0; i<this.numMyStates; i++) {
			myState s = new myState(i);
			// initialize hashmap for actions
			HashMap<myAction, Double> actionList = new HashMap<myAction, Double>(this.numMyActions, 1.0f);
			// iterate over each possible action
			for(myAction action: this.getActions(s)) {
				// compute total reward as task reward - cost
				// cost = cost per km * distance
				City source = s.currentCity;
				City target = action.nextCity;
				double cost = V.costPerKm()*source.distanceTo(target);
				double reward;
				if(action.takeTask) {
					reward = this.td.reward(source, target);
				}
				else {
					reward = 0.0;
				}
				actionList.put(action, reward-cost);
			}
			// add action map to state map
			this.rewardTable.put(s, actionList);
		}
	}
	
	// Construct the transition table
	private void buildTransitionProbabilityTable() {
		// initialize outer state table
		this.transitionProbabilityTable = new HashMap<myState, HashMap<myAction, HashMap<myState, Double>>>(this.numMyStates, 1.0f);
		// iterate over states
		for(int i=0; i<this.numMyStates; i++) {
			myState s = new myState(i);
			// initialize action table
			HashMap<myAction, HashMap<myState, Double>> actionList = new HashMap<myAction, HashMap<myState, Double>>(this.numMyActions, 1.0f);
			// iterate over each possible action
			for(myAction a: this.getActions(s)) {
				// initialize inner state table
				HashMap<myState, Double> nextmyStateList = new HashMap<myState, Double>(this.numMyStates, 1.0f);
				// iterate over each possible next state
				for(myState nextS: this.getNextStates(s, a)) {
					double probability = this.td.probability(a.nextCity, nextS.taskCity);
					nextmyStateList.put(nextS, probability);
				}
				actionList.put(a, nextmyStateList);
			}
			this.transitionProbabilityTable.put(s, actionList);
		}
	}
	
	// run the value iteration to determine optimal policy
	private void valueIteration(double discount, double threshold) {
		// create a copy of the current state values to store temporary values
		HashMap<myState, Double> myStateValueCopy = new HashMap<myState, Double>(this.numMyStates, 1.0f);
		for(int i=0; i<this.numMyStates; i++) {
			myState s = new myState(i);
			myStateValueCopy.put(s, this.stateValues.get(s));
		}
		// delta we check against threshold to see when the values have converged
		double delta = 1000000000000.0;
		// iterate while delta > threshold
		do {
			// compute value for each state
			for(int i=0; i<this.numMyStates; i++) {
				// initialize max q value at something very small
				double maxQvalue = -1000000000000000.0;
				myState s = new myState(i);
				// iterate over available actions
				for(myAction a: this.getActions(s)) {
					// compute q value for state-action pair
					double qValue = this.rewardTable.get(s).get(a);
					// iterate over possible next states
					for(myState nextS: this.getNextStates(s, a)) {
						double transProb = this.transitionProbabilityTable.get(s).get(a).get(nextS);
						qValue+=discount*transProb*myStateValueCopy.get(nextS);
					}
					// check if this q value is best, if so, set the policy
					if(qValue>maxQvalue) {
						maxQvalue = qValue;
						this.policy.put(s, a);
					}
				}
				// set the new state value
				myStateValueCopy.put(s, maxQvalue);
			}
			// compute delta change in value list and update current value list
			delta = 0.0;
			for(int i=0; i<this.numMyStates; i++) {
				myState s = new myState(i);
				double vsDelta = this.stateValues.get(s) - myStateValueCopy.get(s);
				this.stateValues.put(s, myStateValueCopy.get(s));
				delta+= vsDelta*vsDelta;
			}
			//System.out.println(delta);
		} while(delta>threshold);
	}
	
	// Use policy to get optimal action in terms of logist class objects
	public Action optimalAction(Vehicle vehicle, Task availableTask) {
		Action action;
		City currentCity = vehicle.getCurrentCity();
		City nextCity;
		if(availableTask==null) {
			nextCity = null;
		}
		else {
			nextCity = availableTask.deliveryCity;
		}
		
		myState currentState = new myState(currentCity, nextCity);
		myAction optimalAction = this.policy.get(currentState);
		if(optimalAction.takeTask) {
			action = new Pickup(availableTask);
		}
		else {
			action = new Move(optimalAction.nextCity);
		}
		return action;
	}
	
}