From 4bc296c08a54a395fbfd0f4738e6cbaabf32c1f7 Mon Sep 17 00:00:00 2001
From: Jan Travnicek <Jan.Travnicek@fit.cvut.cz>
Date: Tue, 15 Jan 2019 13:44:32 +0100
Subject: [PATCH] speedup generate strings up to given length from grammar

---
 .../grammar/generate/GenerateUpToLength.cpp   | 10 +--
 .../src/grammar/generate/GenerateUpToLength.h | 76 ++++++++++++++-----
 .../generate/GrammarGenerateUpToLength.cpp    | 74 ++++++++++++++++--
 alib2common/src/PrimitiveRegistrator.cpp      |  3 +
 alib2data/src/PrimitiveRegistrator.cpp        |  3 +
 .../src/extensions/container/forward_tree.hpp |  2 +-
 alib2std/src/extensions/container/tree.hpp    |  2 +-
 alib2std/src/extensions/container/trie.hpp    |  2 +-
 tests.anormalize.sh                           |  2 +-
 9 files changed, 139 insertions(+), 35 deletions(-)

diff --git a/alib2algo/src/grammar/generate/GenerateUpToLength.cpp b/alib2algo/src/grammar/generate/GenerateUpToLength.cpp
index a481ff3eba..9a0e6b4ad0 100644
--- a/alib2algo/src/grammar/generate/GenerateUpToLength.cpp
+++ b/alib2algo/src/grammar/generate/GenerateUpToLength.cpp
@@ -12,11 +12,11 @@ namespace grammar {
 
 namespace generate {
 
-auto GenerateUpToLengthEpsilonFreeCFG = registration::AbstractRegister < GenerateUpToLength, ext::set < string::LinearString < > >, const grammar::EpsilonFreeCFG < > &, unsigned > ( GenerateUpToLength::generate );
-auto GenerateUpToLengthGNF = registration::AbstractRegister < GenerateUpToLength, ext::set < string::LinearString < > >, const grammar::GNF < > &, unsigned > ( GenerateUpToLength::generate );
-auto GenerateUpToLengthCNF = registration::AbstractRegister < GenerateUpToLength, ext::set < string::LinearString < > >, const grammar::CNF < > &, unsigned > ( GenerateUpToLength::generate );
-auto GenerateUpToLengthLeftRG = registration::AbstractRegister < GenerateUpToLength, ext::set < string::LinearString < > >, const grammar::LeftRG < > &, unsigned > ( GenerateUpToLength::generate );
-auto GenerateUpToLengthRightRG = registration::AbstractRegister < GenerateUpToLength, ext::set < string::LinearString < > >, const grammar::RightRG < > &, unsigned > ( GenerateUpToLength::generate );
+auto GenerateUpToLengthEpsilonFreeCFG = registration::AbstractRegister < GenerateUpToLength, ext::trie < DefaultSymbolType, bool >, const grammar::EpsilonFreeCFG < > &, unsigned > ( GenerateUpToLength::generate );
+auto GenerateUpToLengthGNF = registration::AbstractRegister < GenerateUpToLength, ext::trie < DefaultSymbolType, bool >, const grammar::GNF < > &, unsigned > ( GenerateUpToLength::generate );
+auto GenerateUpToLengthCNF = registration::AbstractRegister < GenerateUpToLength, ext::trie < DefaultSymbolType, bool >, const grammar::CNF < > &, unsigned > ( GenerateUpToLength::generate );
+auto GenerateUpToLengthLeftRG = registration::AbstractRegister < GenerateUpToLength, ext::trie < DefaultSymbolType, bool >, const grammar::LeftRG < > &, unsigned > ( GenerateUpToLength::generate );
+auto GenerateUpToLengthRightRG = registration::AbstractRegister < GenerateUpToLength, ext::trie < DefaultSymbolType, bool >, const grammar::RightRG < > &, unsigned > ( GenerateUpToLength::generate );
 
 } /* namespace generate */
 
diff --git a/alib2algo/src/grammar/generate/GenerateUpToLength.h b/alib2algo/src/grammar/generate/GenerateUpToLength.h
index 5007490124..d1efd89586 100644
--- a/alib2algo/src/grammar/generate/GenerateUpToLength.h
+++ b/alib2algo/src/grammar/generate/GenerateUpToLength.h
@@ -12,6 +12,8 @@
 #include <string/LinearString.h>
 #include <alib/set>
 #include <alib/deque>
+#include <alib/trie>
+#include <alib/iterator>
 
 #include <grammar/ContextFree/EpsilonFreeCFG.h>
 #include <grammar/ContextFree/GNF.h>
@@ -29,30 +31,55 @@ namespace generate {
  * Implements algorithms from Melichar, chapter 3.3
  */
 class GenerateUpToLength {
+	template < class TerminalSymbolType >
+	static bool pruneTrie ( ext::trie < TerminalSymbolType, bool > & trie ) {
+		for ( typename std::map < TerminalSymbolType, ext::trie < TerminalSymbolType, bool > >::iterator iter = trie.getChildren ( ).begin ( ); iter != trie.getChildren ( ).end ( ); )
+			if ( pruneTrie ( iter->second ) )
+				iter = trie.erase ( iter );
+			else
+				++ iter;
+
+		return trie.getData ( ) == false && trie.getChildren ( ).size ( ) == 0;
+	}
+
 public:
 	template < class T, class TerminalSymbolType = typename grammar::TerminalSymbolTypeOfGrammar < T >, class NontermimnalSymbolType = typename grammar::NonterminalSymbolTypeOfGrammar < T > >
-	static ext::set < string::LinearString < TerminalSymbolType > > generate ( const T & grammar, unsigned length );
+	static ext::trie < TerminalSymbolType, bool > generate ( const T & grammar, unsigned length );
 };
 
 template < class T, class TerminalSymbolType, class NonterminalSymbolType >
-ext::set < string::LinearString < TerminalSymbolType > > GenerateUpToLength::generate ( const T & grammar, unsigned length ) {
-	ext::set < string::LinearString < TerminalSymbolType > > res;
+ext::trie < TerminalSymbolType, bool > GenerateUpToLength::generate ( const T & grammar, unsigned length ) {
+	ext::trie < TerminalSymbolType, bool > res ( false );
 
 	ext::map < NonterminalSymbolType, ext::set < ext::vector < ext::variant < TerminalSymbolType, NonterminalSymbolType > > > > rules = grammar::RawRules::getRawRules ( grammar );
 	if ( grammar.getGeneratesEpsilon ( ) ) {
-		res.insert ( string::LinearString < TerminalSymbolType > { } );
+		res.getData ( ) = true; // TODO improve interface
 		rules [ grammar.getInitialSymbol ( ) ].erase ( ext::vector < ext::variant < TerminalSymbolType, NonterminalSymbolType > > { } );
 	}
 
-	ext::deque < std::pair < ext::vector < TerminalSymbolType >, ext::vector < ext::variant < TerminalSymbolType, NonterminalSymbolType > > > > data;
-	data.push_back ( std::make_pair ( ext::vector < TerminalSymbolType > { }, ext::vector < ext::variant < TerminalSymbolType, NonterminalSymbolType > > { ext::variant < TerminalSymbolType, NonterminalSymbolType > ( grammar.getInitialSymbol ( ) ) } ) );
+	struct Node {
+		ext::variant < TerminalSymbolType, NonterminalSymbolType > m_symbol;
+		std::shared_ptr < Node > m_parent;
+		unsigned m_depth;
+
+		Node ( ext::variant < TerminalSymbolType, NonterminalSymbolType > symbol, std::shared_ptr < Node > parent, unsigned depth ) : m_symbol ( std::move ( symbol ) ), m_parent ( std::move ( parent ) ), m_depth ( depth ) {
+		}
+	};
+
+	ext::deque < std::tuple < typename ext::trie < TerminalSymbolType, bool > *, unsigned, std::shared_ptr < Node > > > data;
+	data.push_back ( std::make_tuple ( & res, 0, std::make_shared < Node > ( grammar.getInitialSymbol ( ), nullptr, 1 ) ) );
 
 	while ( ! data.empty ( ) ) {
-		std::pair < ext::vector < TerminalSymbolType >, ext::vector < ext::variant < TerminalSymbolType, NonterminalSymbolType > > > item = std::move ( data.back ( ) );
+		unsigned generatedLen = std::get < 1 > ( data.back ( ) );
+		typename ext::trie < TerminalSymbolType, bool > * trieNode = std::get < 0 > ( data.back ( ) );
+
+		std::shared_ptr < Node > stackTop = std::get < 2 > ( data.back ( ) );
+		ext::variant < TerminalSymbolType, NonterminalSymbolType > top = stackTop->m_symbol;
+		unsigned stackSize = stackTop->m_depth - 1;
+		stackTop = stackTop->m_parent;
+
 		data.pop_back ( );
 
-		ext::variant < TerminalSymbolType, NonterminalSymbolType > top = std::move ( item.second.back ( ) );
-		item.second.pop_back ( );
 		if ( ! top.template is < NonterminalSymbolType > ( ) ) // maybe not needed
 			continue;
 		auto rule = rules.find ( top.template get < NonterminalSymbolType > ( ) );
@@ -60,25 +87,32 @@ ext::set < string::LinearString < TerminalSymbolType > > GenerateUpToLength::gen
 			continue;
 
 		for ( const ext::vector < ext::variant < TerminalSymbolType, NonterminalSymbolType > > & rhs : rule->second ) {
-			if ( item.first.size ( ) + item.second.size ( ) + rhs.size ( ) > length ) continue;
-
-			ext::vector < ext::variant < TerminalSymbolType, NonterminalSymbolType > > newStack ( item.second );
-			ext::vector < TerminalSymbolType > newString ( item.first );
-			newStack.insert ( newStack.end ( ), rhs.rbegin ( ), rhs.rend ( ) );
-
-			while ( ! newStack.empty ( ) && grammar.getTerminalAlphabet ( ).count ( newStack.back ( ) ) ) {
-				newString.push_back ( std::move ( newStack.back ( ).template get < TerminalSymbolType > ( ) ) );
-				newStack.pop_back ( );
+			if ( generatedLen + stackSize + rhs.size ( ) > length )
+				continue;
+
+			std::shared_ptr < Node > newStackTop = stackTop;
+			for ( const ext::variant < TerminalSymbolType, NonterminalSymbolType > & symbol : ext::make_reverse ( rhs ) )
+				newStackTop = std::make_shared < Node > ( symbol, newStackTop, newStackTop ? newStackTop->m_depth + 1 : 1 );
+
+			typename ext::trie < TerminalSymbolType, bool > * newTrieNode = trieNode;
+			unsigned newGeneratedLen = generatedLen;
+			while ( newStackTop != nullptr && grammar.getTerminalAlphabet ( ).count ( newStackTop->m_symbol ) ) {
+				newTrieNode->insert ( newStackTop->m_symbol.template get < TerminalSymbolType > ( ), ext::trie < TerminalSymbolType, bool > ( false ) );
+				newTrieNode = & newTrieNode->getChildren ( ).at ( newStackTop->m_symbol.template get < TerminalSymbolType > ( ) );
+				newStackTop = newStackTop->m_parent;
+				newGeneratedLen += 1;
 			}
 
-			if ( newStack.empty ( ) ) {
-				res.insert ( string::LinearString < TerminalSymbolType > ( newString ) );
+			if ( newStackTop == nullptr ) {
+				newTrieNode->getData ( ) = true;
 			} else {
-				data.push_back ( std::make_pair ( std::move ( newString ), std::move ( newStack ) ) );
+				data.push_back ( std::make_tuple ( newTrieNode, newGeneratedLen, newStackTop ) );
 			}
 		}
 	}
 
+	pruneTrie ( res );
+
 	return res;
 }
 
diff --git a/alib2algo/test-src/grammar/generate/GrammarGenerateUpToLength.cpp b/alib2algo/test-src/grammar/generate/GrammarGenerateUpToLength.cpp
index 023d3c1a74..76ac22a25b 100644
--- a/alib2algo/test-src/grammar/generate/GrammarGenerateUpToLength.cpp
+++ b/alib2algo/test-src/grammar/generate/GrammarGenerateUpToLength.cpp
@@ -17,6 +17,15 @@ void GrammarGenerateUpToLength::setUp() {
 void GrammarGenerateUpToLength::tearDown() {
 }
 
+unsigned countStrings ( const ext::trie < DefaultSymbolType, bool > & node ) {
+	unsigned res = node.getData ( );
+
+	for ( const auto & child : node.getChildren ( ) )
+		res += countStrings ( child.second );
+
+	return res;
+}
+
 void GrammarGenerateUpToLength::testGenerate1() {
 	DefaultSymbolType S = DefaultSymbolType("S");
 	DefaultSymbolType A = DefaultSymbolType("A");
@@ -34,7 +43,28 @@ void GrammarGenerateUpToLength::testGenerate1() {
 
 	ext::set<string::LinearString < >> strings;
 
-	CPPUNIT_ASSERT(strings == grammar::generate::GenerateUpToLength::generate(grammar1, 5));
+	ext::trie < DefaultSymbolType, bool > generated = grammar::generate::GenerateUpToLength::generate(grammar1, 5);
+
+	generated.nicePrint ( std::cout );
+	std::cout << std::endl;
+
+	for(const string::LinearString < > & str : strings) {
+		bool flag = true;
+		ext::trie < DefaultSymbolType, bool > * node = & generated;
+		for ( const DefaultSymbolType & symbol : str.getContent ( ) ) {
+			auto iter = node->getChildren ( ).find ( symbol );
+			if ( iter == node->getChildren ( ).end ( ) ) {
+				flag = false;
+				break;
+			}
+			node = & iter->second;
+		}
+
+		CPPUNIT_ASSERT ( flag && node->getData ( ) == true );
+		std::cout << factory::StringDataFactory::toString ( str ) << std::endl;
+	}
+
+	CPPUNIT_ASSERT(strings.size ( ) == countStrings ( generated ) );
 }
 
 void GrammarGenerateUpToLength::testGenerate2() {
@@ -86,11 +116,28 @@ void GrammarGenerateUpToLength::testGenerate2() {
 	strings.insert(string::LinearString < >(ext::vector<DefaultSymbolType>{a, c, d}));
 	strings.insert(string::LinearString < >(ext::vector<DefaultSymbolType>{b, c, d}));
 
-	for(const string::LinearString < >& str : grammar::generate::GenerateUpToLength::generate(grammar1, 3)) {
+	ext::trie < DefaultSymbolType, bool > generated = grammar::generate::GenerateUpToLength::generate(grammar1, 3);
+
+	generated.nicePrint ( std::cout );
+	std::cout << std::endl;
+
+	for(const string::LinearString < > & str : strings) {
+		bool flag = true;
+		ext::trie < DefaultSymbolType, bool > * node = & generated;
+		for ( const DefaultSymbolType & symbol : str.getContent ( ) ) {
+			auto iter = node->getChildren ( ).find ( symbol );
+			if ( iter == node->getChildren ( ).end ( ) ) {
+				flag = false;
+				break;
+			}
+			node = & iter->second;
+		}
+
+		CPPUNIT_ASSERT ( flag && node->getData ( ) == true );
 		std::cout << factory::StringDataFactory::toString ( str ) << std::endl;
 	}
 
-	CPPUNIT_ASSERT(strings == grammar::generate::GenerateUpToLength::generate(grammar1, 3));
+	CPPUNIT_ASSERT(strings.size ( ) == countStrings ( generated ) );
 }
 
 void GrammarGenerateUpToLength::testGenerate3() {
@@ -132,9 +179,26 @@ void GrammarGenerateUpToLength::testGenerate3() {
 	strings.insert(string::LinearString < >(ext::vector<DefaultSymbolType>{c, d, d}));
 	strings.insert(string::LinearString < >(ext::vector<DefaultSymbolType>{b, c, d}));
 
-	for(const string::LinearString < >& str : grammar::generate::GenerateUpToLength::generate(grammar1, 3)) {
+	ext::trie < DefaultSymbolType, bool > generated = grammar::generate::GenerateUpToLength::generate(grammar1, 3);
+
+	generated.nicePrint ( std::cout );
+	std::cout << std::endl;
+
+	for(const string::LinearString < > & str : strings) {
+		bool flag = true;
+		ext::trie < DefaultSymbolType, bool > * node = & generated;
+		for ( const DefaultSymbolType & symbol : str.getContent ( ) ) {
+			auto iter = node->getChildren ( ).find ( symbol );
+			if ( iter == node->getChildren ( ).end ( ) ) {
+				flag = false;
+				break;
+			}
+			node = & iter->second;
+		}
+
+		CPPUNIT_ASSERT ( flag && node->getData ( ) == true );
 		std::cout << factory::StringDataFactory::toString ( str ) << std::endl;
 	}
 
-	CPPUNIT_ASSERT(strings == grammar::generate::GenerateUpToLength::generate(grammar1, 3));
+	CPPUNIT_ASSERT(strings.size ( ) == countStrings ( generated ) );
 }
diff --git a/alib2common/src/PrimitiveRegistrator.cpp b/alib2common/src/PrimitiveRegistrator.cpp
index b136f76403..3b416e9bbf 100644
--- a/alib2common/src/PrimitiveRegistrator.cpp
+++ b/alib2common/src/PrimitiveRegistrator.cpp
@@ -12,6 +12,8 @@
 
 #include <object/Object.h>
 
+#include <alib/trie>
+
 namespace {
 
 class PrimitiveRegistrator {
@@ -49,6 +51,7 @@ public:
 
 		abstraction::ValuePrinterRegistry::registerValuePrinter < object::Object > ( );
 		abstraction::ValuePrinterRegistry::registerValuePrinter < ext::set < object::Object > > ( );
+		abstraction::ValuePrinterRegistry::registerValuePrinter < ext::trie < object::Object, bool > > ( );
 		abstraction::ValuePrinterRegistry::registerValuePrinter < ext::set < unsigned > > ( );
 
 		abstraction::ValuePrinterRegistry::registerValuePrinter < ext::vector < object::Object > > ( );
diff --git a/alib2data/src/PrimitiveRegistrator.cpp b/alib2data/src/PrimitiveRegistrator.cpp
index e8e75be798..d1e5bbd2d2 100644
--- a/alib2data/src/PrimitiveRegistrator.cpp
+++ b/alib2data/src/PrimitiveRegistrator.cpp
@@ -11,11 +11,13 @@
 #include <registry/XmlContainerParserRegistry.hpp>
 
 #include <primitive/xml/UnsignedLong.h>
+#include <primitive/xml/Bool.h>
 
 #include <container/xml/ObjectsSet.h>
 #include <container/xml/ObjectsMap.h>
 #include <container/xml/ObjectsVector.h>
 #include <container/xml/ObjectsVariant.h>
+#include <container/xml/ObjectsTrie.h>
 
 #include <common/ranked_symbol.hpp>
 #include <alphabet/xml/RankedSymbol.h>
@@ -37,6 +39,7 @@ public:
 		registration::XmlWriterRegister < ext::vector < ext::vector < ext::set < object::Object > > > > ( );
 		registration::XmlWriterRegister < ext::map < common::ranked_symbol < object::Object, unsigned >, size_t > > ( );
 		registration::XmlWriterRegister < ext::set < string::LinearString < > > > ( );
+		registration::XmlWriterRegister < ext::trie < DefaultSymbolType, bool > > ( );
 
 		abstraction::XmlParserRegistry::registerXmlParser < object::Object > ( "DefaultStateType" );
 
diff --git a/alib2std/src/extensions/container/forward_tree.hpp b/alib2std/src/extensions/container/forward_tree.hpp
index accdb9ed59..b36e280ad8 100644
--- a/alib2std/src/extensions/container/forward_tree.hpp
+++ b/alib2std/src/extensions/container/forward_tree.hpp
@@ -921,7 +921,7 @@ public:
 	 *
 	 * \param position the specification of position in children where to erase the subtree
 	 */
-	const_children_iterator erase ( const_children_iterator position ) {
+	children_iterator erase ( const_children_iterator position ) {
 		ext::vector < forward_tree > & children = const_cast < ext::vector < forward_tree > & > ( getChildren ( ) );
 
 		return children.erase ( position );
diff --git a/alib2std/src/extensions/container/tree.hpp b/alib2std/src/extensions/container/tree.hpp
index a75c3f8087..bed9b6768a 100644
--- a/alib2std/src/extensions/container/tree.hpp
+++ b/alib2std/src/extensions/container/tree.hpp
@@ -1012,7 +1012,7 @@ public:
 	 *
 	 * \param position the specification of position in children where to erase the subtree
 	 */
-	const_children_iterator erase ( const_children_iterator position ) {
+	children_iterator erase ( const_children_iterator position ) {
 		ext::vector < tree > & children = const_cast < ext::vector < tree > & > ( getChildren ( ) );
 
 		return children.erase ( position );
diff --git a/alib2std/src/extensions/container/trie.hpp b/alib2std/src/extensions/container/trie.hpp
index 00f82983cf..f0a9900fc2 100644
--- a/alib2std/src/extensions/container/trie.hpp
+++ b/alib2std/src/extensions/container/trie.hpp
@@ -356,7 +356,7 @@ public:
 	 *
 	 * \param position the specification of position in children where to erase the subtrie
 	 */
-	const_children_iterator erase ( const_children_iterator position ) {
+	children_iterator erase ( const_children_iterator position ) {
 		ext::map < Key, trie > & children = const_cast < ext::map < Key, trie > & > ( getChildren ( ) );
 
 		return children.erase ( position );
diff --git a/tests.anormalize.sh b/tests.anormalize.sh
index 2a5a50ae73..47e327f742 100755
--- a/tests.anormalize.sh
+++ b/tests.anormalize.sh
@@ -58,7 +58,7 @@ function log {
 }
 
 function generateCFG {
-	./arand2 -t CFG --density $RAND_DENSITY --nonterminals $(( $RANDOM % $RAND_NONTERMINALS + 1 )) --terminals $(( $RANDOM % $RAND_TERMINALS + 1 )) 2>/dev/null | ./aepsilon2
+	./arand2 -t CFG --density $RAND_DENSITY --nonterminals $(( $RANDOM % $RAND_NONTERMINALS + 1 )) --terminals $(( $RANDOM % $RAND_TERMINALS + 1 )) 2>/dev/null | ./aepsilon2 | ./aql2 -q "execute grammar::simplify::SimpleRulesRemover < #stdin > #stdout"
 }
 
 # $1 = command for conversion. Output of such command must be a grammar !!
-- 
GitLab