/*
 * ExactPatternMatchingAutomaton.cpp
 *
 *  Created on: 9. 2. 2014
 *      Author: Jan Travnicek
 */

#include "ExactPatternMatchingAutomaton.h"
#include "ExactSubtreeMatchingAutomaton.h"
#include <tree/properties/SubtreeJumpTable.h>

#include <tree/ranked/RankedTree.h>
#include <tree/ranked/RankedPattern.h>
#include <tree/ranked/PrefixRankedTree.h>
#include <tree/ranked/PrefixRankedPattern.h>
#include <tree/ranked/PrefixRankedBarTree.h>
#include <tree/ranked/PrefixRankedBarPattern.h>

#include <automaton/PDA/InputDrivenNPDA.h>
#include <automaton/PDA/VisiblyPushdownNPDA.h>
#include <automaton/PDA/NPDA.h>
#include <automaton/TA/NFTA.h>

#include <alphabet/BottomOfTheStackSymbol.h>

#include <deque>
#include <alphabet/RankedSymbol.h>
#include <registration/AlgoRegistration.hpp>

namespace arbology {

namespace exact {

automaton::Automaton ExactPatternMatchingAutomaton::construct ( const tree::Tree & pattern ) {
	return dispatch ( pattern.getData ( ) );
}

automaton::InputDrivenNPDA < > ExactPatternMatchingAutomaton::construct ( const tree::PrefixRankedTree < > & pattern ) {
	return ExactSubtreeMatchingAutomaton::construct ( pattern );
}

auto ExactPatternMatchingAutomatonPrefixRankedTree = registration::OverloadRegister < ExactPatternMatchingAutomaton, automaton::InputDrivenNPDA < >, tree::PrefixRankedTree < > > ( ExactPatternMatchingAutomaton::construct );

std::vector < DefaultSymbolType > computeRHS ( const tree::PrefixRankedPattern < > & pattern, const std::vector < int > & patternSubtreeJumpTable, int i ) {
	const std::vector < std::ranked_symbol < > > & content = pattern.getContent ( );

	unsigned rank = ( unsigned ) content[i].getRank ( );

	i++;

	std::vector < DefaultSymbolType > res;

	for ( unsigned ranki = 0; ranki < rank; ranki++ ) {
		if ( content[i] == pattern.getSubtreeWildcard ( ) ) {
			res.push_back ( DefaultSymbolType ( 'R' ) );
			i++;
		} else {
			res.push_back ( DefaultSymbolType ( 'T' ) );

			i = patternSubtreeJumpTable[i];
		}
	}

	return res;
}

automaton::NPDA < > ExactPatternMatchingAutomaton::construct ( const tree::PrefixRankedPattern < > & pattern ) {
	automaton::NPDA < > res ( DefaultStateType ( 0 ), DefaultSymbolType ( 'T' ) );

	for ( const std::ranked_symbol < > & symbol : pattern.getAlphabet ( ) ) {
		if ( symbol == pattern.getSubtreeWildcard ( ) ) continue;

		res.addInputSymbol ( DefaultSymbolType ( alphabet::RankedSymbol < > { symbol } ) );
	}

	res.setPushdownStoreAlphabet ( { DefaultSymbolType ( 'T' ), DefaultSymbolType ( 'R' ) } );

	for ( const std::ranked_symbol < > & symbol : pattern.getAlphabet ( ) ) {
		if ( symbol == pattern.getSubtreeWildcard ( ) ) continue;

		res.addTransition ( DefaultStateType ( 0 ), DefaultSymbolType ( alphabet::RankedSymbol < > { symbol } ), std::vector < DefaultSymbolType > ( 1, DefaultSymbolType ( 'T' ) ), DefaultStateType ( 0 ), std::vector < DefaultSymbolType > ( ( size_t ) symbol.getRank ( ), DefaultSymbolType ( 'T' ) ) );
	}

	std::vector < int > patternSubtreeJumpTable = tree::properties::SubtreeJumpTable::compute ( pattern );

	int i = 1;

	for ( const std::ranked_symbol < > & symbol : pattern.getContent ( ) ) {
		res.addState ( DefaultStateType ( i ) );

		if ( symbol == pattern.getSubtreeWildcard ( ) )
			for ( const std::ranked_symbol < > & alphabetSymbol : pattern.getAlphabet ( ) ) {
				if ( alphabetSymbol == pattern.getSubtreeWildcard ( ) ) continue;

				if ( ( unsigned ) alphabetSymbol.getRank ( ) == 0 ) {
					res.addTransition ( DefaultStateType ( i - 1 ), DefaultSymbolType ( alphabet::RankedSymbol < > { alphabetSymbol } ), std::vector < DefaultSymbolType > ( 1, DefaultSymbolType ( 'T' ) ), DefaultStateType ( i - 1 ), std::vector < DefaultSymbolType > { } );

					res.addTransition ( DefaultStateType ( i - 1 ), DefaultSymbolType ( alphabet::RankedSymbol < > { alphabetSymbol } ), std::vector < DefaultSymbolType > ( 1, DefaultSymbolType ( 'R' ) ), DefaultStateType ( i ), std::vector < DefaultSymbolType > { } );
				} else {
					std::vector < DefaultSymbolType > push ( ( unsigned ) alphabetSymbol.getRank ( ), DefaultSymbolType ( 'T' ) );
					res.addTransition ( DefaultStateType ( i - 1 ), DefaultSymbolType ( alphabet::RankedSymbol < > { alphabetSymbol } ), std::vector < DefaultSymbolType > ( 1, DefaultSymbolType ( 'T' ) ), DefaultStateType ( i - 1 ), push );

					push[ ( unsigned ) alphabetSymbol.getRank ( ) - 1] = DefaultSymbolType ( 'R' );
					res.addTransition ( DefaultStateType ( i - 1 ), DefaultSymbolType ( alphabet::RankedSymbol < > { alphabetSymbol } ), std::vector < DefaultSymbolType > ( 1, DefaultSymbolType ( 'R' ) ), DefaultStateType ( i - 1 ), push );
				}
			}

		else
			res.addTransition ( DefaultStateType ( i - 1 ), DefaultSymbolType ( alphabet::RankedSymbol < > { symbol } ), std::vector < DefaultSymbolType > ( 1, DefaultSymbolType ( 'T' ) ), DefaultStateType ( i ), computeRHS ( pattern, patternSubtreeJumpTable, i - 1 ) );

		i++;
	}

	res.addFinalState ( DefaultStateType ( i - 1 ) );
	return res;
}

auto ExactPatternMatchingAutomatonPrefixRankedPattern = registration::OverloadRegister < ExactPatternMatchingAutomaton, automaton::NPDA < >, tree::PrefixRankedPattern < > > ( ExactPatternMatchingAutomaton::construct );

automaton::InputDrivenNPDA < > ExactPatternMatchingAutomaton::construct ( const tree::PrefixRankedBarTree < > & pattern ) {
	return ExactSubtreeMatchingAutomaton::construct ( pattern );
}

auto ExactPatternMatchingAutomatonPrefixRankedBarTree = registration::OverloadRegister < ExactPatternMatchingAutomaton, automaton::InputDrivenNPDA < >, tree::PrefixRankedBarTree < > > ( ExactPatternMatchingAutomaton::construct );

automaton::VisiblyPushdownNPDA < > ExactPatternMatchingAutomaton::construct ( const tree::PrefixRankedBarPattern < > & pattern ) {
	automaton::VisiblyPushdownNPDA < > res ( alphabet::BottomOfTheStackSymbol::instance < DefaultSymbolType > ( ) );

	res.addState ( DefaultStateType ( 0 ) );
	res.addInitialState ( DefaultStateType ( 0 ) );

	for ( const std::ranked_symbol < > & symbol : pattern.getAlphabet ( ) ) {
		if ( ( symbol == pattern.getSubtreeWildcard ( ) ) || ( symbol == pattern.getVariablesBar ( ) ) ) continue;

		if ( pattern.getBars ( ).count ( symbol ) )
			res.addReturnInputSymbol ( DefaultSymbolType ( alphabet::RankedSymbol < > { symbol } ) );
		else
			res.addCallInputSymbol ( DefaultSymbolType ( alphabet::RankedSymbol < > { symbol } ) );
	}

	res.setPushdownStoreAlphabet ( { alphabet::BottomOfTheStackSymbol::instance < DefaultSymbolType > ( ) , DefaultSymbolType ( 'T' ), DefaultSymbolType ( 'R' ) } );

	for ( const std::ranked_symbol < > & symbol : pattern.getAlphabet ( ) ) {
		if ( ( symbol == pattern.getSubtreeWildcard ( ) ) || ( symbol == pattern.getVariablesBar ( ) ) ) continue;

		if ( pattern.getBars ( ).count ( symbol ) )
			res.addReturnTransition ( DefaultStateType ( 0 ), DefaultSymbolType ( alphabet::RankedSymbol < > { symbol } ), DefaultSymbolType ( 'T' ), DefaultStateType ( 0 ) );
		else
			res.addCallTransition ( DefaultStateType ( 0 ), DefaultSymbolType ( alphabet::RankedSymbol < > { symbol } ), DefaultStateType ( 0 ), DefaultSymbolType ( 'T' ) );
	}

	int i = 1;

	for ( const std::ranked_symbol < > & symbol : pattern.getContent ( ) ) {
		res.addState ( DefaultStateType ( i ) );

		if ( symbol == pattern.getSubtreeWildcard ( ) ) {
			for ( const std::ranked_symbol < > & alphabetSymbol : pattern.getAlphabet ( ) ) {
				if ( ( alphabetSymbol == pattern.getSubtreeWildcard ( ) ) || ( alphabetSymbol == pattern.getVariablesBar ( ) ) || ( pattern.getBars ( ).count ( alphabetSymbol ) ) ) continue;

				res.addCallTransition ( DefaultStateType ( i - 1 ), DefaultSymbolType ( alphabet::RankedSymbol < > { alphabetSymbol } ), DefaultStateType ( i ), DefaultSymbolType ( 'R' ) );
			}
		} else if ( symbol == pattern.getVariablesBar ( ) ) {
			for ( const std::ranked_symbol < > & alphabetSymbol : pattern.getAlphabet ( ) ) {
				if ( ( alphabetSymbol == pattern.getSubtreeWildcard ( ) ) || ( alphabetSymbol == pattern.getVariablesBar ( ) ) ) continue;

				if ( pattern.getBars ( ).count ( alphabetSymbol ) )
					res.addReturnTransition ( DefaultStateType ( i - 1 ), DefaultSymbolType ( alphabet::RankedSymbol < > { alphabetSymbol } ), DefaultSymbolType ( 'T' ), DefaultStateType ( i - 1 ) );
				else
					res.addCallTransition ( DefaultStateType ( i - 1 ), DefaultSymbolType ( alphabet::RankedSymbol < > { alphabetSymbol } ), DefaultStateType ( i - 1 ), DefaultSymbolType ( 'T' ) );
			}

			for ( const std::ranked_symbol < > & alphabetSymbol : pattern.getAlphabet ( ) ) {
				if ( ( alphabetSymbol == pattern.getSubtreeWildcard ( ) ) || ( alphabetSymbol == pattern.getVariablesBar ( ) ) || ( ! pattern.getBars ( ).count ( alphabetSymbol ) ) ) continue;

				res.addReturnTransition ( DefaultStateType ( i - 1 ), DefaultSymbolType ( alphabet::RankedSymbol < > { alphabetSymbol } ), DefaultSymbolType ( 'R' ), DefaultStateType ( i ) );
			}
		} else if ( pattern.getBars ( ).count ( symbol ) ) {
			res.addReturnTransition ( DefaultStateType ( i - 1 ), DefaultSymbolType ( alphabet::RankedSymbol < > { symbol } ), DefaultSymbolType ( 'T' ), DefaultStateType ( i ) );
		} else {
			res.addCallTransition ( DefaultStateType ( i - 1 ), DefaultSymbolType ( alphabet::RankedSymbol < > { symbol } ), DefaultStateType ( i ), DefaultSymbolType ( 'T' ) );
		}

		i++;
	}

	res.addFinalState ( DefaultStateType ( i - 1 ) );
	return res;
}

auto ExactPatternMatchingAutomatonPrefixRankedBarPattern = registration::OverloadRegister < ExactPatternMatchingAutomaton, automaton::VisiblyPushdownNPDA < >, tree::PrefixRankedBarPattern < > > ( ExactPatternMatchingAutomaton::construct );

automaton::NFTA < > ExactPatternMatchingAutomaton::construct ( const tree::RankedTree < > & pattern ) {
	return ExactSubtreeMatchingAutomaton::construct ( pattern );
}

auto ExactPatternMatchingAutomatonRankedTree = registration::OverloadRegister < ExactPatternMatchingAutomaton, automaton::NFTA < >, tree::RankedTree < > > ( ExactPatternMatchingAutomaton::construct );

DefaultStateType constructRecursivePattern ( const std::tree < std::ranked_symbol < > > & node, automaton::NFTA < > & res, const std::ranked_symbol < > & subtreeWildcard, const DefaultStateType & loopState, int & nextState ) {
	if ( node.getData ( ) == subtreeWildcard ) {
		DefaultStateType state = DefaultStateType ( nextState++ );
		res.addState ( state );

		for ( const std::ranked_symbol < > & symbol : res.getInputAlphabet ( ) ) {
			std::vector < DefaultStateType > states;
			states.reserve ( ( size_t ) symbol.getRank ( ) );

			for ( unsigned i = 0; i < ( unsigned ) symbol.getRank ( ); i++ )
				states.push_back ( loopState );

			res.addTransition ( symbol, states, state );
		}

		return state;
	} else {
		std::vector < DefaultStateType > states;
		states.reserve ( ( size_t ) node.getData ( ).getRank ( ) );

		for ( const std::tree < std::ranked_symbol < > > & child : node.getChildren ( ) )
			states.push_back ( constructRecursivePattern ( child, res, subtreeWildcard, loopState, nextState ) );

		DefaultStateType state = DefaultStateType ( nextState++ );
		res.addState ( state );
		res.addTransition ( node.getData ( ), states, state );
		return state;
	}
}

automaton::NFTA < > ExactPatternMatchingAutomaton::construct ( const tree::RankedPattern < > & pattern ) {
	std::set < std::ranked_symbol < > > alphabet = pattern.getAlphabet ( );

	alphabet.erase ( pattern.getSubtreeWildcard ( ) );

	automaton::NFTA < > res;
	res.setInputAlphabet ( alphabet );

	int nextState = 0;

	DefaultStateType loopState = DefaultStateType ( nextState++ );
	res.addState ( loopState );

	for ( const std::ranked_symbol < > & symbol : res.getInputAlphabet ( ) ) {
		std::vector < DefaultStateType > states;
		states.reserve ( ( size_t ) symbol.getRank ( ) );

		for ( unsigned i = 0; i < ( unsigned ) symbol.getRank ( ); i++ )
			states.push_back ( loopState );

		res.addTransition ( symbol, states, loopState );
	}

	res.addFinalState ( constructRecursivePattern ( pattern.getContent ( ), res, pattern.getSubtreeWildcard ( ), loopState, nextState ) );
	return res;
}

auto ExactPatternMatchingAutomatonRankedPattern = registration::OverloadRegister < ExactPatternMatchingAutomaton, automaton::NFTA < >, tree::RankedPattern < > > ( ExactPatternMatchingAutomaton::construct );

} /* namespace exact */

} /* namespace arbology */