From 4bc296c08a54a395fbfd0f4738e6cbaabf32c1f7 Mon Sep 17 00:00:00 2001 From: Jan Travnicek <Jan.Travnicek@fit.cvut.cz> Date: Tue, 15 Jan 2019 13:44:32 +0100 Subject: [PATCH] speedup generate strings up to given length from grammar --- .../grammar/generate/GenerateUpToLength.cpp | 10 +-- .../src/grammar/generate/GenerateUpToLength.h | 76 ++++++++++++++----- .../generate/GrammarGenerateUpToLength.cpp | 74 ++++++++++++++++-- alib2common/src/PrimitiveRegistrator.cpp | 3 + alib2data/src/PrimitiveRegistrator.cpp | 3 + .../src/extensions/container/forward_tree.hpp | 2 +- alib2std/src/extensions/container/tree.hpp | 2 +- alib2std/src/extensions/container/trie.hpp | 2 +- tests.anormalize.sh | 2 +- 9 files changed, 139 insertions(+), 35 deletions(-) diff --git a/alib2algo/src/grammar/generate/GenerateUpToLength.cpp b/alib2algo/src/grammar/generate/GenerateUpToLength.cpp index a481ff3eba..9a0e6b4ad0 100644 --- a/alib2algo/src/grammar/generate/GenerateUpToLength.cpp +++ b/alib2algo/src/grammar/generate/GenerateUpToLength.cpp @@ -12,11 +12,11 @@ namespace grammar { namespace generate { -auto GenerateUpToLengthEpsilonFreeCFG = registration::AbstractRegister < GenerateUpToLength, ext::set < string::LinearString < > >, const grammar::EpsilonFreeCFG < > &, unsigned > ( GenerateUpToLength::generate ); -auto GenerateUpToLengthGNF = registration::AbstractRegister < GenerateUpToLength, ext::set < string::LinearString < > >, const grammar::GNF < > &, unsigned > ( GenerateUpToLength::generate ); -auto GenerateUpToLengthCNF = registration::AbstractRegister < GenerateUpToLength, ext::set < string::LinearString < > >, const grammar::CNF < > &, unsigned > ( GenerateUpToLength::generate ); -auto GenerateUpToLengthLeftRG = registration::AbstractRegister < GenerateUpToLength, ext::set < string::LinearString < > >, const grammar::LeftRG < > &, unsigned > ( GenerateUpToLength::generate ); -auto GenerateUpToLengthRightRG = registration::AbstractRegister < GenerateUpToLength, ext::set < string::LinearString < > >, const grammar::RightRG < > &, unsigned > ( GenerateUpToLength::generate ); +auto GenerateUpToLengthEpsilonFreeCFG = registration::AbstractRegister < GenerateUpToLength, ext::trie < DefaultSymbolType, bool >, const grammar::EpsilonFreeCFG < > &, unsigned > ( GenerateUpToLength::generate ); +auto GenerateUpToLengthGNF = registration::AbstractRegister < GenerateUpToLength, ext::trie < DefaultSymbolType, bool >, const grammar::GNF < > &, unsigned > ( GenerateUpToLength::generate ); +auto GenerateUpToLengthCNF = registration::AbstractRegister < GenerateUpToLength, ext::trie < DefaultSymbolType, bool >, const grammar::CNF < > &, unsigned > ( GenerateUpToLength::generate ); +auto GenerateUpToLengthLeftRG = registration::AbstractRegister < GenerateUpToLength, ext::trie < DefaultSymbolType, bool >, const grammar::LeftRG < > &, unsigned > ( GenerateUpToLength::generate ); +auto GenerateUpToLengthRightRG = registration::AbstractRegister < GenerateUpToLength, ext::trie < DefaultSymbolType, bool >, const grammar::RightRG < > &, unsigned > ( GenerateUpToLength::generate ); } /* namespace generate */ diff --git a/alib2algo/src/grammar/generate/GenerateUpToLength.h b/alib2algo/src/grammar/generate/GenerateUpToLength.h index 5007490124..d1efd89586 100644 --- a/alib2algo/src/grammar/generate/GenerateUpToLength.h +++ b/alib2algo/src/grammar/generate/GenerateUpToLength.h @@ -12,6 +12,8 @@ #include <string/LinearString.h> #include <alib/set> #include <alib/deque> +#include <alib/trie> +#include <alib/iterator> #include <grammar/ContextFree/EpsilonFreeCFG.h> #include <grammar/ContextFree/GNF.h> @@ -29,30 +31,55 @@ namespace generate { * Implements algorithms from Melichar, chapter 3.3 */ class GenerateUpToLength { + template < class TerminalSymbolType > + static bool pruneTrie ( ext::trie < TerminalSymbolType, bool > & trie ) { + for ( typename std::map < TerminalSymbolType, ext::trie < TerminalSymbolType, bool > >::iterator iter = trie.getChildren ( ).begin ( ); iter != trie.getChildren ( ).end ( ); ) + if ( pruneTrie ( iter->second ) ) + iter = trie.erase ( iter ); + else + ++ iter; + + return trie.getData ( ) == false && trie.getChildren ( ).size ( ) == 0; + } + public: template < class T, class TerminalSymbolType = typename grammar::TerminalSymbolTypeOfGrammar < T >, class NontermimnalSymbolType = typename grammar::NonterminalSymbolTypeOfGrammar < T > > - static ext::set < string::LinearString < TerminalSymbolType > > generate ( const T & grammar, unsigned length ); + static ext::trie < TerminalSymbolType, bool > generate ( const T & grammar, unsigned length ); }; template < class T, class TerminalSymbolType, class NonterminalSymbolType > -ext::set < string::LinearString < TerminalSymbolType > > GenerateUpToLength::generate ( const T & grammar, unsigned length ) { - ext::set < string::LinearString < TerminalSymbolType > > res; +ext::trie < TerminalSymbolType, bool > GenerateUpToLength::generate ( const T & grammar, unsigned length ) { + ext::trie < TerminalSymbolType, bool > res ( false ); ext::map < NonterminalSymbolType, ext::set < ext::vector < ext::variant < TerminalSymbolType, NonterminalSymbolType > > > > rules = grammar::RawRules::getRawRules ( grammar ); if ( grammar.getGeneratesEpsilon ( ) ) { - res.insert ( string::LinearString < TerminalSymbolType > { } ); + res.getData ( ) = true; // TODO improve interface rules [ grammar.getInitialSymbol ( ) ].erase ( ext::vector < ext::variant < TerminalSymbolType, NonterminalSymbolType > > { } ); } - ext::deque < std::pair < ext::vector < TerminalSymbolType >, ext::vector < ext::variant < TerminalSymbolType, NonterminalSymbolType > > > > data; - data.push_back ( std::make_pair ( ext::vector < TerminalSymbolType > { }, ext::vector < ext::variant < TerminalSymbolType, NonterminalSymbolType > > { ext::variant < TerminalSymbolType, NonterminalSymbolType > ( grammar.getInitialSymbol ( ) ) } ) ); + struct Node { + ext::variant < TerminalSymbolType, NonterminalSymbolType > m_symbol; + std::shared_ptr < Node > m_parent; + unsigned m_depth; + + Node ( ext::variant < TerminalSymbolType, NonterminalSymbolType > symbol, std::shared_ptr < Node > parent, unsigned depth ) : m_symbol ( std::move ( symbol ) ), m_parent ( std::move ( parent ) ), m_depth ( depth ) { + } + }; + + ext::deque < std::tuple < typename ext::trie < TerminalSymbolType, bool > *, unsigned, std::shared_ptr < Node > > > data; + data.push_back ( std::make_tuple ( & res, 0, std::make_shared < Node > ( grammar.getInitialSymbol ( ), nullptr, 1 ) ) ); while ( ! data.empty ( ) ) { - std::pair < ext::vector < TerminalSymbolType >, ext::vector < ext::variant < TerminalSymbolType, NonterminalSymbolType > > > item = std::move ( data.back ( ) ); + unsigned generatedLen = std::get < 1 > ( data.back ( ) ); + typename ext::trie < TerminalSymbolType, bool > * trieNode = std::get < 0 > ( data.back ( ) ); + + std::shared_ptr < Node > stackTop = std::get < 2 > ( data.back ( ) ); + ext::variant < TerminalSymbolType, NonterminalSymbolType > top = stackTop->m_symbol; + unsigned stackSize = stackTop->m_depth - 1; + stackTop = stackTop->m_parent; + data.pop_back ( ); - ext::variant < TerminalSymbolType, NonterminalSymbolType > top = std::move ( item.second.back ( ) ); - item.second.pop_back ( ); if ( ! top.template is < NonterminalSymbolType > ( ) ) // maybe not needed continue; auto rule = rules.find ( top.template get < NonterminalSymbolType > ( ) ); @@ -60,25 +87,32 @@ ext::set < string::LinearString < TerminalSymbolType > > GenerateUpToLength::gen continue; for ( const ext::vector < ext::variant < TerminalSymbolType, NonterminalSymbolType > > & rhs : rule->second ) { - if ( item.first.size ( ) + item.second.size ( ) + rhs.size ( ) > length ) continue; - - ext::vector < ext::variant < TerminalSymbolType, NonterminalSymbolType > > newStack ( item.second ); - ext::vector < TerminalSymbolType > newString ( item.first ); - newStack.insert ( newStack.end ( ), rhs.rbegin ( ), rhs.rend ( ) ); - - while ( ! newStack.empty ( ) && grammar.getTerminalAlphabet ( ).count ( newStack.back ( ) ) ) { - newString.push_back ( std::move ( newStack.back ( ).template get < TerminalSymbolType > ( ) ) ); - newStack.pop_back ( ); + if ( generatedLen + stackSize + rhs.size ( ) > length ) + continue; + + std::shared_ptr < Node > newStackTop = stackTop; + for ( const ext::variant < TerminalSymbolType, NonterminalSymbolType > & symbol : ext::make_reverse ( rhs ) ) + newStackTop = std::make_shared < Node > ( symbol, newStackTop, newStackTop ? newStackTop->m_depth + 1 : 1 ); + + typename ext::trie < TerminalSymbolType, bool > * newTrieNode = trieNode; + unsigned newGeneratedLen = generatedLen; + while ( newStackTop != nullptr && grammar.getTerminalAlphabet ( ).count ( newStackTop->m_symbol ) ) { + newTrieNode->insert ( newStackTop->m_symbol.template get < TerminalSymbolType > ( ), ext::trie < TerminalSymbolType, bool > ( false ) ); + newTrieNode = & newTrieNode->getChildren ( ).at ( newStackTop->m_symbol.template get < TerminalSymbolType > ( ) ); + newStackTop = newStackTop->m_parent; + newGeneratedLen += 1; } - if ( newStack.empty ( ) ) { - res.insert ( string::LinearString < TerminalSymbolType > ( newString ) ); + if ( newStackTop == nullptr ) { + newTrieNode->getData ( ) = true; } else { - data.push_back ( std::make_pair ( std::move ( newString ), std::move ( newStack ) ) ); + data.push_back ( std::make_tuple ( newTrieNode, newGeneratedLen, newStackTop ) ); } } } + pruneTrie ( res ); + return res; } diff --git a/alib2algo/test-src/grammar/generate/GrammarGenerateUpToLength.cpp b/alib2algo/test-src/grammar/generate/GrammarGenerateUpToLength.cpp index 023d3c1a74..76ac22a25b 100644 --- a/alib2algo/test-src/grammar/generate/GrammarGenerateUpToLength.cpp +++ b/alib2algo/test-src/grammar/generate/GrammarGenerateUpToLength.cpp @@ -17,6 +17,15 @@ void GrammarGenerateUpToLength::setUp() { void GrammarGenerateUpToLength::tearDown() { } +unsigned countStrings ( const ext::trie < DefaultSymbolType, bool > & node ) { + unsigned res = node.getData ( ); + + for ( const auto & child : node.getChildren ( ) ) + res += countStrings ( child.second ); + + return res; +} + void GrammarGenerateUpToLength::testGenerate1() { DefaultSymbolType S = DefaultSymbolType("S"); DefaultSymbolType A = DefaultSymbolType("A"); @@ -34,7 +43,28 @@ void GrammarGenerateUpToLength::testGenerate1() { ext::set<string::LinearString < >> strings; - CPPUNIT_ASSERT(strings == grammar::generate::GenerateUpToLength::generate(grammar1, 5)); + ext::trie < DefaultSymbolType, bool > generated = grammar::generate::GenerateUpToLength::generate(grammar1, 5); + + generated.nicePrint ( std::cout ); + std::cout << std::endl; + + for(const string::LinearString < > & str : strings) { + bool flag = true; + ext::trie < DefaultSymbolType, bool > * node = & generated; + for ( const DefaultSymbolType & symbol : str.getContent ( ) ) { + auto iter = node->getChildren ( ).find ( symbol ); + if ( iter == node->getChildren ( ).end ( ) ) { + flag = false; + break; + } + node = & iter->second; + } + + CPPUNIT_ASSERT ( flag && node->getData ( ) == true ); + std::cout << factory::StringDataFactory::toString ( str ) << std::endl; + } + + CPPUNIT_ASSERT(strings.size ( ) == countStrings ( generated ) ); } void GrammarGenerateUpToLength::testGenerate2() { @@ -86,11 +116,28 @@ void GrammarGenerateUpToLength::testGenerate2() { strings.insert(string::LinearString < >(ext::vector<DefaultSymbolType>{a, c, d})); strings.insert(string::LinearString < >(ext::vector<DefaultSymbolType>{b, c, d})); - for(const string::LinearString < >& str : grammar::generate::GenerateUpToLength::generate(grammar1, 3)) { + ext::trie < DefaultSymbolType, bool > generated = grammar::generate::GenerateUpToLength::generate(grammar1, 3); + + generated.nicePrint ( std::cout ); + std::cout << std::endl; + + for(const string::LinearString < > & str : strings) { + bool flag = true; + ext::trie < DefaultSymbolType, bool > * node = & generated; + for ( const DefaultSymbolType & symbol : str.getContent ( ) ) { + auto iter = node->getChildren ( ).find ( symbol ); + if ( iter == node->getChildren ( ).end ( ) ) { + flag = false; + break; + } + node = & iter->second; + } + + CPPUNIT_ASSERT ( flag && node->getData ( ) == true ); std::cout << factory::StringDataFactory::toString ( str ) << std::endl; } - CPPUNIT_ASSERT(strings == grammar::generate::GenerateUpToLength::generate(grammar1, 3)); + CPPUNIT_ASSERT(strings.size ( ) == countStrings ( generated ) ); } void GrammarGenerateUpToLength::testGenerate3() { @@ -132,9 +179,26 @@ void GrammarGenerateUpToLength::testGenerate3() { strings.insert(string::LinearString < >(ext::vector<DefaultSymbolType>{c, d, d})); strings.insert(string::LinearString < >(ext::vector<DefaultSymbolType>{b, c, d})); - for(const string::LinearString < >& str : grammar::generate::GenerateUpToLength::generate(grammar1, 3)) { + ext::trie < DefaultSymbolType, bool > generated = grammar::generate::GenerateUpToLength::generate(grammar1, 3); + + generated.nicePrint ( std::cout ); + std::cout << std::endl; + + for(const string::LinearString < > & str : strings) { + bool flag = true; + ext::trie < DefaultSymbolType, bool > * node = & generated; + for ( const DefaultSymbolType & symbol : str.getContent ( ) ) { + auto iter = node->getChildren ( ).find ( symbol ); + if ( iter == node->getChildren ( ).end ( ) ) { + flag = false; + break; + } + node = & iter->second; + } + + CPPUNIT_ASSERT ( flag && node->getData ( ) == true ); std::cout << factory::StringDataFactory::toString ( str ) << std::endl; } - CPPUNIT_ASSERT(strings == grammar::generate::GenerateUpToLength::generate(grammar1, 3)); + CPPUNIT_ASSERT(strings.size ( ) == countStrings ( generated ) ); } diff --git a/alib2common/src/PrimitiveRegistrator.cpp b/alib2common/src/PrimitiveRegistrator.cpp index b136f76403..3b416e9bbf 100644 --- a/alib2common/src/PrimitiveRegistrator.cpp +++ b/alib2common/src/PrimitiveRegistrator.cpp @@ -12,6 +12,8 @@ #include <object/Object.h> +#include <alib/trie> + namespace { class PrimitiveRegistrator { @@ -49,6 +51,7 @@ public: abstraction::ValuePrinterRegistry::registerValuePrinter < object::Object > ( ); abstraction::ValuePrinterRegistry::registerValuePrinter < ext::set < object::Object > > ( ); + abstraction::ValuePrinterRegistry::registerValuePrinter < ext::trie < object::Object, bool > > ( ); abstraction::ValuePrinterRegistry::registerValuePrinter < ext::set < unsigned > > ( ); abstraction::ValuePrinterRegistry::registerValuePrinter < ext::vector < object::Object > > ( ); diff --git a/alib2data/src/PrimitiveRegistrator.cpp b/alib2data/src/PrimitiveRegistrator.cpp index e8e75be798..d1e5bbd2d2 100644 --- a/alib2data/src/PrimitiveRegistrator.cpp +++ b/alib2data/src/PrimitiveRegistrator.cpp @@ -11,11 +11,13 @@ #include <registry/XmlContainerParserRegistry.hpp> #include <primitive/xml/UnsignedLong.h> +#include <primitive/xml/Bool.h> #include <container/xml/ObjectsSet.h> #include <container/xml/ObjectsMap.h> #include <container/xml/ObjectsVector.h> #include <container/xml/ObjectsVariant.h> +#include <container/xml/ObjectsTrie.h> #include <common/ranked_symbol.hpp> #include <alphabet/xml/RankedSymbol.h> @@ -37,6 +39,7 @@ public: registration::XmlWriterRegister < ext::vector < ext::vector < ext::set < object::Object > > > > ( ); registration::XmlWriterRegister < ext::map < common::ranked_symbol < object::Object, unsigned >, size_t > > ( ); registration::XmlWriterRegister < ext::set < string::LinearString < > > > ( ); + registration::XmlWriterRegister < ext::trie < DefaultSymbolType, bool > > ( ); abstraction::XmlParserRegistry::registerXmlParser < object::Object > ( "DefaultStateType" ); diff --git a/alib2std/src/extensions/container/forward_tree.hpp b/alib2std/src/extensions/container/forward_tree.hpp index accdb9ed59..b36e280ad8 100644 --- a/alib2std/src/extensions/container/forward_tree.hpp +++ b/alib2std/src/extensions/container/forward_tree.hpp @@ -921,7 +921,7 @@ public: * * \param position the specification of position in children where to erase the subtree */ - const_children_iterator erase ( const_children_iterator position ) { + children_iterator erase ( const_children_iterator position ) { ext::vector < forward_tree > & children = const_cast < ext::vector < forward_tree > & > ( getChildren ( ) ); return children.erase ( position ); diff --git a/alib2std/src/extensions/container/tree.hpp b/alib2std/src/extensions/container/tree.hpp index a75c3f8087..bed9b6768a 100644 --- a/alib2std/src/extensions/container/tree.hpp +++ b/alib2std/src/extensions/container/tree.hpp @@ -1012,7 +1012,7 @@ public: * * \param position the specification of position in children where to erase the subtree */ - const_children_iterator erase ( const_children_iterator position ) { + children_iterator erase ( const_children_iterator position ) { ext::vector < tree > & children = const_cast < ext::vector < tree > & > ( getChildren ( ) ); return children.erase ( position ); diff --git a/alib2std/src/extensions/container/trie.hpp b/alib2std/src/extensions/container/trie.hpp index 00f82983cf..f0a9900fc2 100644 --- a/alib2std/src/extensions/container/trie.hpp +++ b/alib2std/src/extensions/container/trie.hpp @@ -356,7 +356,7 @@ public: * * \param position the specification of position in children where to erase the subtrie */ - const_children_iterator erase ( const_children_iterator position ) { + children_iterator erase ( const_children_iterator position ) { ext::map < Key, trie > & children = const_cast < ext::map < Key, trie > & > ( getChildren ( ) ); return children.erase ( position ); diff --git a/tests.anormalize.sh b/tests.anormalize.sh index 2a5a50ae73..47e327f742 100755 --- a/tests.anormalize.sh +++ b/tests.anormalize.sh @@ -58,7 +58,7 @@ function log { } function generateCFG { - ./arand2 -t CFG --density $RAND_DENSITY --nonterminals $(( $RANDOM % $RAND_NONTERMINALS + 1 )) --terminals $(( $RANDOM % $RAND_TERMINALS + 1 )) 2>/dev/null | ./aepsilon2 + ./arand2 -t CFG --density $RAND_DENSITY --nonterminals $(( $RANDOM % $RAND_NONTERMINALS + 1 )) --terminals $(( $RANDOM % $RAND_TERMINALS + 1 )) 2>/dev/null | ./aepsilon2 | ./aql2 -q "execute grammar::simplify::SimpleRulesRemover < #stdin > #stdout" } # $1 = command for conversion. Output of such command must be a grammar !! -- GitLab