Skip to content
Snippets Groups Projects
Commit 92e51c9e authored by Jan Trávníček's avatar Jan Trávníček
Browse files

templatise arithmetic coding

parent de0e7a7c
No related branches found
No related tags found
1 merge request!49First compression algorithms
...@@ -8,7 +8,8 @@ ...@@ -8,7 +8,8 @@
#include "ArithmeticModel.h" #include "ArithmeticModel.h"
   
#include <alib/vector> #include <alib/vector>
#include <alib/string>
#include <string/LinearString.h>
   
namespace stringology { namespace stringology {
   
...@@ -23,39 +24,37 @@ class AdaptiveIntegerArithmeticCompression { ...@@ -23,39 +24,37 @@ class AdaptiveIntegerArithmeticCompression {
} }
} }
public: public:
template < class SymbolType >
static ext::vector < bool > compress ( ext::string & source ) { static ext::vector < bool > compress ( const string::LinearString < SymbolType > & source ) {
ext::set < char > alphabet; ArithmeticModel < SymbolType > model ( source.getAlphabet ( ) );
for ( int i = 0; i < 256; ++ i )
alphabet.insert ( i );
ArithmeticModel < char > model ( alphabet );
   
ext::vector < bool > result; ext::vector < bool > result;
unsigned pending_bits = 0; unsigned pending_bits = 0;
unsigned valid_bits = sizeof ( unsigned long long ) * 8 / 2;
   
unsigned max_code = ~0u; unsigned long long max_code = ~0ull >> valid_bits;
unsigned one_half = ( max_code >> 1 ) + 1; unsigned long long one_half = ( max_code >> 1 ) + 1;
unsigned one_fourth = ( max_code >> 2 ) + 1; unsigned long long one_fourth = ( max_code >> 2 ) + 1;
unsigned three_fourths = one_half + one_fourth; unsigned long long three_fourths = one_half + one_fourth;
   
unsigned low = 0; unsigned long long low = 0;
unsigned high = max_code; unsigned long long high = max_code;
   
for ( size_t index = 0; index < source.size ( ) + 1; ++ index ) { for ( size_t index = 0; index < source.getContent ( ).size ( ) + 1; ++ index ) {
   
unsigned prob_low, prob_high; unsigned prob_low, prob_high;
unsigned prob_count = model.getCount ( ); unsigned prob_count = model.getCount ( );
   
if ( index >= source.size ( ) ) if ( index >= source.getContent ( ).size ( ) )
std::tie ( prob_low, prob_high ) = model.getProbabilityEof ( ); model.getProbabilityEof ( prob_low, prob_high );
else { else {
std::tie ( prob_low, prob_high ) = model.getProbability ( source [ index ] ); model.getProbability ( source.getContent ( ) [ index ], prob_low, prob_high );
model.update ( source [ index ] ); model.update ( source.getContent ( ) [ index ] );
} }
   
unsigned long long range = ( unsigned long long ) ( high - low ) + 1; unsigned long long range = high - low + 1;
high = low + ( unsigned ) ( range * prob_high / prob_count - 1 ); high = low + range * prob_high / prob_count - 1;
low = low + ( unsigned ) ( range * prob_low / prob_count ); low = low + range * prob_low / prob_count;
for ( ; ; ) { for ( ; ; ) {
if ( high < one_half || low >= one_half ) if ( high < one_half || low >= one_half )
put_bit_plus_pending(result, low >= one_half, pending_bits); put_bit_plus_pending(result, low >= one_half, pending_bits);
...@@ -68,6 +67,9 @@ public: ...@@ -68,6 +67,9 @@ public:
high <<= 1; high <<= 1;
high++; high++;
low <<= 1; low <<= 1;
low &= max_code;
high &= max_code;
} }
} }
pending_bits++; pending_bits++;
......
...@@ -7,19 +7,18 @@ ...@@ -7,19 +7,18 @@
   
#include "ArithmeticModel.h" #include "ArithmeticModel.h"
   
#include <string/LinearString.h>
namespace stringology { namespace stringology {
   
namespace compression { namespace compression {
   
class AdaptiveIntegerArithmeticDecompression { class AdaptiveIntegerArithmeticDecompression {
public: public:
static ext::string decompress ( ext::vector < bool > &source ) { template < class SymbolType >
ext::set < char > alphabet; static string::LinearString < SymbolType > decompress ( const ext::vector < bool > & source, const ext::set < SymbolType > & alphabet ) {
for ( int i = 0; i < 256; ++ i ) ArithmeticModel < SymbolType > model ( alphabet );
alphabet.insert ( i ); ext::vector < SymbolType > result;
ArithmeticModel < char > model ( alphabet );
ext::string result;
   
unsigned valid_bits = sizeof ( unsigned long long ) * 8 / 2; unsigned valid_bits = sizeof ( unsigned long long ) * 8 / 2;
   
...@@ -59,17 +58,17 @@ public: ...@@ -59,17 +58,17 @@ public:
if ( model.isEof ( scaled_value ) ) if ( model.isEof ( scaled_value ) )
break; break;
   
char c;
unsigned prob_low, prob_high; unsigned prob_low, prob_high;
unsigned prob_count = model.getCount ( ); unsigned prob_count = model.getCount ( );
std::tie ( prob_low, prob_high, c ) = model.getChar ( scaled_value ); SymbolType c = model.getChar ( scaled_value, prob_low, prob_high );
model.update ( c ); model.update ( c );
   
result += c; result.push_back ( c );
high = low + ( range * prob_high ) / prob_count - 1; high = low + range * prob_high / prob_count - 1;
low = low + ( range * prob_low ) / prob_count; low = low + range * prob_low / prob_count;
} }
return result;
return string::LinearString < SymbolType > ( alphabet, result );
} }
}; };
   
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
   
#include <stdexcept> #include <stdexcept>
#include <alib/map> #include <alib/map>
#include <alib/variant>
#include <alib/set> #include <alib/set>
   
template < class SymbolType > template < class SymbolType >
...@@ -19,35 +18,36 @@ public: ...@@ -19,35 +18,36 @@ public:
m_global_high = frequency + 1; m_global_high = frequency + 1;
} }
   
void update ( const ext::variant < void, SymbolType > & symbol ) { void update ( const SymbolType & symbol ) {
for ( auto i = m_high_cumulative_frequency.find ( symbol ); i != m_high_cumulative_frequency.end ( ) ; ++ i ) for ( auto i = m_high_cumulative_frequency.find ( symbol ); i != m_high_cumulative_frequency.end ( ) ; ++ i )
i->second += 1; i->second += 1;
m_global_high += 1; m_global_high += 1;
} }
   
std::tuple < unsigned, unsigned > getProbability ( const SymbolType & c ) const { void getProbability ( const SymbolType & c, unsigned & low_prob, unsigned & high_prob ) const {
auto i = m_high_cumulative_frequency.find ( c ); auto i = m_high_cumulative_frequency.find ( c );
unsigned low_prob = 0; high_prob = i->second;
low_prob = 0;
   
if ( i != m_high_cumulative_frequency.begin ( ) ) if ( i != m_high_cumulative_frequency.begin ( ) )
low_prob = std::prev ( i )->second; low_prob = std::prev ( i )->second;
return std::make_tuple ( low_prob, i->second );
} }
   
std::tuple < unsigned, unsigned > getProbabilityEof ( ) const { void getProbabilityEof ( unsigned & low_prob, unsigned & high_prob ) const {
return std::make_tuple ( m_global_high - 1, m_global_high ); low_prob = m_global_high - 1;
high_prob = m_global_high;
} }
   
std::tuple < unsigned, unsigned, SymbolType > getChar ( unsigned scaled_value ) const { SymbolType getChar ( unsigned scaled_value, unsigned & low_prob, unsigned & high_prob ) const {
for ( auto i = m_high_cumulative_frequency.begin ( ); i != m_high_cumulative_frequency.end ( ); ++ i ) for ( auto i = m_high_cumulative_frequency.begin ( ); i != m_high_cumulative_frequency.end ( ); ++ i )
if ( scaled_value < i->second ) { if ( scaled_value < i->second ) {
unsigned low_prob = 0; high_prob = i->second;
low_prob = 0;
   
if ( i != m_high_cumulative_frequency.begin ( ) ) if ( i != m_high_cumulative_frequency.begin ( ) )
low_prob = std::prev ( i )->second; low_prob = std::prev ( i )->second;
   
return std::make_tuple ( low_prob, i->second, i->first ); return i->first;
} }
throw std::logic_error("error"); throw std::logic_error("error");
} }
......
...@@ -13,12 +13,15 @@ void ArithmeticCompressionTest::tearDown() { ...@@ -13,12 +13,15 @@ void ArithmeticCompressionTest::tearDown() {
} }
   
void ArithmeticCompressionTest::basics() { void ArithmeticCompressionTest::basics() {
ext::string input ( "abbabbabaae2378 8723 babababb ab bapobababbbabaaabbafjfjdjlvldsuiueqwpomvdhgataewpvdihviasubababbba 5475 baaabba" ); ext::string rawInput ( "abbabbabaae123456789r0 8723 babababb ab bapobababbbabaaabbafjfjdjlvldsuiueqwpomvdhgataewpvdihviasubababbba 5475 baaabba" );
string::LinearString < char > input ( rawInput );
ext::vector < bool > compressed = stringology::compression::AdaptiveIntegerArithmeticCompression::compress ( input ); ext::vector < bool > compressed = stringology::compression::AdaptiveIntegerArithmeticCompression::compress ( input );
std::cout << "compressed = " << compressed << std::endl; std::cout << "compressed = " << compressed << std::endl;
ext::string output = stringology::compression::AdaptiveIntegerArithmeticDecompression::decompress ( compressed ); string::LinearString < char > output = stringology::compression::AdaptiveIntegerArithmeticDecompression::decompress ( compressed, input.getAlphabet ( ) );
   
std::cout << "original= " << input << " decompressed = " << output << std::endl; std::cout << "original= " << input << std::endl << "decompressed = " << output << std::endl;
std::cout << "compressed size = " << compressed.size ( ) << std::endl << "original_size = " << input.getContent ( ).size ( ) * 8 << std::endl;
CPPUNIT_ASSERT ( input == output ); CPPUNIT_ASSERT ( input == output );
} }
   
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment