Skip to content
Snippets Groups Projects
parser.c 24.9 KiB
Newer Older
// FML parser
// Michal Vlasák, FIT CTU, 2023

Michal Vlasák's avatar
Michal Vlasák committed
#include <stdio.h>

#include "parser.h"

#define UNREACHABLE() unreachable(__FILE__, __LINE__)
_Noreturn void
unreachable(char *file, size_t line)
{
Michal Vlasák's avatar
Michal Vlasák committed
	fprintf(stderr, "ERROR: unreachable code reached at %s:%zu\n", file, line);
Michal Vlasák's avatar
Michal Vlasák committed
	exit(EXIT_FAILURE);
}

Michal Vlasák's avatar
Michal Vlasák committed
bool str_eq(Str a, Str b)
{
	return a.len == b.len && memcmp(a.str, b.str, a.len) == 0;
}

int str_cmp(Str a, Str b)
{
	int cmp = memcmp(a.str, b.str, a.len < b.len ? a.len : b.len);
	return cmp == 0 ? (a.len > b.len) - (b.len > a.len) : cmp;
}

Michal Vlasák's avatar
Michal Vlasák committed
typedef struct {
	const u8 *pos;
	const u8 *end;
} Lexer;

typedef enum {
	LS_START,
	LS_IDENTIFIER,
	LS_NUMBER,
	LS_STRING,
	LS_STRING_ESC,
	LS_SLASH,
	LS_LINE_COMMENT,
	LS_BLOCK_COMMENT,
	LS_BLOCK_COMMENT_STAR,
	LS_MINUS,
	LS_EQUAL,
	LS_GREATER,
	LS_LESS,
	LS_EXCLAM,
} LexState;

typedef enum {
	ASSOC_LEFT,
	ASSOC_RIGHT,
} Associativity;

typedef enum {
	PREC_NONE,
	PREC_EXPR,
	PREC_ASGN,
	PREC_DISJ,
	PREC_CONJ,
	PREC_CMP,
	PREC_ADD,
	PREC_MUL,
	PREC_POST,
	PREC_TOP,
} Precedence;

#define TOKENS(KW, PU, OT) \
	/* token          repr             nud       led       prec  assoc*/\
	OT(NUMBER,        "a number",      primary,  lefterr,  TOP,  LEFT)  \
	OT(IDENTIFIER,    "an identifier", ident,    lefterr,  TOP,  LEFT)  \
	OT(STRING,        "a string",      nullerr,  lefterr,  TOP,  LEFT)  \
                                                                            \
	PU(BAR,           "|",             nullerr,  binop,    DISJ, LEFT)  \
	PU(AMPERSANT,     "&",             nullerr,  binop,    CONJ, LEFT)  \
	PU(EQUAL_EQUAL,   "==",            nullerr,  binop,    CMP,  LEFT)  \
	PU(BANG_EQUAL,    "!=",            nullerr,  binop,    CMP,  LEFT)  \
	PU(GREATER,       ">",             nullerr,  binop,    CMP,  LEFT)  \
	PU(LESS,          "<",             nullerr,  binop,    CMP,  LEFT)  \
	PU(GREATER_EQUAL, ">=",            nullerr,  binop,    CMP,  LEFT)  \
	PU(LESS_EQUAL,    "<=",            nullerr,  binop,    CMP,  LEFT)  \
	PU(PLUS,          "+",             nullerr,  binop,    ADD,  LEFT)  \
	PU(MINUS,         "-",             nullerr,  binop,    ADD,  LEFT)  \
	PU(ASTERISK,      "*",             nullerr,  binop,    MUL,  LEFT)  \
	PU(SLASH,         "/",             nullerr,  binop,    MUL,  LEFT)  \
	PU(PERCENT,       "%",             nullerr,  binop,    MUL,  LEFT)  \
	                                                                    \
	PU(SEMICOLON,     ";",             nullerr,  stop,     NONE, LEFT)  \
	PU(LPAREN,        "(",             paren,    call,     POST, LEFT)  \
	PU(RPAREN,        ")",             nullerr,  stop,     NONE, LEFT)  \
	PU(EQUAL,         "=",             nullerr,  eqerr,    ASGN, LEFT)  \
	PU(LARROW,        "<-",            nullerr,  assign,   ASGN, RIGHT) \
	PU(RARROW,        "->",            nullerr,  lefterr,  TOP,  LEFT)  \
	PU(DOT,           ".",             nullerr,  field,    POST, LEFT)  \
	PU(LBRACKET,      "[",             nullerr,  indexing, POST, LEFT)  \
	PU(RBRACKET,      "]",             nullerr,  stop,     NONE, LEFT)  \
	PU(COMMA,         ",",             nullerr,  stop,     NONE, LEFT)  \
	                                                                    \
	KW(BEGIN,         "begin",         block,    stop,     NONE, LEFT)  \
	KW(END,           "end",           nullerr,  stop,     NONE, LEFT)  \
	KW(IF,            "if",            cond,     lefterr,  TOP,  LEFT)  \
	KW(THEN,          "then",          nullerr,  stop,     NONE, LEFT)  \
	KW(ELSE,          "else",          nullerr,  stop,     NONE, LEFT)  \
	KW(LET,           "let",           let,      lefterr,  TOP,  LEFT)  \
	KW(NULL,          "null",          primary,  lefterr,  TOP,  LEFT)  \
	KW(PRINT,         "print",         print,    lefterr,  TOP,  LEFT)  \
	KW(OBJECT,        "object",        object,   lefterr,  TOP,  LEFT)  \
	KW(EXTENDS,       "extends",       nullerr,  lefterr,  TOP,  LEFT)  \
	KW(WHILE,         "while",         loop,     lefterr,  TOP,  LEFT)  \
	KW(DO,            "do",            nullerr,  stop,     NONE, LEFT)  \
	KW(FUNCTION,      "function",      function, lefterr,  TOP,  LEFT)  \
	KW(ARRAY,         "array",         array,    lefterr,  TOP,  LEFT)  \
	KW(TRUE,          "true",          primary,  lefterr,  TOP,  LEFT)  \
	KW(FALSE,         "false",         primary,  lefterr,  TOP,  LEFT)  \
	                                                                    \
	OT(EOF,           "end of input",  nullerr,  stop,     NONE, LEFT)  \
	OT(ERROR,         "lex error",     nullerr,  lefterr,  TOP,  LEFT)

typedef enum {
	#define TOK_ENUM(tok, ...) TK_##tok,
	TOKENS(TOK_ENUM, TOK_ENUM, TOK_ENUM)
	#undef TOK_ENUM
} TokenKind;

static const char *tok_repr[] = {
	#define TOK_STR(tok, str, ...) "'"str"'",
	#define TOK_STR_OTHER(tok, str, ...) str,
	TOKENS(TOK_STR, TOK_STR, TOK_STR_OTHER)
	#undef TOK_STR
	#undef TOK_STR_OTHER
};

static struct {
	const char *str;
	TokenKind tok;
} keywords[] = {
	#define TOK_KW(tok, str, ...) { str, TK_##tok },
	#define TOK_OTHER(tok, str, ...)
	TOKENS(TOK_KW, TOK_OTHER, TOK_OTHER)
	#undef TOK_KW
	#undef TOK_OTHER
};

static bool
tok_is_identifier(TokenKind kind)
{
	return kind == TK_IDENTIFIER || (kind >= TK_BAR && kind <= TK_PERCENT);
}

typedef struct {
	TokenKind kind;
	Str str;
} Token;

Lexer
lex_create(Str source)
{
	return (Lexer) {
		.pos = source.str,
		.end = source.str + source.len,
	};
}

#define ALPHA '_': case 'a': case 'b': case 'c': case 'd': case 'e': case 'f': case 'g': case 'h': case 'i': case 'j': case 'k': case 'l': case 'm': case 'n': case 'o': case 'p': case 'q': case 'r': case 's': case 't': case 'v': case 'u': case 'w': case 'x': case 'y': case 'z': case 'A': case 'B': case 'C': case 'D': case 'E': case 'F': case 'G': case 'H': case 'I': case 'J': case 'K': case 'L': case 'M': case 'N': case 'O': case 'P': case 'Q': case 'R': case 'S': case 'T': case 'V': case 'U': case 'W': case 'X': case 'Y': case 'Z'

#define DIGIT '0': case '1': case '2': case '3': case '4': case '5': case '6': case '7': case '8': case '9'

static void
Michal Vlasák's avatar
Michal Vlasák committed
lex_next(Lexer *lexer, Token *token)
{
	LexState state = LS_START;
	TokenKind tok = TK_ERROR;
	int end_offset = 0;
	const u8 *start = lexer->pos;
	size_t length;
	while (lexer->pos != lexer->end) {
		u8 c = *lexer->pos;
		switch (state) {
		case LS_START: switch (c) {
			case ' ': case '\t': case '\n': start += 1; break;
			case ALPHA: state = LS_IDENTIFIER; break;
			case DIGIT: state = LS_NUMBER; break;
			case '"': state = LS_STRING; start += 1; break;
			case '/': state = LS_SLASH; break;
			case '-': state = LS_MINUS; break;
			case '=': state = LS_EQUAL; break;
			case '>': state = LS_GREATER; break;
			case '<': state = LS_LESS; break;
			case '!': state = LS_EXCLAM; break;
			case ';': tok = TK_SEMICOLON; goto done;
			case '|': tok = TK_BAR; goto done;
			case '&': tok = TK_AMPERSANT; goto done;
			case '+': tok = TK_PLUS; goto done;
			case '*': tok = TK_ASTERISK; goto done;
			case '%': tok = TK_PERCENT; goto done;
			case '(': tok = TK_LPAREN; goto done;
			case ')': tok = TK_RPAREN; goto done;
			case '.': tok = TK_DOT; goto done;
			case '[': tok = TK_LBRACKET; goto done;
			case ']': tok = TK_RBRACKET; goto done;
			case ',': tok = TK_COMMA; goto done;
			default:  tok = TK_ERROR; goto done;
Michal Vlasák's avatar
Michal Vlasák committed
		case LS_IDENTIFIER: switch (c) {
			case ALPHA: case DIGIT: break;
			default: tok = TK_IDENTIFIER; goto prev_done;
Michal Vlasák's avatar
Michal Vlasák committed
		case LS_NUMBER: switch (c) {
			case DIGIT: break;
			default: tok = TK_NUMBER; goto prev_done;
Michal Vlasák's avatar
Michal Vlasák committed
		case LS_STRING: switch (c) {
			case '"': tok = TK_STRING; end_offset = -1; goto done;
			case '\\': state = LS_STRING_ESC; break;
Michal Vlasák's avatar
Michal Vlasák committed
		case LS_STRING_ESC: switch (c) {
			case 'n': case 't': case 'r': case '~': case '"': case '\\': state = LS_STRING; break;
			default: goto err;
Michal Vlasák's avatar
Michal Vlasák committed
		case LS_SLASH: switch (c) {
			case '/': state = LS_LINE_COMMENT; start += 2; break;
			case '*': state = LS_BLOCK_COMMENT; start += 2; break;
			default: tok = TK_SLASH; goto prev_done;
Michal Vlasák's avatar
Michal Vlasák committed
		case LS_LINE_COMMENT: switch (c) {
			case '\n': state = LS_START; start = lexer->pos + 1; break;
Michal Vlasák's avatar
Michal Vlasák committed
		case LS_BLOCK_COMMENT: switch (c) {
			case '*': state = LS_BLOCK_COMMENT_STAR; break;
Michal Vlasák's avatar
Michal Vlasák committed
		case LS_BLOCK_COMMENT_STAR: switch (c) {
			case '*': break;
			case '/': state = LS_START; start = lexer->pos + 1; break;
			default: state = LS_BLOCK_COMMENT; break;
Michal Vlasák's avatar
Michal Vlasák committed
		case LS_MINUS: switch (c) {
			case '>': tok = TK_RARROW; goto done;
			case DIGIT: state = LS_NUMBER; break;
			default: tok = TK_MINUS; goto prev_done;
Michal Vlasák's avatar
Michal Vlasák committed
		case LS_EQUAL: switch (c) {
			case '=': tok = TK_EQUAL_EQUAL; goto done;
			default: tok = TK_EQUAL; goto prev_done;
Michal Vlasák's avatar
Michal Vlasák committed
		case LS_GREATER: switch (c) {
			case '=': tok = TK_GREATER_EQUAL; goto done;
			default: tok = TK_GREATER; goto prev_done;
Michal Vlasák's avatar
Michal Vlasák committed
		case LS_LESS: switch (c) {
			case '=': tok = TK_LESS_EQUAL; goto done;
			case '-': tok = TK_LARROW; goto done;
			default: tok = TK_LESS; goto prev_done;
Michal Vlasák's avatar
Michal Vlasák committed
		case LS_EXCLAM: switch (c) {
			case '=': tok = TK_BANG_EQUAL; goto done;
			default: goto err;
Michal Vlasák's avatar
Michal Vlasák committed
		}

		lexer->pos += 1;
	}

	switch (state) {
	case LS_START: case LS_LINE_COMMENT: tok = TK_EOF; goto prev_done;
	case LS_IDENTIFIER: tok = TK_IDENTIFIER; goto prev_done;
	case LS_NUMBER: tok = TK_NUMBER; goto prev_done;
	case LS_STRING: case LS_STRING_ESC: case LS_BLOCK_COMMENT: case LS_BLOCK_COMMENT_STAR: goto err;
	case LS_SLASH: case LS_MINUS: case LS_EQUAL: case LS_GREATER: case LS_LESS: tok = TK_SLASH; goto prev_done;
	case LS_EXCLAM: goto err;
	}

done:
	lexer->pos += 1;
prev_done:
err:
	length = lexer->pos - start + end_offset;
	if (tok == TK_IDENTIFIER) {
		for (size_t i = 0; i < sizeof(keywords) / sizeof(keywords[0]); i++) {
			if (strlen(keywords[i].str) == length && memcmp((const char*) start, keywords[i].str, length) == 0) {
				tok = keywords[i].tok;
				break;
			}
		}
	}
	token->kind = tok;
	token->str.str = start;
	token->str.len = length;
}

typedef struct {
	Arena *arena;
	GArena *scratch;
	void *user_data;
	void (*error_callback)(void *user_data, const u8 *err_pos, const char *msg, va_list ap);
	Lexer lexer;
	Token lookahead;
	Token prev;
	bool had_error;
	bool panic_mode;
} Parser;

static void
parser_error(Parser *parser, Token errtok, bool panic, const char *msg, ...)
{
	va_list ap;
	va_start(ap, msg);
	if (!parser->panic_mode) {
		parser->error_callback(parser->user_data, errtok.str.str, msg, ap);
		parser->had_error = true;
		parser->panic_mode = panic;
	}
	va_end(ap);
}

static TokenKind
peek(const Parser *parser)
{
	return parser->lookahead.kind;
}

static Token
prev_tok(Parser *parser)
{
	return parser->prev;
}

static Token
discard(Parser *parser)
{
	parser->prev = parser->lookahead;
	lex_next(&parser->lexer, &parser->lookahead);
	if (parser->lookahead.kind == TK_ERROR) {
		parser_error(parser, parser->lookahead, true, "Unexpected character");
	}
	return parser->prev;
}

static void
eat(Parser *parser, TokenKind kind)
{
	TokenKind tok = peek(parser);
	if (tok != kind) {
		parser_error(parser, parser->lookahead, true, "Expected %s, found %s", tok_repr[kind], tok_repr[tok]);
		return;
	}
	discard(parser);
}

static void
eat_identifier(Parser *parser, Str *identifier)
{
	TokenKind tok = peek(parser);
	if (!tok_is_identifier(tok)) {
		parser_error(parser, parser->lookahead, true, "Expected %s, found %s", tok_repr[TK_IDENTIFIER], tok_repr[tok]);
		return;
	}
	discard(parser);
	*identifier = prev_tok(parser).str;
}

static bool
try_eat(Parser *parser, TokenKind kind)
{
	if (peek(parser) == kind) {
		discard(parser);
		return true;
	}
	return false;
}

#define AST_CREATE(type, var, arena, kind) type *var = ast_create_((arena), (kind), sizeof(type))
static void *
ast_create_(Arena *arena, AstKind kind, size_t size)
{
	Ast *ast = arena_alloc(arena, size);
	memset(ast, 0, size);
	ast->kind = kind;
	return ast;
}

static Ast *
create_null(Parser *parser)
{
	AST_CREATE(AstNull, ast, parser->arena, AST_NULL);
	return &ast->base;
}

static Ast *expression_bp(Parser *parser, int bp);

static Ast *
expression(Parser *parser)
{
	return expression_bp(parser, PREC_EXPR);
}

static void
expression_list(Parser *parser, Ast ***list, size_t *n, TokenKind separator, TokenKind terminator)
{
	size_t start = garena_save(parser->scratch);
	while (!try_eat(parser, terminator)) {
		garena_push_value(parser->scratch, Ast *, expression(parser));
		if (!try_eat(parser, separator)) {
			eat(parser, terminator);
			break;
		}
	}
	*n = garena_cnt_from(parser->scratch, start, Ast *);
	*list = move_to_arena(parser->arena, parser->scratch, start, Ast *);
}

static void
identifier_list(Parser *parser, Str **list, size_t *n, TokenKind separator, TokenKind terminator)
{
	size_t start = garena_save(parser->scratch);
	while (!try_eat(parser, terminator)) {
		eat_identifier(parser, garena_push(parser->scratch, Str));
		if (!try_eat(parser, separator)) {
			eat(parser, terminator);
			break;
		}
	}
	*n = garena_cnt_from(parser->scratch, start, Str);
	*list = move_to_arena(parser->arena, parser->scratch, start, Str);
}

static Ast *
nullerr(Parser *parser)
{
	TokenKind tok = peek(parser);
	parser_error(parser, parser->lookahead, true, "Invalid start of expression %s", tok_repr[tok]);
	return create_null(parser);
}

static Ast *
primary(Parser *parser)
{
	Token token = discard(parser);
	switch (token.kind) {
	case TK_NUMBER: {
		AST_CREATE(AstInteger, integer, parser->arena, AST_INTEGER);
		const u8 *pos = token.str.str;
		const u8 *end = pos + token.str.len;
		bool negative = 0;
		while (*pos == '-') {
			negative = !negative;
			pos += 1;
		}
		i64 value = 0;
		for (; pos < end; pos++) {
			value = value * 10 + (*pos - '0');
		}
		integer->value = (i32) (negative ? -value : value);
Michal Vlasák's avatar
Michal Vlasák committed
		return &integer->base;
	}
	case TK_NULL: {
		AST_CREATE(AstNull, null, parser->arena, AST_NULL);
		return &null->base;
	}
	case TK_TRUE: {
		AST_CREATE(AstBoolean, boolean, parser->arena, AST_BOOLEAN);
		boolean->value = true;
		return &boolean->base;
	}
	case TK_FALSE: {
		AST_CREATE(AstBoolean, boolean, parser->arena, AST_BOOLEAN);
		boolean->value = false;
		return &boolean->base;
	}
	default:
		UNREACHABLE();
	}
}

static Ast *
ident(Parser *parser)
{
	AST_CREATE(AstVariableAccess, variable_access, parser->arena, AST_VARIABLE_ACCESS);
	eat_identifier(parser, &variable_access->name);
	return &variable_access->base;
}

static Ast *
block(Parser *parser)
{
	AST_CREATE(AstBlock, block, parser->arena, AST_BLOCK);
	eat(parser, TK_BEGIN);
	expression_list(parser, &block->expressions, &block->expression_cnt, TK_SEMICOLON, TK_END);
	// begin end => null
	if (block->expression_cnt == 0) {
		block->base.kind = AST_NULL;
	}
	return &block->base;
}

static Ast *
let(Parser *parser)
{
	AST_CREATE(AstDefinition, definition, parser->arena, AST_DEFINITION);
	eat(parser, TK_LET);
	eat_identifier(parser, &definition->name);
	eat(parser, TK_EQUAL);
	definition->value = expression(parser);
	return &definition->base;
}

static Ast *
array(Parser *parser)
{
	AST_CREATE(AstArray, array, parser->arena, AST_ARRAY);
	eat(parser, TK_ARRAY);
	eat(parser, TK_LPAREN);
	array->size = expression(parser);
	eat(parser, TK_COMMA);
	array->initializer = expression(parser);
	eat(parser, TK_RPAREN);
	return &array->base;
}

static Ast *
object(Parser *parser)
{
	AST_CREATE(AstObject, object, parser->arena, AST_OBJECT);
	eat(parser, TK_OBJECT);
	Token object_tok = prev_tok(parser);
	if (try_eat(parser, TK_EXTENDS)) {
		object->extends = expression(parser);
	} else {
		object->extends = create_null(parser);
	}
	eat(parser, TK_BEGIN);
	expression_list(parser, &object->members, &object->member_cnt, TK_SEMICOLON, TK_END);
	for (size_t i = 0; i < object->member_cnt; i++) {
		if (object->members[i]->kind != AST_DEFINITION) {
			parser_error(parser, object_tok, false, "Found object member that is not a definition");
		}
		AstDefinition *member = (AstDefinition *) object->members[i];
		for (size_t j = 0; j < i; j++) {
			AstDefinition *other = (AstDefinition *) object->members[j];
			if (str_eq(member->name, other->name)) {
				const u8 *str = member->name.str;
				size_t len = member->name.len;
				parser_error(parser, object_tok, false, "Found multiple times the object member '%.*s'", len, str);
			}
		}
Michal Vlasák's avatar
Michal Vlasák committed
	}
	return &object->base;
}

static Ast *
cond(Parser *parser)
{
	AST_CREATE(AstConditional, conditional, parser->arena, AST_CONDITIONAL);
	eat(parser, TK_IF);
	conditional->condition = expression(parser);
	eat(parser, TK_THEN);
	conditional->consequent = expression(parser);
	if (try_eat(parser, TK_ELSE)) {
		conditional->alternative = expression(parser);
	} else {
		conditional->alternative = create_null(parser);
	}
	return &conditional->base;
}

static Ast *
loop(Parser *parser)
{
	AST_CREATE(AstLoop, loop, parser->arena, AST_LOOP);
	eat(parser, TK_WHILE);
	loop->condition = expression(parser);
	eat(parser, TK_DO);
	loop->body = expression(parser);
	return &loop->base;
}

static Ast *
print(Parser *parser)
{
	AST_CREATE(AstPrint, print, parser->arena, AST_PRINT);
	eat(parser, TK_PRINT);
	eat(parser, TK_LPAREN);
	eat(parser, TK_STRING);
	Token fmt_tok = prev_tok(parser);
	print->format = fmt_tok.str;
	size_t formats = 0;
	for (size_t i = 0; i < print->format.len; i++) {
		switch (print->format.str[i]) {
		case '\\': i++; continue;
		case '~': formats += 1;
		}
	}
	if (try_eat(parser, TK_COMMA)) {
		expression_list(parser, &print->arguments, &print->argument_cnt, TK_COMMA, TK_RPAREN);
	} else {
		eat(parser, TK_RPAREN);
	}
	if (formats != print->argument_cnt) {
		parser_error(parser, fmt_tok, false, "Invalid number of print arguments: %zu expected, got %zu", formats, print->argument_cnt);
	}
	return &print->base;
}

static Ast *
paren(Parser *parser)
{
	Ast *ast;
	eat(parser, TK_LPAREN);
	ast = expression(parser);
	eat(parser, TK_RPAREN);
	return ast;
}


static Ast *
function(Parser *parser)
{
	AST_CREATE(AstFunction, function, parser->arena, AST_FUNCTION);
	Ast *ast = &function->base;
	eat(parser, TK_FUNCTION);
	if (tok_is_identifier(peek(parser))) {
		AST_CREATE(AstDefinition, definition, parser->arena, AST_DEFINITION);
		eat_identifier(parser, &definition->name);
		definition->value = &function->base;
		ast = &definition->base;
	}
	eat(parser, TK_LPAREN);
	identifier_list(parser, &function->parameters, &function->parameter_cnt, TK_COMMA, TK_RPAREN);
	eat(parser, TK_RARROW);
	function->body = expression(parser);
	return ast;
}

static Ast *
stop(Parser *parser, Ast *left, int rbp)
{
	(void) parser;
	(void) left;
	(void) rbp;
	UNREACHABLE();
}

static Ast *
lefterr(Parser *parser, Ast *left, int rbp)
{
	(void) left;
	(void) rbp;
	TokenKind tok = peek(parser);
	parser_error(parser, parser->lookahead, true, "Invalid expression continuing/ending token %s", tok_repr[tok]);
	// Set the current token to something with low binding power to not get
	// into infinite loop of `lefterr`s on the same token.
	parser->lookahead.kind = TK_EOF;
	return create_null(parser);
}

static Ast *
binop(Parser *parser, Ast *left, int rbp)
{
	AST_CREATE(AstMethodCall, method_call, parser->arena, AST_METHOD_CALL);
	method_call->object = left;
	Token token = discard(parser);
	method_call->name = token.str;
	method_call->arguments = arena_alloc(parser->arena, sizeof(*method_call->arguments));
	method_call->arguments[0] = expression_bp(parser, rbp);
	//method_call->arguments = expression_bp(parser, rbp);
	method_call->argument_cnt = 1;
	return &method_call->base;
}

static Ast *
call(Parser *parser, Ast *left, int rbp)
{
	(void) rbp;
	eat(parser, TK_LPAREN);
	switch (left->kind) {
	case AST_FIELD_ACCESS: {
		AstFieldAccess *field_access = (AstFieldAccess *) left;
		AST_CREATE(AstMethodCall, method_call, parser->arena, AST_METHOD_CALL);
		left->kind = AST_METHOD_CALL;
		method_call->object = field_access->object;
		method_call->name = field_access->field;
		expression_list(parser, &method_call->arguments, &method_call->argument_cnt, TK_COMMA, TK_RPAREN);
		return &method_call->base;
	}
	default: {
		AST_CREATE(AstFunctionCall, function_call, parser->arena, AST_FUNCTION_CALL);
		function_call->function = left;
		expression_list(parser, &function_call->arguments, &function_call->argument_cnt, TK_COMMA, TK_RPAREN);
		return &function_call->base;
	}
	}
}

static Ast *
indexing(Parser *parser, Ast *left, int rbp)
{
	// rbp not used - delimited by TK_RBRACKET, not by precedence
	(void) rbp;
	AST_CREATE(AstIndexAccess, index_access, parser->arena, AST_INDEX_ACCESS);
	eat(parser, TK_LBRACKET);
	index_access->object = left;
	index_access->index = expression(parser);
	eat(parser, TK_RBRACKET);
	return &index_access->base;
}

static Ast *
field(Parser *parser, Ast *left, int rbp)
{
	(void) rbp;
	AST_CREATE(AstFieldAccess, field_access, parser->arena, AST_FIELD_ACCESS);
	eat(parser, TK_DOT);
	field_access->object = left;
	eat_identifier(parser, &field_access->field);
	return &field_access->base;
}

static Ast *
assign(Parser *parser, Ast *left, int rbp)
{
	eat(parser, TK_LARROW);
	switch (left->kind) {
	case AST_VARIABLE_ACCESS: {
		AstVariableAccess *variable_access = (AstVariableAccess *) left;
		AST_CREATE(AstVariableAssignment, variable_assignment, parser->arena, AST_VARIABLE_ASSIGNMENT);
		variable_assignment->name = variable_access->name;
		variable_assignment->value = expression_bp(parser, rbp);
		return &variable_assignment->base;
	}
	case AST_FIELD_ACCESS: {
		AstFieldAccess *field_access = (AstFieldAccess *) left;
		AST_CREATE(AstFieldAssignment, field_assignment, parser->arena, AST_FIELD_ASSIGNMENT);
		field_assignment->object = field_access->object;
		field_assignment->field = field_access->field;
		field_assignment->value = expression_bp(parser, rbp);
		return &field_assignment->base;
	}
	case AST_INDEX_ACCESS: {
		AstIndexAccess *index_access = (AstIndexAccess *) left;
		AST_CREATE(AstIndexAssignment, index_assignment, parser->arena, AST_INDEX_ASSIGNMENT);
		index_assignment->object = index_access->object;
		index_assignment->index = index_access->index;
		index_assignment->value = expression_bp(parser, rbp);
		return &index_assignment->base;
	}
	default:
		parser_error(parser, parser->prev, false, "Invalid assignment left hand side, expected variable, index or field access");
		return left;
	}
}

static Ast *
eqerr(Parser *parser, Ast *left, int rbp)
{
	parser_error(parser, parser->lookahead, false, "Unexpected %s, did you mean to use %s for assignment?", tok_repr[TK_EQUAL], tok_repr[TK_LARROW]);
	parser->lookahead.kind = TK_LARROW;
	return assign(parser, left, rbp);
}


typedef struct {
	Ast *(*nud)(Parser *);
} NullInfo;

static NullInfo null_info[] = {
Michal Vlasák's avatar
Michal Vlasák committed
	#define TOK_NULL(tok, str, nud, led, lbp, rbp) { nud },
	TOKENS(TOK_NULL, TOK_NULL, TOK_NULL)
	#undef TOK_STR
};

static bool
at_synchronization_point(Parser *parser)
{
	if (parser->prev.kind == TK_EOF) {
		// nothing to synchronize
		return true;
	}
	if (parser->prev.kind == TK_SEMICOLON && null_info[parser->lookahead.kind].nud != nullerr) {
		// an expression separator and a token that starts an expression
		return true;
	}
	return false;
}

typedef struct {
	Ast *(*led)(Parser *, Ast *left, int rbp);
	int lbp;
	int rbp;
} LeftInfo;

static LeftInfo left_info[] = {
Michal Vlasák's avatar
Michal Vlasák committed
	#define TOK_LEFT(tok, str, nud, led, prec, assoc) { led, PREC_##prec, PREC_##prec + (ASSOC_##assoc == ASSOC_LEFT) },
	TOKENS(TOK_LEFT, TOK_LEFT, TOK_LEFT)
	#undef TOK_STR
};

static Ast *
expression_bp(Parser *parser, int bp)
{
	NullInfo ni = null_info[peek(parser)];
	Ast *left = ni.nud(parser);

	for (;;) {
		LeftInfo li = left_info[peek(parser)];
		if (li.lbp < bp) {
			break;
		}
		left = li.led(parser, left, li.rbp);
	}

	return left;
}

static Ast *
top(Parser *parser)
{
	AST_CREATE(AstTop, top, parser->arena, AST_TOP);
	for (;;) {
		expression_list(parser, &top->expressions, &top->expression_cnt, TK_SEMICOLON, TK_EOF);
		if (!parser->panic_mode) {
			break;
		}
		do {
			discard(parser);
		} while (!at_synchronization_point(parser));
		parser->panic_mode = false;
	}
	// empty program => null
	if (top->expression_cnt == 0) {
		top->base.kind = AST_NULL;
	}
	return &top->base;
}

void
parser_error_cb(void *user_data, const u8 *err_pos, const char *msg, va_list ap)
{
	Str *src = user_data;
	const u8 *line_start = src->str;
	size_t line = 0;
	const u8 *pos = src->str;
	for (; pos < err_pos; pos++) {
		if (*pos == '\n') {
			line_start = pos + 1;
			line++;
		}
	}
	size_t col = pos - line_start;
	const u8 *source_end = src->str + src->len;
	const u8 *line_end = pos;
	for (; line_end < source_end && *line_end != '\n'; line_end++) {}
	fprintf(stderr, "[%zu:%zu]: parser error: ", line + 1, col + 1);
	vfprintf(stderr, msg, ap);
	fprintf(stderr, "\n");
	fprintf(stderr, "  %.*s\n", (int) (line_end - line_start), line_start);
	fprintf(stderr, "  %*s\n", (int) (pos - line_start + 1), "^");
}

Ast *
parse(Arena *arena, GArena *scratch, Str source, void (*error_callback)(void *user_data, const u8 *err_pos, const char *msg, va_list ap), void *user_data)
{
	Parser parser = {
		.arena = arena,
		.scratch = scratch,
		.user_data = user_data,
		.error_callback = error_callback,
		.lexer = lex_create(source),
		.had_error = false,
		.panic_mode = false,
	};
	discard(&parser);

	Ast *ast = top(&parser);
	if (parser.had_error) {
		return NULL;
	}
	return ast;
}

Ast *
parse_src(Arena *arena, Str source)
{
	GArena scratch;
	garena_init(&scratch);
	Ast *ast = parse(arena, &scratch, source, parser_error_cb, &source);
	garena_destroy(&scratch);
	return ast;
}