#include <stdlib.h>
#include <stdio.h>
#include <inttypes.h>

#include "parser.h"
#include "ast_interpreter.h"

bool value_to_bool ( Value value ) {
  if ( value . kind == VALUE_NULL )
    return false;
  if ( value . kind == VALUE_BOOLEAN )
    return value . boolean;
  return true;
}

void env_push ( ASTInterpreterState * state ) {
  Environment * env = (Environment*) malloc ( sizeof (Environment) );
  (*env) = (Environment) {.prev = state -> current_env, .start = NULL };
  state -> current_env = env;
}

void env_pop ( ASTInterpreterState * state ) {
  Environment * tmp = state -> current_env;
  EnvironmentEntry * entry;
  EnvironmentEntry * next = tmp -> start;
  while ( (entry = next) ) {
    next = entry -> next;
    free ( entry );
  }
  state -> current_env = tmp -> prev;
  free (tmp);
}

void env_put ( ASTInterpreterState * state, Str name, Value value ) {
  Environment * env = state -> current_env;
  EnvironmentEntry * entry;
  while ( env ) {
    entry = env -> start;
    while ( entry ) {
      if ( str_eq ( entry -> name, name ) ) {
        entry -> value = value;
        return;
      }
      entry = entry -> next;
    }
    env = env -> prev;
  }
  entry = state -> global_env . start;
  while ( entry ) {
    if ( str_eq ( entry -> name, name ) ) {
      entry -> value = value;
      return;
    }
    entry = entry -> next;
  }
  fprintf ( stderr, "Variable with name %.*s does not exist.\n", (int) name . len, name . str );
  exit ( 1 );
}

void env_def ( ASTInterpreterState * state, Str name, Value value ) {
  Environment * env = state -> current_env ? state -> current_env : & state -> global_env;
  EnvironmentEntry * parent = NULL;
  EnvironmentEntry * entry = env -> start;
  while ( entry ) {
    if ( str_eq ( name, entry -> name ) ) {
      fprintf ( stderr, "Variable with name %.*s already exists.\n", (int) name . len, name . str );
      exit ( 1 );
    }
    parent = entry;
    entry = entry -> next;
  }
  entry = (EnvironmentEntry*) malloc ( sizeof (EnvironmentEntry) );
  (*entry) = (EnvironmentEntry) {.name = name, .value = value, .next = NULL};
  if ( parent )
    parent -> next = entry;
  else
    env -> start = entry;
}

Value env_get ( ASTInterpreterState * state, Str name ) {
  Environment * env = state -> current_env;
  EnvironmentEntry * entry;
  while ( env ) {
    entry = env -> start;
    while ( entry ) {
      if ( str_eq ( entry -> name, name ) ) 
        return entry -> value;
      entry = entry -> next;
    }
    env = env -> prev;
  }
  entry = state -> global_env . start;
  while ( entry ) {
    if ( str_eq ( entry -> name, name ) ) 
      return entry -> value;
    entry = entry -> next;
  }
  // fprintf ( stderr, "Variable with name %.*s does not exist.\n", (int) name . len, name . str );
  // exit ( 1 );
  // return (Value) { .kind = VALUE_NULL };
  env = state -> current_env;
  env = env ? env : & state -> global_env;
  entry = (EnvironmentEntry*) malloc ( sizeof (EnvironmentEntry) );
  (*entry) = (EnvironmentEntry) {.name = name, .value = (Value) { .kind = VALUE_NULL }, .next = env -> start };
  env -> start = entry;
  return entry -> value;
}

void * heap_alloc ( Heap * heap, size_t len, size_t align ) {
  size_t pos = (size_t) heap -> next;
  size_t rem = pos % align;
  if ( rem )
    heap -> next = heap -> next + align - rem;

  if ( heap -> next >= heap -> end )
    return NULL;
  
  void * ret = heap -> next;
  heap -> next += len;
  return ret;
}

void heap_init ( Heap * heap, size_t heap_size ) {
  heap -> begin = (u8*) malloc ( heap_size );
  heap -> next = heap -> begin;
  heap -> end = heap -> begin + heap_size;
}

void heap_destroy ( Heap * heap ) {
  free ( heap -> begin );
}

void state_init ( ASTInterpreterState * state, Heap * heap ) {
  state -> heap = heap; 
  state -> global_env . prev = NULL;
  state -> global_env . start = NULL;
  state -> current_env = NULL;
}

void state_destroy ( ASTInterpreterState * state ) {
  heap_destroy ( state -> heap );
  Environment * env = & state -> global_env;
  EnvironmentEntry * entry;
  EnvironmentEntry * next = env -> start;
  while ( (entry = next) ) {
    next = entry -> next;
    free ( entry );
  }
  Environment * next_env = state -> current_env;
  while ( (env = next_env) ) {
    next = env -> start;
    while ( (entry = next) ) {
      next = entry -> next;
      free ( entry );
    }
    next_env = env -> prev;
  }
}

Value function_call ( ASTInterpreterState * state, Value function, bool is_function, Ast ** arguments, size_t argc ) {
  Value * values = (Value *) malloc ( argc * sizeof (Value) );
  for ( size_t i = 0; i < argc; ++i )
    values [ i ] = evaluate ( state, arguments [ i ] );
  Value ret = (Value) { .kind = VALUE_NULL };
  if ( is_function ) {
    if ( function . kind != VALUE_FUNCTION ) {
      fprintf ( stderr, "Invalid calee.\n" );
      exit ( 4 );
    }
    Environment * tmp = state -> current_env;
    state -> current_env = NULL;
    env_push ( state );
    env_def ( state, STR ("this"), (Value) { .kind = VALUE_NULL} );
    for ( size_t i = 0; i < argc; ++i )
      env_def ( state, function . function -> parameters [ i ], values [ i ] );
    ret = evaluate ( state, function . function -> body );
    env_pop ( state );
    state -> current_env = tmp;
  }
  free ( values );
  return ret;
}

static Value * get_object_field ( Value * object, Str name ) {
  Value * next = object;
  Object * curr;
  do {
    curr = (Object*) next -> address;
    for ( size_t i = 0; i < curr -> member_cnt; ++i )
      if ( str_eq ( curr -> members [ i ] . name, name ) )
        return & curr -> members [ i ] . value;
    next = & curr -> extends;
  } while ( next -> kind == VALUE_OBJECT );
  fprintf ( stderr, "Field %.*s was not found in specified object.\n", (int) name . len, name . str );
  exit ( 8 );
}

void print_value ( Value value ) {
  switch ( value . kind ) {
    case VALUE_INTEGER:
      printf ( "%" PRIi32, value . integer );
      break;
    case VALUE_BOOLEAN:
      if ( value . boolean )
        printf ( "true" );
      else
        printf ( "false" );
      break;
    case VALUE_NULL:
      printf ( "null" );
      break;
    case VALUE_FUNCTION:
      printf ( "function" );
      break;
    case VALUE_ARRAY:
      putchar ( '[' );
      Value * addr = (Value*) value . address;
      for ( i32 i = 1; i <= addr [ 0 ] . integer; ++i ) {
        if ( i != 1 )
          printf ( ", " );
        print_value ( addr [ i ] );
      }
      putchar ( ']' );
      break;
    case VALUE_OBJECT:
      printf ( "object(");
      Object * obj = (Object*) value . address;
      bool parent = false;
      if ( obj -> extends . kind != VALUE_NULL ) {
        parent = true;
        printf ( "..=" );
        print_value ( obj -> extends );
      }
      for ( size_t i = 0; i < obj -> member_cnt; ++i ) {
        if ( i != 0 || parent )
          printf ( ", " );
        printf ( "%.*s=", (int) obj -> members [ i ] . name . len, obj -> members [ i ] . name . str );
        print_value ( obj -> members [ i ] . value );
      }
      putchar ( ')' );
      break;
  }
}

void fml_print ( Str format, Value * args, size_t argc ) {
  u8 c;
  size_t arg = 0;
  for ( size_t i = 0; i < format . len; ++i ) {
    c = format . str [ i ];
    if ( c == '\\' ) {
      ++i;
      c = format . str [ i ];
      switch ( c ) {
        case '~':
          putchar ( '~' );
          break;
        case 'n':
          putchar ( '\n' );
          break;
        case '"':
          putchar ( '"' );
          break;
        case 'r':
          putchar ( '\r' );
          break;
        case 't':
          putchar ( '\t' );
          break;
        case '\\':
          putchar ( '\\' );
          break;
        default:
          fprintf ( stderr, "Invalid escaped character %c.\n", c );
          exit ( 1 );
      }
    } else if ( c == '~' ) {
      if ( arg == argc ) {
        fprintf ( stderr, "Too many placeholders, not enough arguments.\n" );
        exit ( 1 );
      }
      print_value ( args [ arg++ ] );
    } else
      putchar ( c );
  }
}

Value evaluate ( ASTInterpreterState * state, Ast * ast ) {
  switch ( ast -> kind ) {
    case AST_INTEGER:
      return (Value) { .kind = VALUE_INTEGER, .integer = ((AstInteger*) ast ) -> value };
    case AST_BOOLEAN:
      return (Value) { .kind = VALUE_BOOLEAN, .boolean = ((AstBoolean*) ast ) -> value };
    case AST_NULL:
      return (Value) { .kind = VALUE_NULL };
    case AST_PRINT: {
      AstPrint * printAst = (AstPrint*) ast;
      Value * args = (Value *) malloc ( sizeof (Value) * printAst -> argument_cnt );
      for ( size_t i = 0; i < printAst -> argument_cnt; ++i ) 
        args [ i ] = evaluate ( state, printAst -> arguments [ i ] );
      fml_print ( printAst -> format, args, printAst -> argument_cnt );
      free ( args );
      return (Value) { .kind = VALUE_NULL };
    }
    case AST_TOP: {
      AstTop * topAst = (AstTop*) ast;
      Value val = evaluate ( state, topAst -> expressions [ 0 ] );
      for ( size_t i = 1; i < topAst -> expression_cnt; ++i )
        val = evaluate ( state, topAst -> expressions [ i ] );
      return val;
    }
    case AST_DEFINITION: {
      AstDefinition * defAst = (AstDefinition*) ast;
      Value value = evaluate ( state, defAst -> value );
      env_def ( state, defAst -> name, value );
      return value; 
    }
    case AST_VARIABLE_ACCESS: 
      return env_get ( state, ((AstVariableAccess*) ast ) -> name );
    case AST_VARIABLE_ASSIGNMENT: {
      AstVariableAssignment * assignAst = (AstVariableAssignment*) ast;
      Value value = evaluate ( state, assignAst -> value );
      env_put ( state, assignAst -> name, value );
      return value;
    }
    case AST_BLOCK: {
      AstBlock * blockAst = (AstBlock*) ast;
      env_push ( state );
      Value value = evaluate ( state, blockAst -> expressions [ 0 ] );
      for ( size_t i = 1; i < blockAst -> expression_cnt; ++i )
        value = evaluate ( state, blockAst -> expressions [ i ] );
      env_pop ( state );
      return value;
    }
    case AST_CONDITIONAL: {
      AstConditional * condAst = (AstConditional*) ast;
      Value cond = evaluate ( state, condAst -> condition );
      env_push ( state );
      Value value;
      Ast * cont = value_to_bool ( cond ) ? condAst -> consequent : condAst -> alternative;
      value = evaluate ( state, cont );
      env_pop ( state );
      return value;
    }
    case AST_LOOP: {
      AstLoop * loopAst = (AstLoop*) ast;
      while ( value_to_bool ( evaluate ( state, loopAst -> condition ) ) ) {
        env_push ( state );
        evaluate ( state, loopAst -> body );
        env_pop ( state );
      }
      return (Value) { .kind = VALUE_NULL };
    }
    case AST_FUNCTION:
      return (Value) { .kind = VALUE_FUNCTION, .function = (AstFunction*) ast };
    case AST_FUNCTION_CALL: {
      AstFunctionCall * callAst = (AstFunctionCall*) ast;
      Value function = evaluate ( state, callAst -> function );
      return function_call ( state, function, true, callAst -> arguments, callAst -> argument_cnt ); 
    }
    case AST_ARRAY: {
      AstArray * arrAst = (AstArray*) ast;
      Value size = evaluate ( state, arrAst -> size );
      if ( size . kind != VALUE_INTEGER || size . integer < 0 ) {
        fprintf ( stderr, "Invalid array size.\n" );
        exit ( 5 );
      }
      Value * addr = (Value*) heap_alloc ( state -> heap, sizeof (Value) * (size . integer + 1), alignof (Value) );
      addr [ 0 ] = (Value) { .kind = VALUE_INTEGER, .integer = size . integer };
      for ( i32 i = 1; i <= size . integer; ++i ) {
        Value val = evaluate ( state, arrAst -> initializer );
        addr [ i ] = val;
      }
      return (Value) { .kind = VALUE_ARRAY, .address = addr };
    }
    case AST_OBJECT: {
      AstObject * objAst = (AstObject*) ast;
      Value extends = evaluate ( state, objAst -> extends );
      Object * addr = (Object*) heap_alloc ( state -> heap, sizeof (Object), alignof (Object) );
      addr -> member_cnt = objAst -> member_cnt;
      addr -> extends = extends;
      addr -> members = (SimpleEntry*) heap_alloc ( state -> heap, sizeof (SimpleEntry) * addr -> member_cnt, alignof (SimpleEntry) );
      Value value;
      AstDefinition * defAst;
      for ( size_t i = 0; i < addr -> member_cnt; ++i ) {
        if ( objAst -> members [ i ] -> kind != AST_DEFINITION ) {
          fprintf ( stderr, "Invalid object member.\n" );
          exit ( 7 );
        }
        defAst = (AstDefinition*) objAst -> members [ i ];
        env_push ( state );
        value = evaluate ( state, objAst -> members [ i ] );
        env_pop ( state );
        addr -> members [ i ] . name  = defAst -> name;
        addr -> members [ i ] . value = value;
      }
      return (Value) { .kind = VALUE_OBJECT, .address = addr };      
    }
    case AST_FIELD_ACCESS: {
      AstFieldAccess * accessAst = (AstFieldAccess *) ast;
      Value object_value = evaluate ( state, accessAst -> object );
      if ( object_value . kind != VALUE_OBJECT ) {
        fprintf ( stderr, "Trying to access field of non object.\n" );
        exit ( 8 );
      }
      Value * result = get_object_field ( &object_value, accessAst -> field );
      /*Value * next = &object_value;
      Object * curr;
      do {
        curr = (Object*) next -> address;
        for ( size_t i = 0; i < curr -> member_cnt; ++i )
          if ( str_eq ( curr -> members [ i ] . name, accessAst -> field ) )
            return curr -> members [ i ] . value;
        next = & curr -> extends;
      } while ( next -> kind == VALUE_OBJECT );
      fprintf ( stderr, "Field %.*s was not found in specified object.\n", (int) accessAst -> field . len, accessAst -> field . str );
      exit ( 9 );*/
      return *result;
    }
    case AST_FIELD_ASSIGNMENT: {
      AstFieldAssignment * assignAst = (AstFieldAssignment*) ast;
      Value object_value = evaluate ( state, assignAst -> object );
      if ( object_value . kind != VALUE_OBJECT ) {
        fprintf ( stderr, "Trying to assign to a field of non object.\n" );
        exit ( 8 );
      }
      Value value = evaluate ( state, assignAst -> value );
      Value * result = get_object_field ( &object_value, assignAst -> field );
      /*Value * next = &object_value;
      Object * curr;
      do {
        curr = (Object*) next -> address;
        for ( size_t i = 0; i < curr -> member_cnt; ++i )
          if ( str_eq ( curr -> members [ i ] . name, assignAst -> field ) ) {
            curr -> members [ i ] . value = value;
            return value;
          }
        next = & curr -> extends;
      } while ( next -> kind == VALUE_OBJECT );
      fprintf ( stderr, "Field %.*s was not found in specified object.\n", (int) assignAst -> field . len, assignAst -> field . str );
      exit ( 9 );*/
      *result = value;
      return value;
    }
  }
  return (Value) { .kind = VALUE_NULL };
}