From b1d12fde20067fe325fc9544da325fea912ebdcb Mon Sep 17 00:00:00 2001
From: Peter Matta <mattapet@fit.cvut.cz>
Date: Wed, 18 Apr 2018 21:04:23 +0200
Subject: [PATCH] Added support for return/break keywords

---
 include/dusk/AST/ASTVisitor.h        |  4 ++
 include/dusk/AST/Stmt.h              | 29 ++++++++++++
 include/dusk/Parse/Parser.h          |  3 ++
 include/dusk/Parse/Token.h           |  5 ++-
 include/dusk/Parse/TokenDefinition.h |  6 ++-
 lib/AST/ASTPrinter.cpp               | 22 +++++++++
 lib/AST/ASTWalker.cpp                | 17 +++++++
 lib/AST/Stmt.cpp                     | 23 ++++++++++
 lib/Parser/Lexer.cpp                 |  3 +-
 lib/Parser/ParseStmt.cpp             | 67 +++++++++++++++++++++-------
 10 files changed, 157 insertions(+), 22 deletions(-)

diff --git a/include/dusk/AST/ASTVisitor.h b/include/dusk/AST/ASTVisitor.h
index 39c81d2..4731d9b 100644
--- a/include/dusk/AST/ASTVisitor.h
+++ b/include/dusk/AST/ASTVisitor.h
@@ -92,6 +92,10 @@ public:
     /// Visit a concrete statement node.
     bool visit(Stmt *S) {
         switch (S->getKind()) {
+        case StmtKind::Break:
+            return getDerived().visit(static_cast<BreakStmt *>(S));
+        case StmtKind::Return:
+            return getDerived().visit(static_cast<ReturnStmt *>(S));
         case StmtKind::Range:
             return getDerived().visit(static_cast<RangeStmt *>(S));
         case StmtKind::Block:
diff --git a/include/dusk/AST/Stmt.h b/include/dusk/AST/Stmt.h
index 53e9772..63c274e 100644
--- a/include/dusk/AST/Stmt.h
+++ b/include/dusk/AST/Stmt.h
@@ -28,6 +28,8 @@ class ASTWalker;
 
 /// Describes statement type.
 enum struct StmtKind {
+    Break,
+    Return,
     Range,
     Block,
     Func,
@@ -47,6 +49,32 @@ public:
     StmtKind getKind() const { return Kind; }
 };
 
+/// Represents a `break` statement in a loop.
+class BreakStmt: public Stmt {
+    /// Range of \c break keyword
+    llvm::SMRange BreakLoc;
+    
+public:
+    BreakStmt(llvm::SMRange BR);
+    
+    virtual llvm::SMRange getSourceRange() const override;
+};
+   
+/// Represents a `return` statement.
+class ReturnStmt: public Stmt {
+    /// Location of \c return keyword
+    llvm::SMLoc RetLoc;
+    
+    /// Value that is to be returned.
+    Expr *Value;
+    
+public:
+    ReturnStmt(llvm::SMLoc RL, Expr *V);
+    
+    Expr *getValue() const { return Value; }
+    virtual llvm::SMRange getSourceRange() const override;
+};
+    
 /// Represents a range.
 class RangeStmt: public Stmt {
     /// Start of the range
@@ -71,6 +99,7 @@ public:
     virtual llvm::SMRange getSourceRange() const override;
 };
 
+
 /// Represents an arbitrary block of code.
 class BlockStmt: public Stmt {
     /// Location of block's opening \c {
diff --git a/include/dusk/Parse/Parser.h b/include/dusk/Parse/Parser.h
index 5e21a52..61ee964 100644
--- a/include/dusk/Parse/Parser.h
+++ b/include/dusk/Parse/Parser.h
@@ -120,6 +120,9 @@ private:
     
     Expr *parseExprStmt();
     
+    BreakStmt *parseBreakStmt();
+    ReturnStmt *parseReturnStmt();
+    
     FuncStmt *parseFuncStmt();
     
     ForStmt *parseForStmt();
diff --git a/include/dusk/Parse/Token.h b/include/dusk/Parse/Token.h
index 5edfc33..4056e28 100644
--- a/include/dusk/Parse/Token.h
+++ b/include/dusk/Parse/Token.h
@@ -105,18 +105,19 @@ public:
         return is(tok::identifier);
     }
 
-    /// \return \c true, if token is a keyword, \c false otherwise.
+    /// Returns \c true, if token is a keyword, \c false otherwise.
     bool isKeyword() const {
         switch (Kind) {
         case tok::kwVar:
         case tok::kwConst:
+        case tok::kwBreak:
+        case tok::kwReturn:
         case tok::kwIf:
         case tok::kwElse:
         case tok::kwWhile:
         case tok::kwFor:
         case tok::kwIn:
         case tok::kwFunc:
-        case tok::kwReturn:
         case tok::kwWriteln:
         case tok::kwReadln:
             return true;
diff --git a/include/dusk/Parse/TokenDefinition.h b/include/dusk/Parse/TokenDefinition.h
index 4c25538..7d8aadf 100644
--- a/include/dusk/Parse/TokenDefinition.h
+++ b/include/dusk/Parse/TokenDefinition.h
@@ -19,13 +19,14 @@ enum struct tok {
 // Keywords
     kwVar,
     kwConst,
+    kwBreak,
+    kwReturn,
     kwIf,
     kwElse,
     kwWhile,
     kwFor,
     kwIn,
     kwFunc,
-    kwReturn,
     kwWriteln,
     kwReadln,
 
@@ -89,13 +90,14 @@ namespace llvm {
         // Keywords
         case dusk::tok::kwVar:        return OS << "var";
         case dusk::tok::kwConst:      return OS << "const";
+        case dusk::tok::kwBreak:      return OS << "break";
+        case dusk::tok::kwReturn:     return OS << "return";
         case dusk::tok::kwIf:         return OS << "if";
         case dusk::tok::kwElse:       return OS << "else";
         case dusk::tok::kwWhile:      return OS << "while";
         case dusk::tok::kwFor:        return OS << "for";
         case dusk::tok::kwIn:         return OS << "in";
         case dusk::tok::kwFunc:       return OS << "func";
-        case dusk::tok::kwReturn:     return OS << "return";
         case dusk::tok::kwWriteln:    return OS << "writeln";
         case dusk::tok::kwReadln:     return OS << "readln";
                 
diff --git a/lib/AST/ASTPrinter.cpp b/lib/AST/ASTPrinter.cpp
index 0d4a12c..88615b0 100644
--- a/lib/AST/ASTPrinter.cpp
+++ b/lib/AST/ASTPrinter.cpp
@@ -148,6 +148,24 @@ public:
     }
 
     // MARK: - Statement nodes
+    
+    bool visit(BreakStmt *S) {
+        Printer.printStmtPre(S);
+        Printer << tok::kwBreak;
+        Printer.printStmtPost(S);
+        return true;
+    }
+    
+    bool visit(ReturnStmt *S) {
+        Printer.printStmtPre(S);
+        
+        Printer << tok::kwReturn << " ";
+        super::visit(S->getValue());
+
+        Printer.printStmtPost(S);
+        return true;
+    }
+    
     bool visit(RangeStmt *S) {
         Printer.printStmtPre(S);
         super::visit(S->getStart());
@@ -285,6 +303,10 @@ public:
 
     virtual void printStmtPost(Stmt *S) override {
         switch (S->getKind()) {
+        case StmtKind::Break:
+        case StmtKind::Return:
+            *this << ";";
+            break;
         case StmtKind::Block:
             --(*this);
             *this << tok::r_brace;
diff --git a/lib/AST/ASTWalker.cpp b/lib/AST/ASTWalker.cpp
index d4e7d72..8e8032e 100644
--- a/lib/AST/ASTWalker.cpp
+++ b/lib/AST/ASTWalker.cpp
@@ -184,6 +184,23 @@ public:
 
     // MARK: - Statement nodes
 
+    bool visit(BreakStmt *S) {
+        // Skip subtree
+        if (!Walker.preWalk(S))
+            return true;
+        return Walker.postWalk(S);
+    }
+    
+    bool visit(ReturnStmt *S) {
+        // Skip subtree
+        if (!Walker.preWalk(S))
+            return true;
+        
+        if (!super::visit(S->getValue()))
+            return false;
+        return Walker.postWalk(S);
+    }
+    
     bool visit(RangeStmt *S) {
         // Skip subtree
         if (!Walker.preWalk(S))
diff --git a/lib/AST/Stmt.cpp b/lib/AST/Stmt.cpp
index 4a6b61f..591679a 100644
--- a/lib/AST/Stmt.cpp
+++ b/lib/AST/Stmt.cpp
@@ -14,6 +14,29 @@
 
 using namespace dusk;
 
+// MARK: - Break statement
+
+BreakStmt::BreakStmt(llvm::SMRange BL)
+: Stmt(StmtKind::Break), BreakLoc(BL)
+{}
+
+llvm::SMRange BreakStmt::getSourceRange() const {
+    return BreakLoc;
+}
+
+// MARK: - Return statement
+
+ReturnStmt::ReturnStmt(llvm::SMLoc RL, Expr *V)
+: Stmt(StmtKind::Return), RetLoc(RL), Value(V)
+{
+    assert(V && "Invalid `return` statement.");
+}
+
+llvm::SMRange ReturnStmt::getSourceRange() const {
+    return { RetLoc, Value->getLocEnd() };
+}
+
+
 // MARK: - Range statement
 
 RangeStmt::RangeStmt(Expr *S, Expr *E, Token O)
diff --git a/lib/Parser/Lexer.cpp b/lib/Parser/Lexer.cpp
index 657dbaa..14ef7e8 100644
--- a/lib/Parser/Lexer.cpp
+++ b/lib/Parser/Lexer.cpp
@@ -266,13 +266,14 @@ tok Lexer::kindOfIdentifier(llvm::StringRef Str) {
     return llvm::StringSwitch<tok>(Str)
     .Case("var", tok::kwVar)
     .Case("const", tok::kwConst)
+    .Case("break", tok::kwBreak)
+    .Case("return", tok::kwReturn)
     .Case("if", tok::kwIf)
     .Case("else", tok::kwElse)
     .Case("while", tok::kwWhile)
     .Case("for", tok::kwFor)
     .Case("in", tok::kwIn)
     .Case("func", tok::kwFunc)
-    .Case("return", tok::kwReturn)
     .Case("writeln", tok::kwWriteln)
     .Case("readln", tok::kwReadln)
     .Default(tok::identifier);
diff --git a/lib/Parser/ParseStmt.cpp b/lib/Parser/ParseStmt.cpp
index 1f985a2..f95f75f 100644
--- a/lib/Parser/ParseStmt.cpp
+++ b/lib/Parser/ParseStmt.cpp
@@ -17,6 +17,51 @@ ASTNode *Parser::parseStatement() {
     return nullptr;
 }
 
+/// ExprStmt ::=
+///     Expr ';'
+Expr *Parser::parseExprStmt() {
+    Expr *E;
+    switch (Tok.getKind()) {
+        case tok::identifier:
+        case tok::number_literal:
+        case tok::l_paren:
+            E = parseExpr();
+            break;
+        default:
+            llvm_unreachable("Unexpected token.");
+    }
+    if (!consumeIf(tok::semicolon))
+        assert("Missing semicolon at the end of the line." && false);
+    return E;
+}
+
+/// BreakStmt ::=
+///     break ';'
+BreakStmt *Parser::parseBreakStmt() {
+    // Validate `break` keyword
+    assert(Tok.is(tok::kwBreak) && "Invalid parse method.");
+    auto T = Tok.getText();
+    auto S = consumeToken();
+    auto E = llvm::SMLoc::getFromPointer(T.data() + T.size());
+    
+    if (!consumeIf(tok::semicolon))
+        assert("Missing semicolon at the end of the line." && false);
+    return new BreakStmt({ S, E });
+}
+
+/// ReturnStmt ::=
+///     return Expr ';'
+ReturnStmt *Parser::parseReturnStmt() {
+    // Validate `return` keyword
+    assert(Tok.is(tok::kwReturn) && "Invalid parse method.");
+    auto RL = consumeToken();
+    auto E = parseExpr();
+    
+    if (!consumeIf(tok::semicolon))
+        assert("Missing semicolon at the end of the line." && false);
+    return new ReturnStmt(RL, E);
+}
+
 /// Block ::=
 ///     '{' BlockBody '}'
 BlockStmt *Parser::parseBlock() {
@@ -41,6 +86,11 @@ ASTNode *Parser::parseBlockBody() {
         
     case tok::kwVar:
         return parseVarDecl();
+            
+    case tok::kwBreak:
+        return parseBreakStmt();
+    case tok::kwReturn:
+        return parseReturnStmt();
     
     case tok::identifier:
     case tok::number_literal:
@@ -59,23 +109,6 @@ ASTNode *Parser::parseBlockBody() {
     }
 }
 
-Expr *Parser::parseExprStmt() {
-    Expr *E;
-    switch (Tok.getKind()) {
-        case tok::identifier:
-        case tok::number_literal:
-        case tok::l_paren:
-            E = parseExpr();
-            break;
-            
-        default:
-            llvm_unreachable("Unexpected token.");
-    }
-    if (!consumeIf(tok::semicolon))
-        assert("Missing semicolon at the end of the line" && false);
-    return E;
-}
-
 FuncStmt *Parser::parseFuncStmt() {
     // Validate `func` keyword
     assert(Tok.is(tok::kwFunc) && "Invalid parse method");
-- 
GitLab