#include "bc_compiler.h"
#include <assert.h>

// DEBUG
#include <stdio.h>

void string_init    ( String * str ) {
  str -> capacity = INIT_STRING_LENGTH;
  str -> len = 0;
  str -> str = (u8*) malloc ( INIT_STRING_LENGTH );
}

void string_destroy ( String * str ) {
  free ( str -> str );
}

void string_write_byte ( String * str, const u8  data ) {
  if ( str -> capacity == str -> len ) {
    u8 * tmp = (u8*) malloc ( str -> capacity * 2 );
    assert ( tmp );
    memcpy ( tmp, str -> str, str -> len );
    free ( str -> str );
    str -> str = tmp;
    str -> capacity *= 2;
  }
  str -> str [ str -> len ++ ] = data;
}

void string_write_u16  ( String * str, const u16 data ) {
  string_write_byte ( str, data & 255 );
  string_write_byte ( str, data >> 8 );
}

void string_write_i32  ( String * str, const i32 data ) {
  string_write_byte ( str, data & 255 );
  string_write_byte ( str, (data >> 8) & 255 );
  string_write_byte ( str, (data >> 16) & 255 );
  string_write_byte ( str, (data >> 24) & 255 );
}

void string_write_constant ( String * str, BCConstant constant ) {
  switch ( constant . kind ) {
    case CONSTANT_NULL:
      // fprintf ( stderr, "null, " );
      string_write_byte ( str, 0x01 );
      break;
    case CONSTANT_BOOLEAN:
      // fprintf ( stderr, "bool, " );
      string_write_byte ( str, 0x04 );
      string_write_byte ( str, constant . boolean ? 0x01 : 0x00 );
      break;
    case CONSTANT_INTEGER:
      // fprintf ( stderr, "int, " );
      string_write_byte ( str, 0x00 );
      string_write_i32  ( str, constant . integer );
      break;
    case CONSTANT_STRING:
      // fprintf ( stderr, "str, " );
      string_write_byte ( str, 0x02 );
      string_write_i32  ( str, constant . string . len );
      for ( size_t i = 0; i < constant . string . len; ++i )
        string_write_byte ( str, constant . string . str [ i ] );
      break;
    case CONSTANT_FUNCTION:
      // fprintf ( stderr, "f, " );
      string_write_byte ( str, 0x03 );
      string_write_byte ( str, constant . function . parameters );
      string_write_u16  ( str, constant . function . locals );
      string_write_i32  ( str, constant . function . bc . len );
      for ( size_t i = 0; i < constant . function . bc . len; ++i )
        string_write_byte ( str, constant . function . bc . str [ i ] );
      break;
    default:
      assert ( false );
      break;
  }
}

BCConstant create_function ( u8 parameters ) {
  BCFunction function = (BCFunction) { .parameters = parameters, .locals = 0 };
  string_init ( & function . bc );
  return (BCConstant) { .kind = CONSTANT_FUNCTION, .function = function };
}

LocalScope * create_function_scope ( u8 parameters ) {
  LocalScope * scope = (LocalScope*) malloc ( sizeof (LocalScope*) + 2 * sizeof (u16) + sizeof (Str) * ( parameters + MAX_SCOPE_VARIABLES ) );
  scope -> locals [ 0 ] = STR ( "this" );
  scope -> used_locals = 1;
  return scope;
}

u16 get_string_index ( BCCompilerState * state, Str str ) {
  u16 i;
  for ( i = 0; i < state -> constants . constant_count; ++i )
    if ( state -> constants . constants [ i ] . kind == CONSTANT_STRING && str_eq ( state -> constants . constants [ i ] . string, str ) )
      return i;
  insert_constant ( state, (BCConstant) { .kind = CONSTANT_STRING, .string = str } );
  return i;
}

u16 get_local_index ( BCCompilerState * state, Str name ) {
  BCFunction * function = & state -> constants . constants [ state -> fp ] . function;
  LocalScope * scope = state -> scope;
  size_t i;
  while ( scope ) {
    for ( i = 0; i < scope -> used_locals; ++i )
      if ( str_eq ( scope -> locals [ i ], name ) )
        return function -> parameters + ( scope -> prev ? scope -> prev -> local_count + i: i );
    scope = scope -> prev;
  }
  u16 * n = & state -> scope -> local_count;
  assert ( *n != 256 * 256 );
  state -> scope -> locals [ *n ] = name;
  ++(*n);
  ++(state -> scope -> used_locals);
  if ( *n > function -> locals )
    function -> locals = *n;
  return *n - 1 + function -> parameters; 
}

void gen_bc_constant ( BCCompilerState * state, u16 index ) {
  BCFunction * function = & state -> constants . constants [ state -> fp ] . function;
  string_write_byte ( & function -> bc, 0x01 );
  string_write_u16  ( & function -> bc, index );
}

void insert_constant ( BCCompilerState * state, BCConstant constant ) {
  state -> constants . constants [ state -> constants . constant_count ++ ] = constant;
}

void bc_state_init    ( BCCompilerState * state ) {
  state -> constants . constant_count = 0;
  state -> constants . null_pos = -1;
  state -> constants . bool_pos [ 0 ] = -1;
  state -> constants . bool_pos [ 1 ] = -1;
  state -> globals . global_count = 0;
  state -> scope = NULL;
  //state -> fp = 0;
}

void bc_state_destroy ( BCCompilerState * state ) {
  // TODO ( free strings )
  for ( size_t i = 0; i < state -> constants . constant_count; ++i ) {
    BCConstant * constant = & state -> constants . constants [ i ]; 
    if ( constant -> kind == CONSTANT_FUNCTION )
      string_destroy ( & constant -> function . bc );
  }
}

void ast_to_bc ( BCCompilerState * state, Ast * ast ) {
  switch ( ast -> kind ) {
    case AST_TOP: {
      AstTop * top = (AstTop*) ast;
      state -> ep = state -> fp = state -> constants . constant_count;
      insert_constant ( state, create_function ( 1 ) );
      state -> scope = create_function_scope ( 1 );
      state -> scope -> prev = NULL;
      state -> scope -> used_locals = 1;
      state -> scope -> local_count = 1;
      for ( size_t i = 0; i < top -> expression_cnt; ++i )
        ast_to_bc ( state, top -> expressions [ i ] );
      free ( state -> scope );
      return;
    }
    case AST_NULL: {
      if ( state -> constants . null_pos < 0 ) {
        state -> constants . null_pos = state -> constants . constant_count;
        insert_constant ( state, (BCConstant) { .kind = CONSTANT_NULL } );
      }
      gen_bc_constant ( state, state -> constants . null_pos );
      return;
    }
    case AST_INTEGER: {
      AstInteger * integer = (AstInteger*) ast;
      u16 i;
      for ( i = 0; i < state -> constants . constant_count; ++i )
        if ( state -> constants . constants [ i ] . kind == CONSTANT_INTEGER && state -> constants . constants [ i ] . integer == integer -> value )
          break;
      if ( i == state -> constants . constant_count )
        insert_constant ( state, (BCConstant) { .kind = CONSTANT_INTEGER, .integer = integer -> value } );
      gen_bc_constant ( state, i );
      return;
    }
    case AST_BOOLEAN: {
      AstBoolean * boolean = (AstBoolean*) ast;
      size_t pos = boolean -> value ? 1 : 0;
      if ( state -> constants . bool_pos [ pos ] < 0 ) {
        state -> constants . bool_pos [ pos ] = state -> constants . constant_count;
        insert_constant ( state, (BCConstant) { .kind = CONSTANT_BOOLEAN, .boolean = boolean -> value } );
      }
      gen_bc_constant ( state, state -> constants . bool_pos [ pos ] );
      return;
    }
    case AST_BLOCK: {
      AstBlock * block = (AstBlock*) ast;
      // make local scope
      LocalScope * scope = (LocalScope*) malloc ( sizeof (LocalScope*) + sizeof (u32) + sizeof (Str) * MAX_SCOPE_VARIABLES );
      scope -> local_count = state -> scope -> local_count;
      scope -> used_locals = 0;
      scope -> prev = state -> scope;
      state -> scope = scope;
      state -> top = scope;
      BCFunction * function = & state -> constants . constants [ state -> fp ] . function;
      for ( size_t i = 0; i < block -> expression_cnt; ++i ) {
        ast_to_bc ( state, block -> expressions [ i ] );
        if ( i != block -> expression_cnt - 1 )
          string_write_byte ( & function -> bc, 0x00 );
      }
      // delete local scope
      state -> scope = state -> scope -> prev;
      free ( scope );
      return;
    }
    case AST_FUNCTION: {
      AstFunction * func = (AstFunction*) ast;
      LocalScope * old_scope = state -> scope;
      u16 old_ep = state -> ep;
      state -> ep = state -> fp = state -> constants . constant_count;
      insert_constant ( state, create_function ( func -> parameter_cnt ) );
      state -> scope = create_function_scope ( func -> parameter_cnt );
      state -> scope -> prev = NULL;
      state -> scope -> used_locals = func -> parameter_cnt;
      state -> scope -> local_count = func -> parameter_cnt;
      for ( u16 i = 1; i < func -> parameter_cnt; ++i )
        state -> scope -> locals [ i ] = func -> parameters [ i ];
      ast_to_bc ( state, func -> body );
      free ( state -> scope );
      state -> ep = old_ep;
      state -> scope = old_scope;
      return;            
    }
    case AST_FUNCTION_CALL: {
      AstFunctionCall * call = (AstFunctionCall*) ast;
      ast_to_bc ( state, call -> function );
      for ( size_t i = 0; i < call -> argument_cnt; ++i )
        ast_to_bc ( state, call -> arguments [ i ] );
      BCFunction * function = & state -> constants . constants [ state -> fp ] . function;
      string_write_byte ( & function -> bc, 0x08 );
      string_write_byte ( & function -> bc, call -> argument_cnt & 255 );    
      return;  
    }
    case AST_PRINT: {
      AstPrint * print = (AstPrint*) ast;
      u16 index = get_string_index ( state, print -> format );
      for ( size_t i = 0; i < print -> argument_cnt; ++i )
        ast_to_bc ( state, print -> arguments [ i ] );
      BCFunction * function = & state -> constants . constants [ state -> fp ] . function;
      string_write_byte ( & function -> bc, 0x02 );
      string_write_u16 ( & function -> bc, index );
      string_write_byte ( & function -> bc, print -> argument_cnt & 255 );
      return;      
    }
    case AST_DEFINITION: {
      AstDefinition * def = (AstDefinition*) ast;
      ast_to_bc ( state, def -> value );
      if ( ! state -> scope -> prev ) {
        // GLOBAL
        u16 index = get_string_index ( state, def -> name );
        BCFunction * function = & state -> constants . constants [ state -> fp ] . function;
        u16 * count = & state -> globals . global_count;
        for ( size_t i = 0; i < *count; ++i )
          if ( state -> globals . names [ i ] == index )
            assert ( false );
        state -> globals . names [ *count ] = index;
        ++(*count);
        string_write_byte ( & function -> bc, 0x0B );
        string_write_u16 ( & function -> bc, index );
      } else {
        // LOCAL
        u16 index = get_local_index ( state, def -> name );
        BCFunction * function = & state -> constants . constants [ state -> fp ] . function;
        string_write_byte ( & function -> bc, 0x09 );
        string_write_u16 ( & function -> bc, index );
      }
      return;
    }
    case AST_VARIABLE_ACCESS: {
      AstVariableAccess * var = (AstVariableAccess*) ast;
      if ( state -> scope == state -> top ) {
        // GLOBAL
        u16 index = get_string_index ( state, var -> name );
        BCFunction * function = & state -> constants . constants [ state -> fp ] . function;
        u16 * count = & state -> globals . global_count;
        size_t i;
        for ( i = 0; i < *count; ++i )
          if ( state -> globals . names [ i ] == index )
            break;
        if ( i == *count )
          assert ( false );
        string_write_byte ( & function -> bc, 0x0C );
        string_write_u16 ( & function -> bc, index );
      } else {
        // LOCAL
        u16 index = get_local_index ( state, var -> name );
        BCFunction * function = & state -> constants . constants [ state -> fp ] . function;
        string_write_byte ( & function -> bc, 0x0A );
        string_write_u16 ( & function -> bc, index );
      }
      return;
    }
    case AST_VARIABLE_ASSIGNMENT: {
      AstVariableAssignment * assign = (AstVariableAssignment*) ast;
      ast_to_bc ( state, assign -> value );
      if ( state -> scope == state -> top ) {
        // GLOBAL
        u16 constants = state -> constants . constant_count; 
        u16 index = get_string_index ( state, assign -> name );
        assert ( constants == state -> constants . constant_count );
        BCFunction * function = & state -> constants . constants [ state -> fp ] . function;
        u16 * count = & state -> globals . global_count;
        size_t i;
        for ( i = 0; i < *count; ++i )
          if ( state -> globals . names [ i ] == index )
            break;
        if ( i == *count )
          assert ( false );
        string_write_byte ( & function -> bc, 0x0B );
        string_write_u16 ( & function -> bc, index );
      } else {
        // LOCAL
        u16 locals = state -> scope -> local_count;
        u16 index = get_local_index ( state, assign -> name );
        assert ( locals == state -> scope -> local_count );
        BCFunction * function = & state -> constants . constants [ state -> fp ] . function;
        string_write_byte ( & function -> bc, 0x09 );
        string_write_u16 ( & function -> bc, index );
      }
      return;
    }
    default:
      assert ( false );
  }
}

String generate_bc ( Ast * ast ) {
  // init
  String bc;
  string_init ( & bc );
  BCCompilerState bc_state;
  bc_state_init ( & bc_state );
  // generate internals (constants (exp. functions), globals)
  ast_to_bc ( &bc_state, ast );

  // header
  string_write_byte ( & bc, 'F' );
  string_write_byte ( & bc, 'M' );
  string_write_byte ( & bc, 'L' );
  string_write_byte ( & bc, '\n' );
  // fprintf ( stderr, "printed FML\n" );
  // constants to bc
  // fprintf ( stderr, "%d constants:\n", bc_state . constants . constant_count );
  string_write_u16  ( & bc, bc_state . constants . constant_count );
  for ( size_t i = 0; i < bc_state . constants . constant_count; ++i )
    string_write_constant ( & bc, bc_state . constants . constants [ i ] );
  // fprintf ( stderr, "\n" );
  // globals to bc
  // fprintf ( stderr, "%u globals:\n", bc_state . globals . global_count );
  string_write_u16 ( & bc, bc_state . globals . global_count );
  for ( size_t i = 0; i < bc_state . globals . global_count; ++i ) {
    // fprintf ( stderr, "%u, ",  bc_state . globals . names [ i ] );
    string_write_u16 ( & bc, bc_state . globals . names [ i ] );
  }
  // fprintf ( stderr, "\n" );
  // EP
  string_write_u16 ( & bc, bc_state . ep );
  // fprintf ( stderr, "EP=%u\n", bc_state . ep );

  // free
  bc_state_destroy ( & bc_state );

  // fprintf ( stderr, "len: %u, %.*s\n", bc . len,  (int) bc . len, bc . str );

  return bc;
}