/*
 * UnreachableStatesRemover.h
 *
 *  Created on: 23. 3. 2014
 *	  Author: Tomas Pecka
 */

#ifndef UNREACHABLE_STATES_REMOVER_H_
#define UNREACHABLE_STATES_REMOVER_H_

#include <core/multipleDispatch.hpp>
#include <automaton/Automaton.h>

#include <automaton/FSM/ExtendedNFA.h>
#include <automaton/FSM/CompactNFA.h>
#include <automaton/FSM/EpsilonNFA.h>
#include <automaton/FSM/MultiInitialStateNFA.h>
#include <automaton/FSM/NFA.h>
#include <automaton/FSM/DFA.h>
#include <automaton/TA/DFTA.h>

#include "../properties/ReachableStates.h"

namespace automaton {

namespace simplify {

class UnreachableStatesRemover : public std::SingleDispatch<UnreachableStatesRemover, automaton::Automaton, const automaton::AutomatonBase &> {
public:
	static automaton::Automaton remove( const automaton::Automaton & automaton );

	/**
	 * Removes dead states from FSM. Melichar 2.29
	 */
	template<class T, class SymbolType = typename automaton::SymbolTypeOfAutomaton < T >, class StateType = typename automaton::StateTypeOfAutomaton < T > >
	static T remove( const T & automaton );
	template < class SymbolType, class StateType >
	static automaton::DFA < SymbolType, StateType > remove( const automaton::DFA < SymbolType, StateType > & fsm );
	template < class SymbolType, class StateType >
	static automaton::MultiInitialStateNFA < SymbolType, StateType > remove( const automaton::MultiInitialStateNFA < SymbolType, StateType > & fsm );
	template < class SymbolType, class RankType, class StateType >
	static automaton::DFTA < SymbolType, RankType, StateType > remove( const automaton::DFTA < SymbolType, RankType, StateType > & dfta );
};

template<class T, class SymbolType, class StateType >
T UnreachableStatesRemover::remove( const T & fsm ) {
	// 1a
	std::set<StateType> Qa = automaton::properties::ReachableStates::reachableStates( fsm );

	// 2
	T M(fsm.getInitialState());

	for( const auto & q : Qa )
		M.addState( q );

	for( const auto & a : fsm.getInputAlphabet( ) )
		M.addInputSymbol( a );

	for( const auto & transition : fsm.getTransitions( ) )
		if( Qa.count( transition.first.first ) )
			for(const auto& to : transition.second )
				M.addTransition( transition.first.first, transition.first.second, to );

	std::set<StateType> intersect;
	std::set_intersection( fsm.getFinalStates( ).begin(), fsm.getFinalStates( ).end(), Qa.begin( ), Qa.end( ), std::inserter( intersect, intersect.begin( ) ) );
	for( auto const & state : intersect )
		M.addFinalState( state );

	return M;
}

template < class SymbolType, class StateType >
automaton::DFA < SymbolType, StateType > UnreachableStatesRemover::remove( const automaton::DFA < SymbolType, StateType > & fsm ) {
	// 1a
	std::set<StateType> Qa = automaton::properties::ReachableStates::reachableStates( fsm );

	// 2
	automaton::DFA < SymbolType, StateType > M(fsm.getInitialState() );

	for( const auto & q : Qa )
		M.addState( q );

	for( const auto & a : fsm.getInputAlphabet( ) )
		M.addInputSymbol( a );

	for( const auto & transition : fsm.getTransitions( ) )
		if( Qa.count( transition.first.first ) )
			M.addTransition( transition.first.first, transition.first.second, transition.second );

	std::set<StateType> intersect;
	std::set_intersection( fsm.getFinalStates( ).begin(), fsm.getFinalStates( ).end(), Qa.begin( ), Qa.end( ), std::inserter( intersect, intersect.begin( ) ) );
	for( auto const & state : intersect )
		M.addFinalState( state );

	return M;
}

template < class SymbolType, class StateType >
automaton::MultiInitialStateNFA < SymbolType, StateType > UnreachableStatesRemover::remove( const automaton::MultiInitialStateNFA < SymbolType, StateType > & fsm ) {
	// 1a
	std::set<StateType> Qa = automaton::properties::ReachableStates::reachableStates( fsm );

	// 2
	automaton::MultiInitialStateNFA < SymbolType, StateType > M;

	for( const auto & q : Qa )
		M.addState( q );

	M.setInitialStates( fsm.getInitialStates() );

	for( const auto & a : fsm.getInputAlphabet( ) )
		M.addInputSymbol( a );

	for( const auto & transition : fsm.getTransitions( ) )
		if( Qa.count( transition.first.first ) )
			for(const auto& to : transition.second )
				M.addTransition( transition.first.first, transition.first.second, to );

	std::set<StateType> intersect;
	std::set_intersection( fsm.getFinalStates( ).begin(), fsm.getFinalStates( ).end(), Qa.begin( ), Qa.end( ), std::inserter( intersect, intersect.begin( ) ) );
	for( auto const & state : intersect )
		M.addFinalState( state );

	return M;
}

template < class SymbolType, class RankType, class StateType >
automaton::DFTA < SymbolType, RankType, StateType > UnreachableStatesRemover::remove( const automaton::DFTA < SymbolType, RankType, StateType > & dfta ) {
	automaton::DFTA < SymbolType, RankType, StateType > res;
	res.setInputAlphabet(dfta.getInputAlphabet());

	typedef std::pair < const std::pair < std::ranked_symbol < SymbolType, RankType >, std::vector < StateType > >, StateType > Transition;
	std::vector<std::pair<const Transition *, int>> transitionsUnreachableCount;
	transitionsUnreachableCount.reserve(dfta.getTransitions().size());

	//for a state, transitions with unreachable count (initially all unreachable) and number of occurences of this state (at least 1)
	std::map<StateType, std::map<std::pair<const Transition *, int> *, int>> stateOccurences;
	std::deque<StateType> queue;
	for(const auto & transition : dfta.getTransitions()) {
		if (transition.first.second.empty()) {
			queue.push_back(transition.second);
			res.addState(transition.second);
			res.addTransition(transition.first.first, transition.first.second, transition.second);
		} else {
			transitionsUnreachableCount.push_back({&transition, transition.first.second.size()});
			for (const auto & state : transition.first.second) {
				auto & occurences = stateOccurences[state];
				auto it = occurences.find(&transitionsUnreachableCount.back());
				if (it == occurences.end()) occurences[&transitionsUnreachableCount.back()] = 1;
				else it->second++;
			}
		}
	}

	while(!queue.empty()) {
		const auto & occurences = stateOccurences[queue.front()];
		queue.pop_front();
		for (const auto & occurence : occurences) {
			int & unreachableCount = occurence.first -> second;
			const StateType & to = occurence.first -> first -> second;
			unreachableCount -=  occurence.second;
			if (unreachableCount == 0) {
				if (res.addState(to)) {
					queue.push_back(to);
				}
			}
		}
	}

	for (const auto & state : res.getStates()) {
		if (dfta.getFinalStates().count(state) != 0) res.addFinalState(state);
	}
	for (const auto & transitionUnreachableCount : transitionsUnreachableCount) {
		if (transitionUnreachableCount.second == 0) {
			const Transition transition = *(transitionUnreachableCount.first);
			res.addTransition(transition.first.first, transition.first.second, transition.second);
		}
	}
	return res;
}

} /* namespace simplify */

} /* namespace automaton */

#endif /* UNREACHABLE_STATES_REMOVER_H_ */