#include <stdio.h>
#include <assert.h>

#include "heap.h"

#define HEADER "timestamp,event,heap\n"

void exit_handler ( GarbageCollector * gc ) {
  (void) gc;
  fprintf ( stderr, "Out of memory.\n" );
  exit ( 22 );
}

void * heap_alloc_alligned ( Heap * heap, size_t len, size_t align ) {
  size_t pos = (size_t) heap -> next;
  size_t rem = pos % align;
  if ( rem )
    heap -> next = heap -> next + align - rem;
  if ( heap -> next + len >= heap -> end )
    return NULL;  
  void * ret = heap -> next;
  heap -> next += len;
  return ret;
}

void * heap_alloc ( Heap * heap, size_t len ) {
  size_t pos = (size_t) heap -> next;
  size_t rem = pos % 8;
  if ( rem )
    heap -> next = heap -> next + 8 - rem;
  if ( heap -> next + len >= heap -> end ) {
    struct timespec ts;
    if ( heap -> log_file ) {
      if ( ! time_get ( &ts ) )
        exit(23);
      fprintf ( heap -> log_file, "%li,%c,%lu\n", ts . tv_nsec, 'B', heap -> next - heap -> begin );
    }
    heap -> full_mem_handler ( heap -> gc );
    if ( heap -> log_file ) {
      if ( ! time_get ( &ts ) )
        exit(23);
      fprintf ( heap -> log_file, "%li,%c,%lu\n", ts . tv_nsec, 'A', heap -> next - heap -> begin );
    }
    pos = (size_t) heap -> next;
    rem = pos % 8;
    if ( rem )
      heap -> next = heap -> next + 8 - rem;
  }  
  void * ret = heap -> next;
  heap -> next += len;
  return ret;
}

void heap_init ( Heap * heap, size_t heap_size, const char * file ) {
  heap -> begin = (u8*) malloc ( heap_size );
  heap -> next = heap -> begin;
  heap -> end = heap -> begin + heap_size;
  heap -> gc = NULL;
  if ( file ) {
    heap -> log_file = fopen ( file, "w" );
    fprintf ( heap -> log_file, HEADER );
    struct timespec ts;
    if ( ! time_get ( &ts ) )
      exit(23);
    fprintf ( heap -> log_file, "%li,%c,%u\n", ts . tv_nsec, 'S', 0 );
  } else
    heap -> log_file = NULL;
  heap -> full_mem_handler = exit_handler;
}

void heap_destroy ( Heap * heap ) {
  free ( heap -> begin );
  if ( heap -> log_file ) {
    struct timespec ts;
    if ( ! time_get ( &ts ) )
      exit(23);
    fprintf ( heap -> log_file, "%li,%c,%lu\n", ts . tv_nsec, 'E', heap -> next - heap -> begin );
    fclose ( heap -> log_file );
  }
}

Value * in_to_semispace ( GarbageCollector * gc, Value * value ) {
  u8 * addr = (u8*) value;
  // check if address is in to semispace
  if ( addr >= gc -> to -> begin && addr < gc -> to -> end )
    return value;
  assert ( addr >= gc -> from -> begin && addr < gc -> from -> end );
  // check if address contains forwarding pointer
  if ( value -> kind == VALUE_HEAP_POINTER )
    return (Value*) ((HeapPointerValue*) value ) -> to;
  // move value into to semispace
  Value * new_addr = copy_value ( gc, (Value*) value );
  HeapPointerValue * pointer = (HeapPointerValue*) value;
  *pointer = (HeapPointerValue) { .kind = (Value) { .kind = VALUE_HEAP_POINTER }, .to = new_addr };
  return new_addr;
}


Value * copy_value ( GarbageCollector * gc, Value * value ) {
  switch ( value -> kind ) {
    case VALUE_INTEGER:
      return make_int ( gc -> to, ((IntValue *) value) -> integer);
    case VALUE_BOOLEAN:
      return make_bool ( gc -> to, ((BoolValue *) value) -> boolean);
    case VALUE_NULL:
      return make_null ( gc -> to );  
    case VALUE_ARRAY: {
      ArrayValue * arr = (ArrayValue*) value;
      ArrayValue * moved = (ArrayValue*) make_array ( gc -> to, arr -> length );
      for ( int i = 0; i < arr -> length; ++i )
        moved -> elements [ i ] = arr -> elements [ i ];
      return (Value*) moved;
    }
    case VALUE_FUNCTION: {
      CFunctionValue * func = (CFunctionValue*) value;
      return make_cfunction ( gc -> to, func -> function );
    }
    case VALUE_OBJECT: {
      ObjectValue * obj = (ObjectValue*) value;
      ObjectValue * moved = (ObjectValue*) make_object ( gc -> to, obj -> member_cnt );
      moved -> extends = obj -> extends;
      for ( size_t i = 0; i < obj -> member_cnt; ++i )
        moved -> members [ i ] = obj -> members [ i ]; 
      return (Value*) moved;
    }
    default:
      assert (false);
  }
  return NULL;
}

void gc_collect ( GarbageCollector * gc ) {
  Value * value;
  size_t rem;
  u8 * addr = gc -> to -> begin;
  while ( addr != gc -> to -> next ) {
    value = (Value*) addr;
    switch ( value -> kind ) {
      case VALUE_ARRAY: {
        ArrayValue * arr = (ArrayValue*) value;
        for ( int i = 0; i < arr -> length; ++i )
          arr -> elements [ i ] = in_to_semispace ( gc, arr -> elements [ i ] );
        addr += sizeof (ArrayValue) + arr -> length * sizeof (Value*);
        break;
      }
      case VALUE_OBJECT: {
        ObjectValue * obj = (ObjectValue*) value;
        if ( obj -> extends )
          obj -> extends = in_to_semispace ( gc, obj -> extends );
        for ( size_t i = 0; i < obj -> member_cnt; ++i )
          if ( obj -> members [ i ] . value )
            obj -> members [ i ] . value = in_to_semispace ( gc, obj -> members [ i ] . value ); 
        addr += sizeof (ObjectValue) + obj -> member_cnt * sizeof (SimpleEntry);
        break;
      }
      case VALUE_FUNCTION:
        addr += sizeof (CFunctionValue) + 8;
        break;
      case VALUE_INTEGER:
        addr += sizeof (IntValue) + 8;
        break;
      case VALUE_BOOLEAN:
        addr += sizeof (BoolValue) + 8;
        break;
      case VALUE_NULL:
        addr += sizeof (NullValue) + 8;
        break;
      default:
        fprintf ( stderr, "Tainted address in to semispace.\n" );
        exit ( 21 );
    }
    rem = (size_t) addr % ALIGN;
    if ( rem )
      addr += ALIGN - rem;
  }
  // swap semispaces
  u8 * tmp = gc -> from -> begin;
  gc -> from -> begin = gc -> to -> begin;
  gc -> to -> begin = tmp;
  tmp = gc -> from -> end;
  gc -> from -> end = gc -> to -> end;
  gc -> to -> end = tmp;
  gc -> from -> next = gc -> to -> next;
  gc -> to -> next = gc -> to -> begin;
}

int compare_entry ( const void * a, const void * b ) {
  return str_cmp ( ((SimpleEntry*) a) -> name, ((SimpleEntry*) b) -> name );
}

bool value_to_bool ( Value * value ) {
  if ( value -> kind == VALUE_NULL )
    return false;
  if ( value -> kind == VALUE_BOOLEAN )
    return ((BoolValue*) value ) -> boolean;
  return true;
}

Value * try_operator ( Heap * heap, Value * object, Value ** arguments, size_t argc, Str * name ) {
  switch ( object -> kind ) {
    case VALUE_INTEGER: {
      if ( argc != 1 ) {
        fprintf ( stderr, "Invalid argument count for integer operation %.*s.\n", (int) name -> len, name -> str );
        exit ( 10 );
      }
      IntValue * this = (IntValue*) object;
      bool is_int = arguments [ 0 ] -> kind == VALUE_INTEGER;
      if ( is_int ) {
        IntValue * other = (IntValue*) arguments [ 0 ];
        if ( str_eq ( *name, STR ("+") ) )
          return make_int ( heap, this -> integer + other -> integer );
        if ( str_eq ( *name, STR ("-") ) )
          return make_int ( heap, this -> integer - other -> integer );
        if ( str_eq ( *name, STR ("*") ) )
          return make_int ( heap, this -> integer * other -> integer );
        if ( str_eq ( *name, STR ("/") ) )
          return make_int ( heap, this -> integer / other -> integer );
        if ( str_eq ( *name, STR ("%") ) )
          return make_int ( heap, this -> integer % other -> integer );
        if ( str_eq ( *name, STR ("<=") ) )
          return make_bool ( heap, this -> integer <= other -> integer );
        if ( str_eq ( *name, STR (">=") ) )
          return make_bool ( heap, this -> integer >= other -> integer );
        if ( str_eq ( *name, STR ("<") ) )
          return make_bool ( heap, this -> integer < other -> integer );
        if ( str_eq ( *name, STR (">") ) )
          return make_bool ( heap, this -> integer > other -> integer );
        if ( str_eq ( *name, STR ("==") ) )
          return make_bool ( heap, this -> integer == other -> integer );
        if ( str_eq ( *name, STR ("!=") ) )
          return make_bool ( heap, this -> integer != other -> integer );  
        break;
      } 
      if ( str_eq ( *name, STR ("==") ) )
        return make_bool ( heap, false );
      if ( str_eq ( *name, STR ("!=") ) )
        return make_bool ( heap, true );
      break;
    }
    case VALUE_BOOLEAN: {
      if ( argc != 1 ) {
        fprintf ( stderr, "Invalid argument count for bool operation %.*s.\n", (int) name -> len, name -> str );
        exit ( 10 );
      }
      BoolValue * this = (BoolValue*) object;
      bool is_bool = arguments [ 0 ] -> kind == VALUE_BOOLEAN;
      if ( is_bool ) {
        BoolValue * other = (BoolValue*) arguments [ 0 ];
        if ( str_eq ( *name, STR ("&") ) )
          return make_bool ( heap, this -> boolean & other -> boolean );
        if ( str_eq ( *name, STR ("|") ) )
          return make_bool ( heap, this -> boolean | other -> boolean );
        if ( str_eq ( *name, STR ("==") ) )
          return make_bool ( heap, this -> boolean == other -> boolean );
        if ( str_eq ( *name, STR ("!=") ) )
          return make_bool ( heap, this -> boolean != other -> boolean );
        break;
      }
      if ( str_eq ( *name, STR ("==") ) )
        return make_bool ( heap, false );
      if ( str_eq ( *name, STR ("!=") ) )
        return make_bool ( heap, true );
      break;
    }
    case VALUE_NULL: {
      if ( argc != 1 ) {
        fprintf ( stderr, "Invalid amount of arguments for null operation %.*s.\n", (int) name -> len, name -> str );
        exit ( 10 );
      }
      if ( str_eq ( *name, STR ("==") ) )
        return make_bool ( heap, arguments [ 0 ] -> kind == VALUE_NULL ? true : false );
      if ( str_eq ( *name, STR ("!=") ) )
        return make_bool ( heap, arguments [ 0 ] -> kind != VALUE_NULL ? true : false );
      break;
    }
    case VALUE_ARRAY:
      if ( str_eq ( *name, STR ("get") ) ) {
        if ( argc != 1 || arguments [ 0 ] -> kind != VALUE_INTEGER ) {
          fprintf ( stderr, "Invalid argument for array get.\n" );
          exit ( 11 );
        }
        ArrayValue * array = (ArrayValue*) object;
        i32 index = ((IntValue*) arguments [ 0 ]) -> integer;
        if ( index < 0 || index >= array -> length ) {
          fprintf ( stderr, "Index is out of bounds.\n" );
          exit ( 11 );
        }
        return array -> elements [ index ];
      }
      if ( str_eq ( *name, STR ("set") ) ) {
        if ( argc != 2 || arguments [ 0 ] -> kind != VALUE_INTEGER ) {
          fprintf ( stderr, "Invalid arguments for array set.\n" );
          exit ( 12 );
        }
        ArrayValue * array = (ArrayValue*) object;
        i32 index = ((IntValue*) arguments [ 0 ]) -> integer;
        if ( index < 0 || index >= array -> length ) {
          fprintf ( stderr, "Index is out of bounds.\n" );
          exit ( 12 );
        }
        array -> elements [ index ] = arguments [ 1 ];
        return arguments [ 1 ];
      }
      break;
    default:
      break;
  }
  return NULL;
}

Value * get_base ( Value * object ) {
  Value * curr = object;
  while ( curr -> kind == VALUE_OBJECT )
    curr = ((ObjectValue *) curr) -> extends;
  return curr;
}

Value ** get_object_field ( Value * object, Str name ) {
  Value * next = object;
  ObjectValue * curr;
  do {
    curr = (ObjectValue*) next;
    for ( size_t i = 0; i < curr -> member_cnt; ++i )
      if ( str_eq ( curr -> members [ i ] . name, name ) )
        return & curr -> members [ i ] . value;
    next = curr -> extends;
  } while ( next -> kind == VALUE_OBJECT );
  return NULL;
}

Value * find_current_object_field ( Value * object, Str name ) {
  ObjectValue * curr = (ObjectValue*) object;
  for ( size_t i = 0; i < curr -> member_cnt; ++i )
    if ( str_eq ( curr -> members [ i ] . name, name ) )
      return curr -> members [ i ] . value;
  return NULL;
}

Value * make_int ( Heap * heap, i32 val ) {
  IntValue * value = heap_alloc ( heap, sizeof (IntValue) + 8 );
  *value = (IntValue) { .kind = (Value) { .kind = VALUE_INTEGER }, .integer = val };
  return (Value*) value;
}

Value * make_bool ( Heap * heap, bool val ) {
  BoolValue * value = heap_alloc ( heap, sizeof (BoolValue) + 8 );
  *value = (BoolValue) { .kind = (Value) { .kind = VALUE_BOOLEAN }, .boolean = val };
  return (Value*) value;
}

Value * make_function ( Heap * heap, AstFunction * function ) {
  FunctionValue * value = heap_alloc ( heap, sizeof (FunctionValue) );
  *value = (FunctionValue) { .kind = (Value) { .kind = VALUE_FUNCTION }, .function = function };
  return (Value*) value;
}

Value * make_cfunction ( Heap * heap, ConstantFunction function ) {
  CFunctionValue * value = heap_alloc ( heap, sizeof (CFunctionValue) + 8 );
  *value = (CFunctionValue) { .kind = (Value) { .kind = VALUE_FUNCTION }, .function = function };
  return (Value*) value;
}

Value * make_array ( Heap * heap, size_t len ) {
  ArrayValue * value = heap_alloc ( heap, sizeof (ArrayValue) + len * sizeof (Value*) );
  *value = (ArrayValue) { .kind = (Value) { .kind = VALUE_ARRAY }, .length = len };
  return (Value*) value;
}

Value * make_object ( Heap * heap, size_t member_cnt ) {
  ObjectValue * value = heap_alloc ( heap, sizeof (ObjectValue) + member_cnt * sizeof (SimpleEntry) );
  *value = (ObjectValue) { .kind = (Value) { .kind = VALUE_OBJECT }, .member_cnt = member_cnt };
  return (Value*) value;
}

Value * make_null ( Heap * heap ) {
  NullValue * value = heap_alloc ( heap, sizeof (NullValue) + 8 );
  *value = (NullValue) { .kind = (Value) { .kind = VALUE_NULL } };
  return (Value*) value;
}