From fb03cbe8d57a724d8cddbfdb018509af7b37617a Mon Sep 17 00:00:00 2001
From: Peter Matta <mattapet@fit.cvut.cz>
Date: Mon, 23 Apr 2018 13:11:28 +0200
Subject: [PATCH] Parse constants in blocks [Fix]

---
 include/dusk/AST/ASTVisitor.h |  4 ++--
 include/dusk/AST/Expr.h       |  6 +++---
 include/dusk/AST/Pattern.h    | 26 +------------------------
 include/dusk/AST/Stmt.h       | 23 +++++++++++++++++++++-
 include/dusk/Parse/Parser.h   |  4 ++--
 lib/AST/ASTPrinter.cpp        | 36 ++++++++++++++++-------------------
 lib/AST/ASTWalker.cpp         | 19 +++++++++---------
 lib/AST/Expr.cpp              |  2 +-
 lib/AST/Pattern.cpp           |  8 --------
 lib/AST/Stmt.cpp              |  6 ++++++
 lib/Parser/ParseExpr.cpp      |  2 +-
 lib/Parser/ParsePattern.cpp   | 17 -----------------
 lib/Parser/ParseStmt.cpp      | 20 +++++++++++++++++++
 13 files changed, 84 insertions(+), 89 deletions(-)

diff --git a/include/dusk/AST/ASTVisitor.h b/include/dusk/AST/ASTVisitor.h
index 04d86f3..8bf3ba0 100644
--- a/include/dusk/AST/ASTVisitor.h
+++ b/include/dusk/AST/ASTVisitor.h
@@ -98,6 +98,8 @@ public:
       return getDerived().visit(static_cast<ReturnStmt *>(S));
     case StmtKind::Range:
       return getDerived().visit(static_cast<RangeStmt *>(S));
+    case StmtKind::Subscript:
+      return getDerived().visit(static_cast<SubscriptStmt *>(S));
     case StmtKind::Block:
       return getDerived().visit(static_cast<BlockStmt *>(S));
     case StmtKind::For:
@@ -118,8 +120,6 @@ public:
       return getDerived().visit(static_cast<ExprPattern *>(P));
     case PatternKind::Variable:
       return getDerived().visit(static_cast<VarPattern *>(P));
-    case PatternKind::Subscript:
-      return getDerived().visit(static_cast<SubscriptPattern *>(P));
     }
   }
 };
diff --git a/include/dusk/AST/Expr.h b/include/dusk/AST/Expr.h
index 3988808..ff18659 100644
--- a/include/dusk/AST/Expr.h
+++ b/include/dusk/AST/Expr.h
@@ -165,13 +165,13 @@ class SubscriptExpr : public Expr {
   Expr *Base;
 
   /// Subscription pattern
-  Pattern *Subscript;
+  Stmt *Subscript;
 
 public:
-  SubscriptExpr(Expr *B, Pattern *S);
+  SubscriptExpr(Expr *B, Stmt *S);
 
   Expr *getBase() { return Base; }
-  Pattern *getSubscript() { return Subscript; }
+  Stmt *getSubscript() { return Subscript; }
 
   virtual SMRange getSourceRange() const override;
 };
diff --git a/include/dusk/AST/Pattern.h b/include/dusk/AST/Pattern.h
index 1d90571..c0ebc47 100644
--- a/include/dusk/AST/Pattern.h
+++ b/include/dusk/AST/Pattern.h
@@ -25,7 +25,7 @@ class Stmt;
 class ParamDecl;
 
 /// Pattern description.
-enum struct PatternKind { Expr, Variable, Subscript };
+enum struct PatternKind { Expr, Variable };
 
 class Pattern : public ASTNode {
   /// Pattern type.
@@ -86,30 +86,6 @@ public:
   virtual SMRange getSourceRange() const override;
 };
 
-/// Subscript pattern
-///
-/// Pattern used in array declaration and in accessing array elements.
-class SubscriptPattern : public Pattern {
-  /// Subscript value
-  Expr *Value;
-
-  /// Location of left bracket
-  SMLoc LBracet;
-
-  /// Location of right bracket
-  SMLoc RBracet;
-
-public:
-  SubscriptPattern(Expr *V, SMLoc L, SMLoc R);
-
-  Expr *getValue() const { return Value; }
-  SMLoc getLBracket() const { return LBracet; }
-  SMLoc getRBracket() const { return RBracet; }
-
-  virtual size_t count() const override;
-  virtual SMRange getSourceRange() const override;
-};
-
 } // namespace dusk
 
 #endif /* DUSK_PATTERN_H */
diff --git a/include/dusk/AST/Stmt.h b/include/dusk/AST/Stmt.h
index 6ce554c..a811718 100644
--- a/include/dusk/AST/Stmt.h
+++ b/include/dusk/AST/Stmt.h
@@ -26,7 +26,7 @@ class IdentifierExpr;
 class ASTWalker;
 
 /// Describes statement type.
-enum struct StmtKind { Break, Return, Range, Block, Func, For, While, If };
+enum struct StmtKind { Break, Return, Range, Block, Func, For, While, If, Subscript };
 
 class Stmt : public ASTNode {
   /// Statement type
@@ -65,6 +65,27 @@ public:
   virtual SMRange getSourceRange() const override;
 };
 
+/// Subscript statement.
+class SubscriptStmt: public Stmt {
+  /// Subcript value
+  Expr *Value;
+  
+  /// Location of left bracket
+  SMLoc LBracket;
+  
+  /// Location of right bracket
+  SMLoc RBracket;
+  
+public:
+  SubscriptStmt(Expr *V, SMLoc L, SMLoc R);
+  
+  Expr *getValue() const { return Value; }
+  SMLoc getLBracket() const { return LBracket; }
+  SMLoc getRBracket() const { return RBracket; }
+  
+  virtual SMRange getSourceRange() const override;
+};
+  
 /// Represents a range.
 class RangeStmt : public Stmt {
   /// Start of the range
diff --git a/include/dusk/Parse/Parser.h b/include/dusk/Parse/Parser.h
index f9ef380..faea278 100644
--- a/include/dusk/Parse/Parser.h
+++ b/include/dusk/Parse/Parser.h
@@ -142,6 +142,8 @@ private:
 
   Stmt *parseBreakStmt();
   Stmt *parseReturnStmt();
+  
+  Stmt *parseSubscriptStmt();
 
   Stmt *parseFuncStmt();
 
@@ -163,8 +165,6 @@ private:
   SmallVector<Decl *, 128> parseVarPatternBody();
   Decl *parseVarPatternItem();
 
-  Pattern *parseSubscriptPattern();
-
   /// Creates and adds a new instance of \c ASTNode to the parser result
   /// and returns a pointer to it.
   template <typename Node, typename... Args> Node *make(Args &&... args) {
diff --git a/lib/AST/ASTPrinter.cpp b/lib/AST/ASTPrinter.cpp
index 0b89050..625ab3f 100644
--- a/lib/AST/ASTPrinter.cpp
+++ b/lib/AST/ASTPrinter.cpp
@@ -149,6 +149,13 @@ public:
     return true;
   }
 
+  bool visit(SubscriptStmt *S) {
+    Printer.printStmtPre(S);
+    super::visit(S->getValue());
+    Printer.printStmtPost(S);
+    return true;
+  }
+
   bool visit(RangeStmt *S) {
     Printer.printStmtPre(S);
     super::visit(S->getStart());
@@ -252,13 +259,6 @@ public:
     Printer.printPatternPost(P);
     return true;
   }
-
-  bool visit(SubscriptPattern *P) {
-    Printer.printPatternPre(P);
-    super::visit(P->getValue());
-    Printer.printPatternPost(P);
-    return true;
-  }
 };
 
 /// Implementation of an \c ASTPrinter, which is used to pretty print the AST.
@@ -312,6 +312,9 @@ public:
       if (!isAtStartOfLine())
         printNewline();
       return;
+    case StmtKind::Subscript:
+      *this << "[";
+      break;
     default:
       return;
     }
@@ -319,13 +322,16 @@ public:
 
   virtual void printStmtPost(Stmt *S) override {
     switch (S->getKind()) {
+    case StmtKind::Block:
+      --(*this);
+      *this << tok::r_brace;
+      break;
     case StmtKind::Break:
     case StmtKind::Return:
       *this << ";";
       break;
-    case StmtKind::Block:
-      --(*this);
-      *this << tok::r_brace;
+    case StmtKind::Subscript:
+      *this << "]";
       break;
     default:
       return;
@@ -338,11 +344,6 @@ public:
     case PatternKind::Expr:
       *this << "(";
       break;
-    case PatternKind::Subscript:
-      *this << "[";
-      break;
-    default:
-      break;
     }
   }
 
@@ -352,11 +353,6 @@ public:
     case PatternKind::Expr:
       *this << ")";
       break;
-    case PatternKind::Subscript:
-      *this << "]";
-      break;
-    default:
-      break;
     }
   }
 };
diff --git a/lib/AST/ASTWalker.cpp b/lib/AST/ASTWalker.cpp
index d406656..2cd819d 100644
--- a/lib/AST/ASTWalker.cpp
+++ b/lib/AST/ASTWalker.cpp
@@ -208,6 +208,16 @@ public:
       return false;
     return Walker.postWalk(S);
   }
+  
+  bool visit(SubscriptStmt *S) {
+    // Skip subtree
+    if (!Walker.preWalk(S))
+      return true;
+    
+    if (!super::visit(S->getValue()))
+      return false;
+    return Walker.postWalk(S);
+  }
 
   bool visit(BlockStmt *S) {
     // Skip subtree
@@ -293,15 +303,6 @@ public:
         return false;
     return Walker.postWalk(P);
   }
-
-  bool visit(SubscriptPattern *P) {
-    // Skip subtree
-    if (!Walker.preWalk(P))
-      return true;
-
-    super::visit(P->getValue());
-    return Walker.postWalk(P);
-  }
 };
 
 } // End of anonymous namespace
diff --git a/lib/AST/Expr.cpp b/lib/AST/Expr.cpp
index 5f37462..996dae6 100644
--- a/lib/AST/Expr.cpp
+++ b/lib/AST/Expr.cpp
@@ -76,7 +76,7 @@ SMRange CallExpr::getSourceRange() const {
 
 // MARK: - Subscript expression
 
-SubscriptExpr::SubscriptExpr(Expr *B, Pattern *S)
+SubscriptExpr::SubscriptExpr(Expr *B, Stmt *S)
     : Expr(ExprKind::Subscript), Base(B), Subscript(S) {}
 
 SMRange SubscriptExpr::getSourceRange() const {
diff --git a/lib/AST/Pattern.cpp b/lib/AST/Pattern.cpp
index 9357e40..a17aa36 100644
--- a/lib/AST/Pattern.cpp
+++ b/lib/AST/Pattern.cpp
@@ -33,11 +33,3 @@ VarPattern::VarPattern(SmallVector<Decl *, 128> &&V, SMLoc L, SMLoc R)
 
 SMRange VarPattern::getSourceRange() const { return {LPar, RPar}; }
 size_t VarPattern::count() const { return Vars.size(); }
-
-// MARK: - Subscript pattern
-
-SubscriptPattern::SubscriptPattern(Expr *V, SMLoc L, SMLoc R)
-    : Pattern(PatternKind::Subscript), Value(V), LBracet(L), RBracet(R) {}
-
-SMRange SubscriptPattern::getSourceRange() const { return {LBracet, RBracet}; }
-size_t SubscriptPattern::count() const { return 1; }
diff --git a/lib/AST/Stmt.cpp b/lib/AST/Stmt.cpp
index 48c69f6..522922d 100644
--- a/lib/AST/Stmt.cpp
+++ b/lib/AST/Stmt.cpp
@@ -40,6 +40,12 @@ SMRange RangeStmt::getSourceRange() const {
   return {Start->getLocStart(), End->getLocEnd()};
 }
 
+// MARK: - Subcsript statement
+SubscriptStmt::SubscriptStmt(Expr *V, SMLoc L, SMLoc R)
+    : Stmt(StmtKind::Subscript), Value(V), LBracket(L), RBracket(R) {}
+
+SMRange SubscriptStmt::getSourceRange() const { return {LBracket, RBracket}; }
+
 // MARK: - Block statement
 
 BlockStmt::BlockStmt(SMLoc S, SMLoc E, std::vector<ASTNode *> &&N)
diff --git a/lib/Parser/ParseExpr.cpp b/lib/Parser/ParseExpr.cpp
index 84c547d..270c560 100644
--- a/lib/Parser/ParseExpr.cpp
+++ b/lib/Parser/ParseExpr.cpp
@@ -260,7 +260,7 @@ Expr *Parser::parseCallExpr(Expr *Dest) {
 Expr *Parser::parseSubscriptExpr(Expr *Dest) {
   // Validate `[`
   assert(Tok.is(tok::l_bracket) && "Invalid parse method.");
-  return make<SubscriptExpr>(Dest, parseSubscriptPattern());
+  return make<SubscriptExpr>(Dest, parseSubscriptStmt());
 }
 
 /// PrimaryExpr ::= '(' Expr ')'
diff --git a/lib/Parser/ParsePattern.cpp b/lib/Parser/ParsePattern.cpp
index 6cf4a71..376f944 100644
--- a/lib/Parser/ParsePattern.cpp
+++ b/lib/Parser/ParsePattern.cpp
@@ -148,20 +148,3 @@ Decl *Parser::parseVarPatternItem() {
   }
 }
 
-/// SubscriptionPattern ::=
-///     [ Expr ]
-Pattern *Parser::parseSubscriptPattern() {
-  // Validate `[` start.
-  assert(Tok.is(tok::l_bracket) && "Invalid parse method.");
-
-  auto L = consumeToken();
-  auto V = parseExpr();
-  if (!consumeIf(tok::r_bracket)) {
-    diagnose(Tok.getLoc(), diag::DiagID::expected_r_bracket)
-      .fixItAfter("]", PreviousLoc);
-    return nullptr;
-  }
-
-  return make<SubscriptPattern>(V, L, PreviousLoc);
-}
-
diff --git a/lib/Parser/ParseStmt.cpp b/lib/Parser/ParseStmt.cpp
index d0720dd..79836c8 100644
--- a/lib/Parser/ParseStmt.cpp
+++ b/lib/Parser/ParseStmt.cpp
@@ -70,6 +70,23 @@ Stmt *Parser::parseReturnStmt() {
   return make<ReturnStmt>(RL, E);
 }
 
+/// SubscriptionPattern ::=
+///     [ Expr ]
+Stmt *Parser::parseSubscriptStmt() {
+  // Validate `[` start.
+  assert(Tok.is(tok::l_bracket) && "Invalid parse method.");
+  
+  auto L = consumeToken();
+  auto V = parseExpr();
+  if (!consumeIf(tok::r_bracket)) {
+    diagnose(Tok.getLoc(), diag::DiagID::expected_r_bracket)
+    .fixItAfter("]", PreviousLoc);
+    return nullptr;
+  }
+  
+  return make<SubscriptStmt>(V, L, PreviousLoc);
+}
+
 /// Block ::=
 ///     '{' BlockBody '}'
 Stmt *Parser::parseBlock() {
@@ -104,6 +121,9 @@ ASTNode *Parser::parseBlockBody() {
 
     case tok::kwVar:
       return parseVarDecl();
+        
+    case tok::kwConst:
+      return parseConstDecl();
 
     case tok::kwBreak:
       return parseBreakStmt();
-- 
GitLab