#include <stdlib.h>
#include <stdio.h>
#include <inttypes.h>

#include "parser.h"
#include "ast_interpreter.h"

void env_push ( ASTInterpreterState * state ) {
  Environment * env = (Environment*) malloc ( sizeof (Environment) );
  (*env) = (Environment) { .prev = state -> current_env, .start = NULL };
  state -> current_env = env;
}

void env_pop ( ASTInterpreterState * state ) {
  Environment * tmp = state -> current_env;
  EnvironmentEntry * entry;
  EnvironmentEntry * next = tmp -> start;
  while ( (entry = next) ) {
    next = entry -> next;
    free ( entry );
  }
  state -> current_env = tmp -> prev;
  free (tmp);
}

void env_put ( ASTInterpreterState * state, Str name, Value * value ) {
  Environment * env = state -> current_env;
  EnvironmentEntry * entry;
  while ( env ) {
    entry = env -> start;
    while ( entry ) {
      if ( str_eq ( entry -> name, name ) ) {
        entry -> value = value;
        return;
      }
      entry = entry -> next;
    }
    env = env -> prev;
  }
  entry = state -> global_env . start;
  while ( entry ) {
    if ( str_eq ( entry -> name, name ) ) {
      entry -> value = value;
      return;
    }
    entry = entry -> next;
  }
  env = state -> current_env;
  env = env ? env : & state -> global_env;
  entry = (EnvironmentEntry*) malloc ( sizeof (EnvironmentEntry) );
  (*entry) = (EnvironmentEntry) {.name = name, .value = value, .next = env -> start };
  env -> start = entry;
}

void env_def ( ASTInterpreterState * state, Str name, Value * value ) {
  Environment * env = state -> current_env ? state -> current_env : & state -> global_env;
  EnvironmentEntry * parent = NULL;
  EnvironmentEntry * entry = env -> start;
  while ( entry ) {
    if ( str_eq ( name, entry -> name ) ) {
      entry -> value = value;
      return;
    }
    parent = entry;
    entry = entry -> next;
  }
  entry = (EnvironmentEntry*) malloc ( sizeof (EnvironmentEntry) );
  entry -> name = name;
  entry -> value = value;
  entry -> next = NULL;
  if ( parent )
    parent -> next = entry;
  else
    env -> start = entry;
}

Value * env_get ( ASTInterpreterState * state, Str name ) {
  Environment * env = state -> current_env;
  EnvironmentEntry * entry;
  while ( env ) {
    entry = env -> start;
    while ( entry ) {
      if ( str_eq ( entry -> name, name ) ) 
        return entry -> value;
      entry = entry -> next;
    }
    env = env -> prev;
  }
  entry = state -> global_env . start;
  while ( entry ) {
    if ( str_eq ( entry -> name, name ) ) 
      return entry -> value;
    entry = entry -> next;
  }
  env = state -> current_env;
  env = env ? env : & state -> global_env;
  entry = (EnvironmentEntry*) malloc ( sizeof (EnvironmentEntry) );
  (*entry) = (EnvironmentEntry) {.name = name, .value = make_null ( state -> heap ), .next = env -> start };
  env -> start = entry;
  return entry -> value;
}

void state_init ( ASTInterpreterState * state, Heap * heap ) {
  state -> heap = heap; 
  state -> global_env . prev = NULL;
  state -> global_env . start = NULL;
  state -> current_env = NULL;
}

void state_destroy ( ASTInterpreterState * state ) {
  heap_destroy ( state -> heap );
  Environment * env = & state -> global_env;
  EnvironmentEntry * entry;
  EnvironmentEntry * next = env -> start;
  while ( (entry = next) ) {
    next = entry -> next;
    free ( entry );
  }
  Environment * next_env = state -> current_env;
  while ( (env = next_env) ) {
    next = env -> start;
    while ( (entry = next) ) {
      next = entry -> next;
      free ( entry );
    }
    next_env = env -> prev;
  }
}

Value * function_call ( ASTInterpreterState * state, Value * callee, bool is_function, Ast ** arguments, size_t argc, Str * name ) {
  Value ** values = (Value **) malloc ( argc * sizeof (Value*) );
  for ( size_t i = 0; i < argc; ++i )
    values [ i ] = evaluate ( state, arguments [ i ] );
  Value * ret = NULL;
  Value * function = callee;
  if ( ! is_function ) {
    function = NULL;
    Value * next = callee;
    Value * curr;
    ObjectValue * tmp;
    while ( next -> kind == VALUE_OBJECT ) {
      curr = next;
      tmp = (ObjectValue *) curr;
      for ( size_t i = 0; i < tmp -> member_cnt; ++i )
        if ( str_eq ( tmp -> members [ i ] . name, *name ) ) {
          function = tmp -> members [ i ] . value;
          callee = curr;
          break;
        }
      next = tmp -> extends;
    }
    if ( ! function ) {
      Value * base = get_base ( callee );
      ret = try_operator ( state -> heap, base, values, argc, name );
      free ( values );
      if ( ret )
        return ret;
      fprintf ( stderr, "Method %.*s does not exist for this object and/or arguments.\n", (int) name -> len, name -> str );
      exit ( 13 );
    }
  } 
  if ( function -> kind != VALUE_FUNCTION ) {
    fprintf ( stderr, "Invalid callee.\n" );
    exit ( 4 );
  }
  FunctionValue * f = (FunctionValue*) function;
  Environment * tmp;
  tmp = state -> current_env;
  state -> current_env = NULL;
  env_push ( state );
  env_def ( state, STR ("this"), is_function ? make_null ( state -> heap ) : callee );
  for ( size_t i = 0; i < argc; ++i )
    env_def ( state, f -> function -> parameters [ i ], values [ i ] );
  ret = evaluate ( state, f -> function -> body );
  env_pop ( state );
  state -> current_env = tmp;
  
  free ( values );
  return ret;
}

void print_value ( Value * value ) {
  switch ( value -> kind ) {
    case VALUE_INTEGER:
      printf ( "%" PRIi32, ((IntValue*) value ) -> integer );
      break;
    case VALUE_BOOLEAN:
      if ( ((BoolValue*) value ) -> boolean )
        printf ( "true" );
      else
        printf ( "false" );
      break;
    case VALUE_NULL:
      printf ( "null" );
      break;
    case VALUE_FUNCTION:
      printf ( "function" );
      break;
    case VALUE_ARRAY:
      putchar ( '[' );
      ArrayValue * arr = (ArrayValue*) value;
      for ( i32 i = 0; i < arr -> length; ++i ) {
        if ( i != 0 )
          printf ( ", " );
        print_value ( arr -> elements [ i ] );
      }
      putchar ( ']' );
      break;
    case VALUE_OBJECT:
      printf ( "object(");
      ObjectValue * obj = (ObjectValue*) value;
      bool parent = false;
      if ( obj -> extends -> kind != VALUE_NULL ) {
        parent = true;
        printf ( "..=" );
        print_value ( obj -> extends );
      }
      if ( obj -> member_cnt )
        qsort ( obj -> members, obj -> member_cnt, sizeof (SimpleEntry), compare_entry );
      for ( size_t i = 0; i < obj -> member_cnt; ++i ) {
        if ( i != 0 || parent )
          printf ( ", " );
        printf ( "%.*s=", (int) obj -> members [ i ] . name . len, obj -> members [ i ] . name . str );
        print_value ( obj -> members [ i ] . value );
      }
      putchar ( ')' );
      break;
    case VALUE_INVALID:
      fprintf ( stderr, "Invalid value.\n" );
  }
}

void fml_print ( Str format, Value ** args, size_t argc ) {
  u8 c;
  size_t arg = 0;
  for ( size_t i = 0; i < format . len; ++i ) {
    c = format . str [ i ];
    if ( c == '\\' ) {
      ++i;
      c = format . str [ i ];
      switch ( c ) {
        case '~':
          putchar ( '~' );
          break;
        case 'n':
          putchar ( '\n' );
          break;
        case '"':
          putchar ( '"' );
          break;
        case 'r':
          putchar ( '\r' );
          break;
        case 't':
          putchar ( '\t' );
          break;
        case '\\':
          putchar ( '\\' );
          break;
        default:
          fprintf ( stderr, "Invalid escaped character %c.\n", c );
          exit ( 1 );
      }
    } else if ( c == '~' ) {
      if ( arg == argc ) {
        fprintf ( stderr, "Too many placeholders, not enough arguments.\n" );
        exit ( 1 );
      }
      print_value ( args [ arg++ ] );
    } else
      putchar ( c );
  }
}

Value * evaluate ( ASTInterpreterState * state, Ast * ast ) {
  switch ( ast -> kind ) {
    case AST_INTEGER:
      return make_int ( state -> heap, ((AstInteger*) ast ) -> value );
    case AST_BOOLEAN:
      return make_bool ( state -> heap, ((AstBoolean*) ast ) -> value );
    case AST_NULL:
      return make_null ( state -> heap );
    case AST_PRINT: {
      AstPrint * printAst = (AstPrint*) ast;
      Value ** args = (Value **) malloc ( sizeof (Value*) * printAst -> argument_cnt );
      for ( size_t i = 0; i < printAst -> argument_cnt; ++i ) 
        args [ i ] = evaluate ( state, printAst -> arguments [ i ] );
      fml_print ( printAst -> format, args, printAst -> argument_cnt );
      free ( args );
      return make_null ( state -> heap );
    }
    case AST_TOP: {
      AstTop * topAst = (AstTop*) ast;
      Value * val = evaluate ( state, topAst -> expressions [ 0 ] );
      for ( size_t i = 1; i < topAst -> expression_cnt; ++i )
        val = evaluate ( state, topAst -> expressions [ i ] );
      return val;
    }
    case AST_DEFINITION: {
      AstDefinition * defAst = (AstDefinition*) ast;
      Value * value = evaluate ( state, defAst -> value );
      env_def ( state, defAst -> name, value );
      return value; 
    }
    case AST_VARIABLE_ACCESS: 
      return env_get ( state, ((AstVariableAccess*) ast ) -> name );
    case AST_VARIABLE_ASSIGNMENT: {
      AstVariableAssignment * assignAst = (AstVariableAssignment*) ast;
      Value * value = evaluate ( state, assignAst -> value );
      env_put ( state, assignAst -> name, value );
      return value;
    }
    case AST_BLOCK: {
      AstBlock * blockAst = (AstBlock*) ast;
      env_push ( state );
      Value * value = evaluate ( state, blockAst -> expressions [ 0 ] );
      for ( size_t i = 1; i < blockAst -> expression_cnt; ++i )
        value = evaluate ( state, blockAst -> expressions [ i ] );
      env_pop ( state );
      return value;
    }
    case AST_CONDITIONAL: {
      AstConditional * condAst = (AstConditional*) ast;
      Value * cond = evaluate ( state, condAst -> condition );
      env_push ( state );
      Ast * cont = value_to_bool ( cond ) ? condAst -> consequent : condAst -> alternative;
      Value * value = evaluate ( state, cont );
      env_pop ( state );
      return value;
    }
    case AST_LOOP: {
      AstLoop * loopAst = (AstLoop*) ast;
      while ( value_to_bool ( evaluate ( state, loopAst -> condition ) ) ) {
        env_push ( state );
        evaluate ( state, loopAst -> body );
        env_pop ( state );
      }
      return make_null ( state -> heap );
    }
    case AST_FUNCTION:
      return make_function ( state -> heap, (AstFunction*) ast );
    case AST_FUNCTION_CALL: {
      AstFunctionCall * callAst = (AstFunctionCall*) ast;
      Value * function = evaluate ( state, callAst -> function );
      return function_call ( state, function, true, callAst -> arguments, callAst -> argument_cnt, NULL ); 
    }
    case AST_METHOD_CALL: {
      AstMethodCall * methodAst = (AstMethodCall*) ast;
      Value * object = evaluate ( state, methodAst -> object );
      if ( object -> kind == VALUE_FUNCTION ) {
        fprintf ( stderr, "Calling method on function.\n" );
        exit ( 4 );
      }
      return function_call ( state, object, false, methodAst -> arguments, methodAst -> argument_cnt, &methodAst -> name ); 
    }
    case AST_ARRAY: {
      AstArray * arrAst = (AstArray*) ast;
      Value * size = evaluate ( state, arrAst -> size );
      IntValue * int_size = (IntValue*) size;
      if ( size -> kind != VALUE_INTEGER || int_size -> integer < 0 ) {
        fprintf ( stderr, "Invalid array size.\n" );
        exit ( 5 );
      }
      ArrayValue * array = (ArrayValue*) make_array ( state -> heap, int_size -> integer );
      for ( i32 i = 0; i < array -> length; ++i ) {
        env_push ( state );
        array -> elements [ i ] = evaluate ( state, arrAst -> initializer );
        env_pop ( state );
      }
      return (Value*) array;
    }
    case AST_INDEX_ACCESS: {
      AstIndexAccess * iaccess = (AstIndexAccess*) ast;
      Value * object = evaluate ( state, iaccess -> object );
      Str tmp = STR ("get");
      return function_call ( state, object, false, &iaccess -> index, 1, &tmp );
    }
    case AST_INDEX_ASSIGNMENT: {
      AstIndexAssignment * iassign = (AstIndexAssignment*) ast;
      Value * object = evaluate ( state, iassign -> object );
      Ast * arguments [2];
      arguments [ 0 ] = iassign -> index;
      arguments [ 1 ] = iassign -> value;
      Str tmp = STR ("set");
      return function_call ( state, object, false, arguments, 2, &tmp );
    }
    case AST_OBJECT: {
      AstObject * objAst = (AstObject*) ast;
      ObjectValue * object = (ObjectValue*) make_object ( state -> heap, objAst -> member_cnt );
      object -> extends = evaluate ( state, objAst -> extends );
      Value * value;
      AstDefinition * defAst;
      for ( size_t i = 0; i < object -> member_cnt; ++i ) {
        if ( objAst -> members [ i ] -> kind != AST_DEFINITION ) {
          fprintf ( stderr, "Invalid object member.\n" );
          exit ( 7 );
        }
        defAst = (AstDefinition*) objAst -> members [ i ];
        env_push ( state );
        value = evaluate ( state, defAst -> value );
        env_pop ( state );
        object -> members [ i ] . name  = defAst -> name;
        object -> members [ i ] . value = value;
      }
      return (Value*) object;    
    }
    case AST_FIELD_ACCESS: {
      AstFieldAccess * accessAst = (AstFieldAccess *) ast;
      Value * object_value = evaluate ( state, accessAst -> object );
      if ( object_value -> kind != VALUE_OBJECT ) {
        fprintf ( stderr, "Trying to access field of non object.\n" );
        exit ( 8 );
      }
      Value ** result = get_object_field ( object_value, accessAst -> field );
      if ( ! result ) {
        fprintf ( stderr, "Field %.*s was not found in specified object.\n", (int) accessAst -> field . len, accessAst -> field . str );
        exit ( 8 );
      }
      return *result;
    }
    case AST_FIELD_ASSIGNMENT: {
      AstFieldAssignment * assignAst = (AstFieldAssignment*) ast;
      Value * object_value = evaluate ( state, assignAst -> object );
      if ( object_value -> kind != VALUE_OBJECT ) {
        fprintf ( stderr, "Trying to assign to a field of non object.\n" );
        exit ( 8 );
      }
      Value * value = evaluate ( state, assignAst -> value );
      Value ** result = get_object_field ( object_value, assignAst -> field );
      if ( ! result ) {
        fprintf ( stderr, "Field %.*s was not found in specified object.\n", (int) assignAst -> field . len, assignAst -> field . str );
        exit ( 8 );
      }
      *result = value;
      return value;
    }
  }
  return NULL;
}