/*
 * PDAToRHPDA.cpp
 *
 *  Created on: 23. 3. 2014
 *	  Author: Jan Travnicek
 */

#include "PDAToRHPDA.h"

#include <automaton/PDA/RealTimeHeightDeterministicDPDA.h>
#include <automaton/PDA/RealTimeHeightDeterministicNPDA.h>
#include <automaton/PDA/DPDA.h>
#include <automaton/PDA/NPDA.h>

#include <alphabet/BottomOfTheStackSymbol.h>

#include <set>
#include <map>
#include <queue>
#include <iterator>
#include <label/InitialStateLabel.h>
#include <common/createUnique.hpp>

#include <registration/CastRegistration.hpp>
#include <registration/AlgoRegistration.hpp>

namespace automaton {

automaton::RealTimeHeightDeterministicDPDA < > PDAToRHPDA::convert ( const automaton::RealTimeHeightDeterministicDPDA < > & pda ) {
	return pda;
}

auto PDAToRHPDARealTimeHeightDeterministicDPDA = registration::OverloadRegister < PDAToRHPDA, automaton::RealTimeHeightDeterministicDPDA < >, automaton::RealTimeHeightDeterministicDPDA < > > ( PDAToRHPDA::convert );

automaton::RealTimeHeightDeterministicNPDA < > PDAToRHPDA::convert ( const automaton::RealTimeHeightDeterministicNPDA < > & pda ) {
	return pda;
}

auto PDAToRHPDARealTimeHeightDeterministicNPDA = registration::OverloadRegister < PDAToRHPDA, automaton::RealTimeHeightDeterministicNPDA < >, automaton::RealTimeHeightDeterministicNPDA < > > ( PDAToRHPDA::convert );

automaton::RealTimeHeightDeterministicDPDA < > PDAToRHPDA::convert ( const automaton::DPDA < > & pda ) {
	DefaultStateType q0 = common::createUnique ( label::InitialStateLabel::instance < DefaultStateType > ( ), pda.getStates ( ) );

	RealTimeHeightDeterministicDPDA < > res ( q0, alphabet::BottomOfTheStackSymbol::instance < DefaultSymbolType > ( ) );

	res.setInputAlphabet ( pda.getInputAlphabet ( ) );

	for ( const auto & state : pda.getStates ( ) )
		res.addState ( state );

	res.setFinalStates ( pda.getFinalStates ( ) );
	std::set < DefaultSymbolType > pushdownStoreAlphabet = pda.getPushdownStoreAlphabet ( );
	pushdownStoreAlphabet.insert ( alphabet::BottomOfTheStackSymbol::instance < DefaultSymbolType > ( ) );
	res.setPushdownStoreAlphabet ( pushdownStoreAlphabet );

	res.addCallTransition ( q0, pda.getInitialState ( ), pda.getInitialSymbol ( ) );

	std::string us ( "us" );
	int i = 0;

	for ( const auto & transition : pda.getTransitions ( ) ) {
		const auto & to = transition.second;

		if ( ( std::get < 2 > ( transition.first ).size ( ) == 0 ) && ( to.second.size ( ) == 0 ) ) {
			res.addLocalTransition ( std::get < 0 > ( transition.first ), std::get < 1 > ( transition.first ), to.first );
		} else if ( ( std::get < 2 > ( transition.first ).size ( ) == 1 ) && ( to.second.size ( ) == 0 ) ) {
			res.addReturnTransition ( std::get < 0 > ( transition.first ), std::get < 1 > ( transition.first ), std::get < 2 > ( transition.first )[0], to.first );
		} else if ( ( std::get < 2 > ( transition.first ).size ( ) == 0 ) && ( to.second.size ( ) == 1 ) ) {
			res.addCallTransition ( std::get < 0 > ( transition.first ), std::get < 1 > ( transition.first ), to.first, to.second[0] );
		} else {
			int popPushIndex = 0;
			int popPushSymbols = std::get < 2 > ( transition.first ).size ( ) + to.second.size ( );

			DefaultStateType lastUS = common::createUnique ( DefaultStateType ( us + ext::to_string ( i ) ), res.getStates ( ) );
			for ( const DefaultSymbolType & pop :std::get < 2 > ( transition.first ) ) {
				DefaultStateType fromState = ( popPushIndex == 0 ) ? std::get < 0 > ( transition.first ) : lastUS;

				if ( popPushIndex != 0 ) lastUS = common::createUnique ( DefaultStateType ( us + ext::to_string ( ++i ) ), res.getStates ( ) );

				DefaultStateType toState = ( popPushIndex == popPushSymbols - 1 ) ? to.first : lastUS;

				res.addState ( fromState );
				res.addState ( toState );

				if ( popPushIndex == 0 )
					res.addReturnTransition ( fromState, std::get < 1 > ( transition.first ), pop, toState );
				else
					res.addReturnTransition ( fromState, pop, toState );

				popPushIndex++;
			}
			for ( const DefaultSymbolType & push : ext::make_reverse ( to.second ) ) {
				DefaultStateType fromState = ( popPushIndex == 0 ) ? std::get < 0 > ( transition.first ) : lastUS;

				if ( popPushIndex != 0 ) lastUS = common::createUnique ( DefaultStateType ( us + ext::to_string ( ++i ) ), res.getStates ( ) );

				DefaultStateType toState = ( popPushIndex == popPushSymbols - 1 ) ? to.first : lastUS;

				res.addState ( fromState );
				res.addState ( toState );

				if ( popPushIndex == 0 )
					res.addCallTransition ( fromState, std::get < 1 > ( transition.first ), toState, push );
				else
					res.addCallTransition ( fromState, toState, push );

				popPushIndex++;
			}
		}
	}

	return res;
}

auto PDAToRHPDADPDA = registration::OverloadRegister < PDAToRHPDA, automaton::RealTimeHeightDeterministicDPDA < >, automaton::DPDA < > > ( PDAToRHPDA::convert );

automaton::RealTimeHeightDeterministicNPDA < > PDAToRHPDA::convert ( const automaton::NPDA < > & pda ) {
	RealTimeHeightDeterministicNPDA < > res ( alphabet::BottomOfTheStackSymbol::instance < DefaultSymbolType > ( ) );

	res.setInputAlphabet ( pda.getInputAlphabet ( ) );
	res.setStates ( pda.getStates ( ) );
	res.setFinalStates ( pda.getFinalStates ( ) );
	std::set < DefaultSymbolType > pushdownStoreAlphabet = pda.getPushdownStoreAlphabet ( );
	pushdownStoreAlphabet.insert ( alphabet::BottomOfTheStackSymbol::instance < DefaultSymbolType > ( ) );
	res.setPushdownStoreAlphabet ( pushdownStoreAlphabet );

	DefaultStateType q0 = common::createUnique ( label::InitialStateLabel::instance < DefaultStateType > ( ), res.getStates ( ) );
	res.addState ( q0 );
	res.addInitialState ( q0 );

	res.addCallTransition ( q0, pda.getInitialState ( ), pda.getInitialSymbol ( ) );

	std::string us ( "us" );
	int i = 0;

	for ( const auto & transition : pda.getTransitions ( ) )
		for ( const auto & to : transition.second ) {
			if ( ( std::get < 2 > ( transition.first ).size ( ) == 0 ) && ( to.second.size ( ) == 0 ) ) {
				res.addLocalTransition ( std::get < 0 > ( transition.first ), std::get < 1 > ( transition.first ), to.first );
			} else if ( ( std::get < 2 > ( transition.first ).size ( ) == 1 ) && ( to.second.size ( ) == 0 ) ) {
				res.addReturnTransition ( std::get < 0 > ( transition.first ), std::get < 1 > ( transition.first ), std::get < 2 > ( transition.first )[0], to.first );
			} else if ( ( std::get < 2 > ( transition.first ).size ( ) == 0 ) && ( to.second.size ( ) == 1 ) ) {
				res.addCallTransition ( std::get < 0 > ( transition.first ), std::get < 1 > ( transition.first ), to.first, to.second[0] );
			} else {
				int popPushIndex = 0;
				int popPushSymbols = std::get < 2 > ( transition.first ).size ( ) + to.second.size ( );

				DefaultStateType lastUS = common::createUnique ( DefaultStateType ( us + ext::to_string ( i ) ), res.getStates ( ) );
				std::for_each ( std::get < 2 > ( transition.first ).begin ( ), std::get < 2 > ( transition.first ).end ( ), [&] ( const DefaultSymbolType & pop ) {
						DefaultStateType fromState = ( popPushIndex == 0 ) ? std::get < 0 > ( transition.first ) : lastUS;

						if ( popPushIndex != 0 ) lastUS = common::createUnique ( DefaultStateType ( us + ext::to_string ( ++i ) ), res.getStates ( ) );

						DefaultStateType toState = ( popPushIndex == popPushSymbols - 1 ) ? to.first : lastUS;

						res.addState ( fromState );
						res.addState ( toState );

						if ( popPushIndex == 0 )
							res.addReturnTransition ( fromState, std::get < 1 > ( transition.first ), pop, toState );
						else
							res.addReturnTransition ( fromState, pop, toState );

						popPushIndex++;
					} );
				std::for_each ( to.second.rbegin ( ), to.second.rend ( ), [&] ( const DefaultSymbolType & push ) {
						DefaultStateType fromState = ( popPushIndex == 0 ) ? std::get < 0 > ( transition.first ) : lastUS;

						if ( popPushIndex != 0 ) lastUS = common::createUnique ( DefaultStateType ( us + ext::to_string ( ++i ) ), res.getStates ( ) );

						DefaultStateType toState = ( popPushIndex == popPushSymbols - 1 ) ? to.first : lastUS;

						res.addState ( fromState );
						res.addState ( toState );

						if ( popPushIndex == 0 )
							res.addCallTransition ( fromState, std::get < 1 > ( transition.first ), toState, push );
						else
							res.addCallTransition ( fromState, toState, push );

						popPushIndex++;
					} );
			}
		}

	return res;
}

auto PDAToRHPDANPDA = registration::OverloadRegister < PDAToRHPDA, automaton::RealTimeHeightDeterministicNPDA < >, automaton::NPDA < > > ( PDAToRHPDA::convert );

automaton::Automaton PDAToRHPDA::convert ( const Automaton & automaton ) {
	return dispatch ( automaton.getData ( ) );
}

} /* namespace automaton */

namespace alib {

auto RealTimeHeightDeterministicDPDAFromDPDA = registration::CastRegister < automaton::RealTimeHeightDeterministicDPDA < >, automaton::DPDA < > > ( automaton::PDAToRHPDA::convert );
auto RealTimeHeightDeterministicNPDAFromNPDA = registration::CastRegister < automaton::RealTimeHeightDeterministicNPDA < >, automaton::NPDA < > > ( automaton::PDAToRHPDA::convert );

} /* namespace alib */