#include <stdlib.h>
#include <stdio.h>
#include <inttypes.h>

#include "parser.h"
#include "ast_interpreter.h"

bool value_to_bool ( Value value ) {
  if ( value . kind == VALUE_NULL )
    return false;
  if ( value . kind == VALUE_BOOLEAN )
    return value . boolean;
  return true;
}

void env_push ( ASTInterpreterState * state ) {
  Environment * env = (Environment*) malloc ( sizeof (Environment) );
  (*env) = (Environment) {.prev = state -> current_env, .start = NULL };
  state -> current_env = env;
}

void env_pop ( ASTInterpreterState * state ) {
  Environment * tmp = state -> current_env;
  EnvironmentEntry * entry;
  EnvironmentEntry * next = tmp -> start;
  while ( (entry = next) ) {
    next = entry -> next;
    free ( entry );
  }
  state -> current_env = tmp -> prev;
  free (tmp);
}

void env_put ( ASTInterpreterState * state, Str name, Value value ) {
  Environment * env = state -> current_env;
  EnvironmentEntry * entry;
  while ( env ) {
    entry = env -> start;
    while ( entry ) {
      if ( str_eq ( entry -> name, name ) ) {
        entry -> value = value;
        return;
      }
      entry = entry -> next;
    }
    env = env -> prev;
  }
  entry = state -> global_env . start;
  while ( entry ) {
    if ( str_eq ( entry -> name, name ) ) {
      entry -> value = value;
      return;
    }
    entry = entry -> next;
  }
  env = state -> current_env;
  env = env ? env : & state -> global_env;
  entry = (EnvironmentEntry*) malloc ( sizeof (EnvironmentEntry) );
  (*entry) = (EnvironmentEntry) {.name = name, .value = value, .next = env -> start };
  env -> start = entry;
  // fprintf ( stderr, "Variable with name %.*s does not exist.\n", (int) name . len, name . str );
  // exit ( 1 );
}

void env_def ( ASTInterpreterState * state, Str name, Value value ) {
  Environment * env = state -> current_env ? state -> current_env : & state -> global_env;
  EnvironmentEntry * parent = NULL;
  EnvironmentEntry * entry = env -> start;
  while ( entry ) {
    if ( str_eq ( name, entry -> name ) ) {
      entry -> value = value;
      return;
      //fprintf ( stderr, "Variable with name %.*s already exists.\n", (int) name . len, name . str );
      //exit ( 1 );
    }
    parent = entry;
    entry = entry -> next;
  }
  entry = (EnvironmentEntry*) malloc ( sizeof (EnvironmentEntry) );
  (*entry) = (EnvironmentEntry) { .name = name, .value = value, .next = NULL };
  if ( parent )
    parent -> next = entry;
  else
    env -> start = entry;
}

Value env_get ( ASTInterpreterState * state, Str name ) {
  Environment * env = state -> current_env;
  EnvironmentEntry * entry;
  while ( env ) {
    entry = env -> start;
    while ( entry ) {
      if ( str_eq ( entry -> name, name ) ) 
        return entry -> value;
      entry = entry -> next;
    }
    env = env -> prev;
  }
  entry = state -> global_env . start;
  while ( entry ) {
    if ( str_eq ( entry -> name, name ) ) 
      return entry -> value;
    entry = entry -> next;
  }
  // fprintf ( stderr, "Variable with name %.*s does not exist.\n", (int) name . len, name . str );
  // exit ( 1 );
  // return (Value) { .kind = VALUE_NULL };
  env = state -> current_env;
  env = env ? env : & state -> global_env;
  entry = (EnvironmentEntry*) malloc ( sizeof (EnvironmentEntry) );
  (*entry) = (EnvironmentEntry) {.name = name, .value = (Value) { .kind = VALUE_NULL }, .next = env -> start };
  env -> start = entry;
  return entry -> value;
}

void * heap_alloc ( Heap * heap, size_t len, size_t align ) {
  size_t pos = (size_t) heap -> next;
  size_t rem = pos % align;
  if ( rem )
    heap -> next = heap -> next + align - rem;

  if ( heap -> next >= heap -> end )
    return NULL;
  
  void * ret = heap -> next;
  heap -> next += len;
  return ret;
}

void heap_init ( Heap * heap, size_t heap_size ) {
  heap -> begin = (u8*) malloc ( heap_size );
  heap -> next = heap -> begin;
  heap -> end = heap -> begin + heap_size;
}

void heap_destroy ( Heap * heap ) {
  free ( heap -> begin );
}

void state_init ( ASTInterpreterState * state, Heap * heap ) {
  state -> heap = heap; 
  state -> global_env . prev = NULL;
  state -> global_env . start = NULL;
  state -> current_env = NULL;
}

void state_destroy ( ASTInterpreterState * state ) {
  heap_destroy ( state -> heap );
  Environment * env = & state -> global_env;
  EnvironmentEntry * entry;
  EnvironmentEntry * next = env -> start;
  while ( (entry = next) ) {
    next = entry -> next;
    free ( entry );
  }
  Environment * next_env = state -> current_env;
  while ( (env = next_env) ) {
    next = env -> start;
    while ( (entry = next) ) {
      next = entry -> next;
      free ( entry );
    }
    next_env = env -> prev;
  }
}

Value * get_base ( Value * object ) {
  Value * curr = object;
  Object * tmp;
  while ( curr -> kind == VALUE_OBJECT ) {
    tmp = (Object *) curr -> address;
    curr = &tmp -> extends;
  }
  return curr;
}

Value try_operator ( Value object, Value * arguments, size_t argc, Str * name ) {
  switch ( object . kind ) {
    case VALUE_INTEGER: {
      if ( argc != 1 ) {
        fprintf ( stderr, "Invalid argument count for integer operation %.*s.\n", (int) name -> len, name -> str );
        exit ( 10 );
      }
      bool is_int = arguments [ 0 ] . kind == VALUE_INTEGER;
      if ( is_int ) {
        if ( str_eq ( *name, STR ("+") ) )
          return (Value) { .kind = VALUE_INTEGER, .integer = object . integer + arguments [ 0 ] . integer };
        if ( str_eq ( *name, STR ("-") ) )
          return (Value) { .kind = VALUE_INTEGER, .integer = object . integer - arguments [ 0 ] . integer };
        if ( str_eq ( *name, STR ("*") ) )
          return (Value) { .kind = VALUE_INTEGER, .integer = object . integer * arguments [ 0 ] . integer };
        if ( str_eq ( *name, STR ("/") ) )
          return (Value) { .kind = VALUE_INTEGER, .integer = object . integer / arguments [ 0 ] . integer };
        if ( str_eq ( *name, STR ("%") ) )
          return (Value) { .kind = VALUE_INTEGER, .integer = object . integer % arguments [ 0 ] . integer };
      }
      if ( str_eq ( *name, STR ("<=") ) )
        return (Value) { .kind = VALUE_BOOLEAN, .boolean = is_int ? object . integer <= arguments [ 0 ] . integer : false };
      if ( str_eq ( *name, STR (">=") ) )
        return (Value) { .kind = VALUE_BOOLEAN, .boolean = is_int ? object . integer >= arguments [ 0 ] . integer : false };
      if ( str_eq ( *name, STR ("<") ) )
        return (Value) { .kind = VALUE_BOOLEAN, .boolean = is_int ? object . integer < arguments [ 0 ] . integer : false };
      if ( str_eq ( *name, STR (">") ) )
        return (Value) { .kind = VALUE_BOOLEAN, .boolean = is_int ? object . integer > arguments [ 0 ] . integer : false };
      if ( str_eq ( *name, STR ("==") ) )
        return (Value) { .kind = VALUE_BOOLEAN, .boolean = is_int ? object . integer == arguments [ 0 ] . integer : false };
      if ( str_eq ( *name, STR ("!=") ) )
        return (Value) { .kind = VALUE_BOOLEAN, .boolean = is_int ? object . integer != arguments [ 0 ] . integer : false };
      break;
    }
    case VALUE_BOOLEAN: {
      if ( argc != 1 ) {
        fprintf ( stderr, "Invalid argument count for bool operation %.*s.\n", (int) name -> len, name -> str );
        exit ( 10 );
      }
      bool is_bool = arguments [ 0 ] . kind == VALUE_BOOLEAN;
      if ( str_eq ( *name, STR ("&") ) )
        return (Value) { .kind = VALUE_BOOLEAN, .boolean = is_bool ? object . boolean &  arguments [ 0 ] . boolean : false };
      if ( str_eq ( *name, STR ("|") ) )
        return (Value) { .kind = VALUE_BOOLEAN, .boolean = is_bool ? object . boolean |  arguments [ 0 ] . boolean : false };
      if ( str_eq ( *name, STR ("==") ) )
        return (Value) { .kind = VALUE_BOOLEAN, .boolean = is_bool ? object . boolean == arguments [ 0 ] . boolean : false };
      if ( str_eq ( *name, STR ("!=") ) )
        return (Value) { .kind = VALUE_BOOLEAN, .boolean = is_bool ? object . boolean != arguments [ 0 ] . boolean : false };
      break;
    }
    case VALUE_NULL: {
      if ( argc != 1 ) {
        fprintf ( stderr, "Invalid amount of arguments for null operation %.*s.\n", (int) name -> len, name -> str );
        exit ( 10 );
      }
      if ( str_eq ( *name, STR ("==") ) )
        return (Value) { .kind = VALUE_BOOLEAN, .boolean = arguments [ 0 ] . kind == VALUE_NULL };
      if ( str_eq ( *name, STR ("!=") ) )
        return (Value) { .kind = VALUE_BOOLEAN, .boolean = arguments [ 0 ] . kind != VALUE_NULL };
      break;
    }
    case VALUE_ARRAY:
      if ( str_eq ( *name, STR ("get") ) ) {
        if ( argc != 1 || arguments [ 0 ] . kind != VALUE_INTEGER ) {
          fprintf ( stderr, "Invalid argument for array get.\n" );
          exit ( 11 );
        }
        Value * data = (Value*) object . address;
        i32 index = arguments [ 0 ] . integer;
        if ( index < 0 || index >= data [ 0 ] . integer ) {
          fprintf ( stderr, "Index is out of bounds.\n" );
          exit ( 11 );
        }
        return data [ index + 1 ];
      }
      if ( str_eq ( *name, STR ("set") ) ) {
        if ( argc != 2 || arguments [ 0 ] . kind != VALUE_INTEGER ) {
          fprintf ( stderr, "Invalid arguments for array set.\n" );
          exit ( 12 );
        }
        Value * data = (Value*) object . address;
        i32 index = arguments [ 0 ] . integer;
        if ( index < 0 || index >= data [ 0 ] . integer ) {
          fprintf ( stderr, "Index is out of bounds.\n" );
          exit ( 12 );
        }
        data [ index + 1 ] = arguments [ 1 ];
        return data [ index + 1 ];
      }
      break;
    default:
      break;
  }
  //fprintf ( stderr, "Method %.*s does not exist.\n", (int) name -> len, name -> str );
  //exit ( 13 );
  return (Value) { .kind = VALUE_INVALID };
}

Value function_call ( ASTInterpreterState * state, Value callee, bool is_function, Ast ** arguments, size_t argc, Str * name ) {
  Value * values = (Value *) malloc ( argc * sizeof (Value) );
  for ( size_t i = 0; i < argc; ++i )
    values [ i ] = evaluate ( state, arguments [ i ] );
  Value ret = (Value) { .kind = VALUE_NULL };
  Value function = callee;
  if ( ! is_function ) {
    function = (Value) { .kind = VALUE_NULL };
    Value * next = &callee;
    Value * curr;
    Object * tmp;
    while ( next -> kind == VALUE_OBJECT ) {
      curr = next;
      tmp = (Object*) curr -> address;
      for ( size_t i = 0; i < tmp -> member_cnt; ++i )
        if ( str_eq ( tmp -> members [ i ] . name, *name ) ) {
          function = tmp -> members [ i ] . value;
          callee = *curr;
          break;
        }
      next = & tmp -> extends;
    }
    if ( function . kind == VALUE_NULL ) {
      curr = &callee;
      Value * base = get_base ( &callee );
      //while ( curr ) {
      ret = try_operator ( *base, values, argc, name );
      if ( ret . kind != VALUE_INVALID ) {
        free ( values );
        return ret;
      }
      //  if ( curr -> kind == VALUE_OBJECT ) {
      //    tmp = (Object *) curr -> address;
      //    curr = &tmp -> extends;
      //  } else 
      //    break;
      //}
      free ( values );
      fprintf ( stderr, "Method %.*s does not exist for this object and/or arguments.\n", (int) name -> len, name -> str );
      exit ( 13 );
    }
  } if ( function . kind != VALUE_FUNCTION ) {
    fprintf ( stderr, "Invalid callee.\n" );
    exit ( 4 );
  }
  Environment * tmp;
  //if ( is_function ) {
    tmp = state -> current_env;
    state -> current_env = NULL;
  //}
  env_push ( state );
  env_def ( state, STR ("this"), is_function ? (Value) { .kind = VALUE_NULL } : callee );
  for ( size_t i = 0; i < argc; ++i )
    env_def ( state, function . function -> parameters [ i ], values [ i ] );
  ret = evaluate ( state, function . function -> body );
  env_pop ( state );
  //if ( is_function )
    state -> current_env = tmp;
  
  free ( values );
  return ret;
}

Value * get_object_field ( Value * object, Str name ) {
  Value * next = object;
  Object * curr;
  do {
    curr = (Object*) next -> address;
    for ( size_t i = 0; i < curr -> member_cnt; ++i )
      if ( str_eq ( curr -> members [ i ] . name, name ) )
        return & curr -> members [ i ] . value;
    next = & curr -> extends;
  } while ( next -> kind == VALUE_OBJECT );
  return NULL;
}

int compare_entry ( const void * a, const void * b ) {
  return str_cmp ( ((SimpleEntry*) a) -> name, ((SimpleEntry*) b) -> name );
}

void print_value ( Value value ) {
  switch ( value . kind ) {
    case VALUE_INTEGER:
      printf ( "%" PRIi32, value . integer );
      break;
    case VALUE_BOOLEAN:
      if ( value . boolean )
        printf ( "true" );
      else
        printf ( "false" );
      break;
    case VALUE_NULL:
      printf ( "null" );
      break;
    case VALUE_FUNCTION:
      printf ( "function" );
      break;
    case VALUE_ARRAY:
      putchar ( '[' );
      Value * addr = (Value*) value . address;
      for ( i32 i = 1; i <= addr [ 0 ] . integer; ++i ) {
        if ( i != 1 )
          printf ( ", " );
        print_value ( addr [ i ] );
      }
      putchar ( ']' );
      break;
    case VALUE_OBJECT:
      printf ( "object(");
      Object * obj = (Object*) value . address;
      bool parent = false;
      if ( obj -> extends . kind != VALUE_NULL ) {
        parent = true;
        printf ( "..=" );
        print_value ( obj -> extends );
      }
      if ( obj -> member_cnt )
        qsort ( obj -> members, obj -> member_cnt, sizeof (SimpleEntry), compare_entry );
      for ( size_t i = 0; i < obj -> member_cnt; ++i ) {
        if ( i != 0 || parent )
          printf ( ", " );
        printf ( "%.*s=", (int) obj -> members [ i ] . name . len, obj -> members [ i ] . name . str );
        print_value ( obj -> members [ i ] . value );
      }
      putchar ( ')' );
      break;
    case VALUE_INVALID:
      fprintf ( stderr, "Invalid value.\n" );
  }
}

void fml_print ( Str format, Value * args, size_t argc ) {
  u8 c;
  size_t arg = 0;
  for ( size_t i = 0; i < format . len; ++i ) {
    c = format . str [ i ];
    if ( c == '\\' ) {
      ++i;
      c = format . str [ i ];
      switch ( c ) {
        case '~':
          putchar ( '~' );
          break;
        case 'n':
          putchar ( '\n' );
          break;
        case '"':
          putchar ( '"' );
          break;
        case 'r':
          putchar ( '\r' );
          break;
        case 't':
          putchar ( '\t' );
          break;
        case '\\':
          putchar ( '\\' );
          break;
        default:
          fprintf ( stderr, "Invalid escaped character %c.\n", c );
          exit ( 1 );
      }
    } else if ( c == '~' ) {
      if ( arg == argc ) {
        fprintf ( stderr, "Too many placeholders, not enough arguments.\n" );
        exit ( 1 );
      }
      print_value ( args [ arg++ ] );
    } else
      putchar ( c );
  }
}

Value evaluate ( ASTInterpreterState * state, Ast * ast ) {
  switch ( ast -> kind ) {
    case AST_INTEGER:
      return (Value) { .kind = VALUE_INTEGER, .integer = ((AstInteger*) ast ) -> value };
    case AST_BOOLEAN:
      return (Value) { .kind = VALUE_BOOLEAN, .boolean = ((AstBoolean*) ast ) -> value };
    case AST_NULL:
      return (Value) { .kind = VALUE_NULL };
    case AST_PRINT: {
      AstPrint * printAst = (AstPrint*) ast;
      Value * args = (Value *) malloc ( sizeof (Value) * printAst -> argument_cnt );
      for ( size_t i = 0; i < printAst -> argument_cnt; ++i ) 
        args [ i ] = evaluate ( state, printAst -> arguments [ i ] );
      fml_print ( printAst -> format, args, printAst -> argument_cnt );
      free ( args );
      return (Value) { .kind = VALUE_NULL };
    }
    case AST_TOP: {
      AstTop * topAst = (AstTop*) ast;
      Value val = evaluate ( state, topAst -> expressions [ 0 ] );
      for ( size_t i = 1; i < topAst -> expression_cnt; ++i )
        val = evaluate ( state, topAst -> expressions [ i ] );
      return val;
    }
    case AST_DEFINITION: {
      AstDefinition * defAst = (AstDefinition*) ast;
      Value value = evaluate ( state, defAst -> value );
      env_def ( state, defAst -> name, value );
      return value; 
    }
    case AST_VARIABLE_ACCESS: 
      return env_get ( state, ((AstVariableAccess*) ast ) -> name );
    case AST_VARIABLE_ASSIGNMENT: {
      AstVariableAssignment * assignAst = (AstVariableAssignment*) ast;
      Value value = evaluate ( state, assignAst -> value );
      env_put ( state, assignAst -> name, value );
      return value;
    }
    case AST_BLOCK: {
      AstBlock * blockAst = (AstBlock*) ast;
      env_push ( state );
      Value value = evaluate ( state, blockAst -> expressions [ 0 ] );
      for ( size_t i = 1; i < blockAst -> expression_cnt; ++i )
        value = evaluate ( state, blockAst -> expressions [ i ] );
      env_pop ( state );
      return value;
    }
    case AST_CONDITIONAL: {
      AstConditional * condAst = (AstConditional*) ast;
      Value cond = evaluate ( state, condAst -> condition );
      env_push ( state );
      Value value;
      Ast * cont = value_to_bool ( cond ) ? condAst -> consequent : condAst -> alternative;
      value = evaluate ( state, cont );
      env_pop ( state );
      return value;
    }
    case AST_LOOP: {
      AstLoop * loopAst = (AstLoop*) ast;
      while ( value_to_bool ( evaluate ( state, loopAst -> condition ) ) ) {
        env_push ( state );
        evaluate ( state, loopAst -> body );
        env_pop ( state );
      }
      return (Value) { .kind = VALUE_NULL };
    }
    case AST_FUNCTION:
      return (Value) { .kind = VALUE_FUNCTION, .function = (AstFunction*) ast };
    case AST_FUNCTION_CALL: {
      AstFunctionCall * callAst = (AstFunctionCall*) ast;
      Value function = evaluate ( state, callAst -> function );
      return function_call ( state, function, true, callAst -> arguments, callAst -> argument_cnt, NULL ); 
    }
    case AST_METHOD_CALL: {
      AstMethodCall * methodAst = (AstMethodCall*) ast;
      Value object = evaluate ( state, methodAst -> object );
      if ( object . kind == VALUE_FUNCTION ) {
        fprintf ( stderr, "Calling method on function.\n" );
        exit ( 4 );
      }
      return function_call ( state, object, false, methodAst -> arguments, methodAst -> argument_cnt, &methodAst -> name ); 
    }
    case AST_ARRAY: {
      AstArray * arrAst = (AstArray*) ast;
      Value size = evaluate ( state, arrAst -> size );
      if ( size . kind != VALUE_INTEGER || size . integer < 0 ) {
        fprintf ( stderr, "Invalid array size.\n" );
        exit ( 5 );
      }
      Value * addr = (Value*) heap_alloc ( state -> heap, sizeof (Value) * (size . integer + 1), alignof (Value) );
      addr [ 0 ] = (Value) { .kind = VALUE_INTEGER, .integer = size . integer };
      for ( i32 i = 1; i <= size . integer; ++i ) {
        env_push ( state );
        Value val = evaluate ( state, arrAst -> initializer );
        env_pop ( state );
        addr [ i ] = val;
      }
      return (Value) { .kind = VALUE_ARRAY, .address = addr };
    }
    case AST_INDEX_ACCESS: {
      AstIndexAccess * iaccess = (AstIndexAccess*) ast;
      Value object = evaluate ( state, iaccess -> object );
      Str tmp = STR ("get");
      return function_call ( state, object, false, &iaccess -> index, 1, &tmp );
    }
    case AST_INDEX_ASSIGNMENT: {
      AstIndexAssignment * iassign = (AstIndexAssignment*) ast;
      Value object = evaluate ( state, iassign -> object );
      Ast * arguments [2];
      arguments [ 0 ] = iassign -> index;
      arguments [ 1 ] = iassign -> value;
      Str tmp = STR ("set");
      return function_call ( state, object, false, arguments, 2, &tmp );
    }
    case AST_OBJECT: {
      AstObject * objAst = (AstObject*) ast;
      Value extends = evaluate ( state, objAst -> extends );
      Object * addr = (Object*) heap_alloc ( state -> heap, sizeof (Object), alignof (Object) );
      addr -> member_cnt = objAst -> member_cnt;
      addr -> extends = extends;
      addr -> members = (SimpleEntry*) heap_alloc ( state -> heap, sizeof (SimpleEntry) * addr -> member_cnt, alignof (SimpleEntry) );
      Value value;
      AstDefinition * defAst;
      for ( size_t i = 0; i < addr -> member_cnt; ++i ) {
        if ( objAst -> members [ i ] -> kind != AST_DEFINITION ) {
          fprintf ( stderr, "Invalid object member.\n" );
          exit ( 7 );
        }
        defAst = (AstDefinition*) objAst -> members [ i ];
        env_push ( state );
        value = evaluate ( state, objAst -> members [ i ] );
        env_pop ( state );
        addr -> members [ i ] . name  = defAst -> name;
        addr -> members [ i ] . value = value;
      }
      return (Value) { .kind = VALUE_OBJECT, .address = addr };      
    }
    case AST_FIELD_ACCESS: {
      AstFieldAccess * accessAst = (AstFieldAccess *) ast;
      Value object_value = evaluate ( state, accessAst -> object );
      if ( object_value . kind != VALUE_OBJECT ) {
        fprintf ( stderr, "Trying to access field of non object.\n" );
        exit ( 8 );
      }
      Value * result = get_object_field ( &object_value, accessAst -> field );
      if ( ! result ) {
        fprintf ( stderr, "Field %.*s was not found in specified object.\n", (int) accessAst -> field . len, accessAst -> field . str );
        exit ( 8 );
      }
      return *result;
    }
    case AST_FIELD_ASSIGNMENT: {
      AstFieldAssignment * assignAst = (AstFieldAssignment*) ast;
      Value object_value = evaluate ( state, assignAst -> object );
      if ( object_value . kind != VALUE_OBJECT ) {
        fprintf ( stderr, "Trying to assign to a field of non object.\n" );
        exit ( 8 );
      }
      Value value = evaluate ( state, assignAst -> value );
      Value * result = get_object_field ( &object_value, assignAst -> field );
      if ( ! result ) {
        fprintf ( stderr, "Field %.*s was not found in specified object.\n", (int) assignAst -> field . len, assignAst -> field . str );
        exit ( 8 );
      }
      *result = value;
      return value;
    }
  }
  return (Value) { .kind = VALUE_NULL };
}