Initial version for chapter 1 of the Toy tutorial

--

PiperOrigin-RevId: 241549247
This commit is contained in:
Mehdi Amini 2019-04-02 10:02:07 -07:00 committed by Mehdi Amini
parent 7c1fc9e795
commit 38b71d6b84
15 changed files with 1593 additions and 1 deletions

View file

@ -43,3 +43,7 @@ add_subdirectory(lib)
add_subdirectory(tools)
add_subdirectory(unittests)
add_subdirectory(test)
if( LLVM_INCLUDE_EXAMPLES )
add_subdirectory(examples)
endif()

View file

@ -0,0 +1,2 @@
add_subdirectory(toy)

View file

@ -0,0 +1,9 @@
add_custom_target(Toy)
set_target_properties(Toy PROPERTIES FOLDER Examples)
macro(add_toy_chapter name)
add_dependencies(Toy ${name})
add_llvm_example(${name} ${ARGN})
endmacro(add_toy_chapter name)
add_subdirectory(Ch1)

View file

@ -0,0 +1,9 @@
set(LLVM_LINK_COMPONENTS
Support
)
add_toy_chapter(toyc-ch1
toyc.cpp
parser/AST.cpp
)
include_directories(include/)

View file

@ -0,0 +1,256 @@
//===- AST.h - Node definition for the Toy AST ----------------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements the AST for the Toy language. It is optimized for
// simplicity, not efficiency. The AST forms a tree structure where each node
// references its children using std::unique_ptr<>.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TUTORIAL_TOY_AST_H_
#define MLIR_TUTORIAL_TOY_AST_H_
#include "toy/Lexer.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include <vector>
namespace toy {
/// A variable
struct VarType {
enum { TY_FLOAT, TY_INT } elt_ty;
std::vector<int> shape;
};
/// Base class for all expression nodes.
class ExprAST {
public:
enum ExprASTKind {
Expr_VarDecl,
Expr_Return,
Expr_Num,
Expr_Literal,
Expr_Var,
Expr_BinOp,
Expr_Call,
Expr_Print, // builtin
Expr_If,
Expr_For,
};
ExprAST(ExprASTKind kind, Location location)
: kind(kind), location(location) {}
virtual ~ExprAST() = default;
ExprASTKind getKind() const { return kind; }
const Location &loc() { return location; }
private:
const ExprASTKind kind;
Location location;
};
/// A block-list of expressions.
using ExprASTList = std::vector<std::unique_ptr<ExprAST>>;
/// Expression class for numeric literals like "1.0".
class NumberExprAST : public ExprAST {
double Val;
public:
NumberExprAST(Location loc, double Val) : ExprAST(Expr_Num, loc), Val(Val) {}
double getValue() { return Val; }
/// LLVM style RTTI
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Num; }
};
///
class LiteralExprAST : public ExprAST {
std::vector<std::unique_ptr<ExprAST>> values;
std::vector<int64_t> dims;
public:
LiteralExprAST(Location loc, std::vector<std::unique_ptr<ExprAST>> values,
std::vector<int64_t> dims)
: ExprAST(Expr_Literal, loc), values(std::move(values)),
dims(std::move(dims)) {}
std::vector<std::unique_ptr<ExprAST>> &getValues() { return values; }
std::vector<int64_t> &getDims() { return dims; }
/// LLVM style RTTI
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Literal; }
};
/// Expression class for referencing a variable, like "a".
class VariableExprAST : public ExprAST {
std::string name;
public:
VariableExprAST(Location loc, const std::string &name)
: ExprAST(Expr_Var, loc), name(name) {}
llvm::StringRef getName() { return name; }
/// LLVM style RTTI
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Var; }
};
///
class VarDeclExprAST : public ExprAST {
std::string name;
VarType type;
std::unique_ptr<ExprAST> initVal;
public:
VarDeclExprAST(Location loc, const std::string &name, VarType type,
std::unique_ptr<ExprAST> initVal)
: ExprAST(Expr_VarDecl, loc), name(name), type(std::move(type)),
initVal(std::move(initVal)) {}
llvm::StringRef getName() { return name; }
ExprAST *getInitVal() { return initVal.get(); }
VarType &getType() { return type; }
/// LLVM style RTTI
static bool classof(const ExprAST *C) { return C->getKind() == Expr_VarDecl; }
};
///
class ReturnExprAST : public ExprAST {
llvm::Optional<std::unique_ptr<ExprAST>> expr;
public:
ReturnExprAST(Location loc, llvm::Optional<std::unique_ptr<ExprAST>> expr)
: ExprAST(Expr_Return, loc), expr(std::move(expr)) {}
llvm::Optional<ExprAST *> getExpr() {
if (expr.hasValue())
return expr->get();
return llvm::NoneType();
}
/// LLVM style RTTI
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Return; }
};
/// Expression class for a binary operator.
class BinaryExprAST : public ExprAST {
char Op;
std::unique_ptr<ExprAST> LHS, RHS;
public:
char getOp() { return Op; }
ExprAST *getLHS() { return LHS.get(); }
ExprAST *getRHS() { return RHS.get(); }
BinaryExprAST(Location loc, char Op, std::unique_ptr<ExprAST> LHS,
std::unique_ptr<ExprAST> RHS)
: ExprAST(Expr_BinOp, loc), Op(Op), LHS(std::move(LHS)),
RHS(std::move(RHS)) {}
/// LLVM style RTTI
static bool classof(const ExprAST *C) { return C->getKind() == Expr_BinOp; }
};
/// Expression class for function calls.
class CallExprAST : public ExprAST {
std::string Callee;
std::vector<std::unique_ptr<ExprAST>> Args;
public:
CallExprAST(Location loc, const std::string &Callee,
std::vector<std::unique_ptr<ExprAST>> Args)
: ExprAST(Expr_Call, loc), Callee(Callee), Args(std::move(Args)) {}
llvm::StringRef getCallee() { return Callee; }
llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() { return Args; }
/// LLVM style RTTI
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Call; }
};
/// Expression class for builtin print calls.
class PrintExprAST : public ExprAST {
std::unique_ptr<ExprAST> Arg;
public:
PrintExprAST(Location loc, std::unique_ptr<ExprAST> Arg)
: ExprAST(Expr_Print, loc), Arg(std::move(Arg)) {}
ExprAST *getArg() { return Arg.get(); }
/// LLVM style RTTI
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Print; }
};
/// This class represents the "prototype" for a function, which captures its
/// name, and its argument names (thus implicitly the number of arguments the
/// function takes).
class PrototypeAST {
Location location;
std::string name;
std::vector<std::unique_ptr<VariableExprAST>> args;
public:
PrototypeAST(Location location, const std::string &name,
std::vector<std::unique_ptr<VariableExprAST>> args)
: location(location), name(name), args(std::move(args)) {}
const Location &loc() { return location; }
const std::string &getName() const { return name; }
const std::vector<std::unique_ptr<VariableExprAST>> &getArgs() {
return args;
}
};
/// This class represents a function definition itself.
class FunctionAST {
std::unique_ptr<PrototypeAST> Proto;
std::unique_ptr<ExprASTList> Body;
public:
FunctionAST(std::unique_ptr<PrototypeAST> Proto,
std::unique_ptr<ExprASTList> Body)
: Proto(std::move(Proto)), Body(std::move(Body)) {}
PrototypeAST *getProto() { return Proto.get(); }
ExprASTList *getBody() { return Body.get(); }
};
/// This class represents a list of functions to be processed together
class ModuleAST {
std::vector<FunctionAST> functions;
public:
ModuleAST(std::vector<FunctionAST> functions)
: functions(std::move(functions)) {}
auto begin() -> decltype(functions.begin()) { return functions.begin(); }
auto end() -> decltype(functions.end()) { return functions.end(); }
};
void dump(ModuleAST &);
} // namespace toy
#endif // MLIR_TUTORIAL_TOY_AST_H_

View file

@ -0,0 +1,239 @@
//===- Lexer.h - Lexer for the Toy language -------------------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements a simple Lexer for the Toy language.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TUTORIAL_TOY_LEXER_H_
#define MLIR_TUTORIAL_TOY_LEXER_H_
#include "llvm/ADT/StringRef.h"
#include <memory>
#include <string>
namespace toy {
/// Structure definition a location in a file.
struct Location {
std::shared_ptr<std::string> file; ///< filename
int line; ///< line number.
int col; ///< column number.
};
// List of Token returned by the lexer.
enum Token : int {
tok_semicolon = ';',
tok_parenthese_open = '(',
tok_parenthese_close = ')',
tok_bracket_open = '{',
tok_bracket_close = '}',
tok_sbracket_open = '[',
tok_sbracket_close = ']',
tok_eof = -1,
// commands
tok_return = -2,
tok_var = -3,
tok_def = -4,
// primary
tok_identifier = -5,
tok_number = -6,
};
/// The Lexer is an abstract base class providing all the facilities that the
/// Parser expects. It goes through the stream one token at a time and keeps
/// track of the location in the file for debugging purpose.
/// It relies on a subclass to provide a `readNextLine()` method. The subclass
/// can proceed by reading the next line from the standard input or from a
/// memory mapped file.
class Lexer {
public:
/// Create a lexer for the given filename. The filename is kept only for
/// debugging purpose (attaching a location to a Token).
Lexer(std::string filename)
: lastLocation(
{std::make_shared<std::string>(std::move(filename)), 0, 0}) {}
virtual ~Lexer() = default;
/// Look at the current token in the stream.
Token getCurToken() { return curTok; }
/// Move to the next token in the stream and return it.
Token getNextToken() { return curTok = getTok(); }
/// Move to the next token in the stream, asserting on the current token
/// matching the expectation.
void consume(Token tok) {
assert(tok == curTok && "consume Token mismatch expectation");
getNextToken();
}
/// Return the current identifier (prereq: getCurToken() == tok_identifier)
llvm::StringRef getId() {
assert(curTok == tok_identifier);
return IdentifierStr;
}
/// Return the current number (prereq: getCurToken() == tok_number)
double getValue() {
assert(curTok == tok_number);
return NumVal;
}
/// Return the location for the beginning of the current token.
Location getLastLocation() { return lastLocation; }
// Return the current line in the file.
int getLine() { return curLineNum; }
// Return the current column in the file.
int getCol() { return curCol; }
private:
/// Delegate to a derived class fetching the next line. Returns an empty
/// string to signal end of file (EOF). Lines are expected to always finish
/// with "\n"
virtual llvm::StringRef readNextLine() = 0;
/// Return the next character from the stream. This manages the buffer for the
/// current line and request the next line buffer to the derived class as
/// needed.
int getNextChar() {
// The current line buffer should not be empty unless it is the end of file.
if (curLineBuffer.empty())
return EOF;
++curCol;
auto nextchar = curLineBuffer.front();
curLineBuffer = curLineBuffer.drop_front();
if (curLineBuffer.empty())
curLineBuffer = readNextLine();
if (nextchar == '\n') {
++curLineNum;
curCol = 0;
}
return nextchar;
}
/// Return the next token from standard input.
Token getTok() {
// Skip any whitespace.
while (isspace(LastChar))
LastChar = Token(getNextChar());
// Save the current location before reading the token characters.
lastLocation.line = curLineNum;
lastLocation.col = curCol;
if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9_]*
IdentifierStr = (char)LastChar;
while (isalnum((LastChar = Token(getNextChar()))) || LastChar == '_')
IdentifierStr += (char)LastChar;
if (IdentifierStr == "return")
return tok_return;
if (IdentifierStr == "def")
return tok_def;
if (IdentifierStr == "var")
return tok_var;
return tok_identifier;
}
if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+
std::string NumStr;
do {
NumStr += LastChar;
LastChar = Token(getNextChar());
} while (isdigit(LastChar) || LastChar == '.');
NumVal = strtod(NumStr.c_str(), nullptr);
return tok_number;
}
if (LastChar == '#') {
// Comment until end of line.
do
LastChar = Token(getNextChar());
while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
if (LastChar != EOF)
return getTok();
}
// Check for end of file. Don't eat the EOF.
if (LastChar == EOF)
return tok_eof;
// Otherwise, just return the character as its ascii value.
Token ThisChar = Token(LastChar);
LastChar = Token(getNextChar());
return ThisChar;
}
/// The last token read from the input.
Token curTok = tok_eof;
/// Location for `curTok`.
Location lastLocation;
/// If the current Token is an identifier, this string contains the value.
std::string IdentifierStr;
/// If the current Token is a number, this contains the value.
double NumVal = 0;
/// The last value returned by getNextChar(). We need to keep it around as we
/// always need to read ahead one character to decide when to end a token and
/// we can't put it back in the stream after reading from it.
Token LastChar = Token(' ');
/// Keep track of the current line number in the input stream
int curLineNum = 0;
/// Keep track of the current column number in the input stream
int curCol = 0;
/// Buffer supplied by the derived class on calls to `readNextLine()`
llvm::StringRef curLineBuffer = "\n";
};
/// A lexer implementation operating on a buffer in memory.
class LexerBuffer final : public Lexer {
public:
LexerBuffer(const char *begin, const char *end, std::string filename)
: Lexer(std::move(filename)), current(begin), end(end) {}
private:
/// Provide one line at a time to the Lexer, return an empty string when
/// reaching the end of the buffer.
llvm::StringRef readNextLine() override {
auto *begin = current;
while (current <= end && *current && *current != '\n')
++current;
if (current <= end && *current)
++current;
llvm::StringRef result{begin, static_cast<size_t>(current - begin)};
return result;
}
const char *current, *end;
};
} // namespace toy
#endif // MLIR_TUTORIAL_TOY_LEXER_H_

View file

@ -0,0 +1,494 @@
//===- Parser.h - Toy Language Parser -------------------------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements the parser for the Toy language. It processes the Token
// provided by the Lexer and returns an AST.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TUTORIAL_TOY_PARSER_H
#define MLIR_TUTORIAL_TOY_PARSER_H
#include "toy/AST.h"
#include "toy/Lexer.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/raw_ostream.h"
#include <map>
#include <utility>
#include <vector>
namespace toy {
/// This is a simple recursive parser for the Toy language. It produces a well
/// formed AST from a stream of Token supplied by the Lexer. No semantic checks
/// or symbol resolution is performed. For example, variables are referenced by
/// string and the code could reference an undeclared variable and the parsing
/// succeeds.
class Parser {
public:
/// Create a Parser for the supplied lexer.
Parser(Lexer &lexer) : lexer(lexer) {}
/// Parse a full Module. A module is a list of function definitions.
std::unique_ptr<ModuleAST> ParseModule() {
lexer.getNextToken(); // prime the lexer
// Parse functions one at a time and accumulate in this vector.
std::vector<FunctionAST> functions;
while (auto F = ParseDefinition()) {
functions.push_back(std::move(*F));
if (lexer.getCurToken() == tok_eof)
break;
}
// If we didn't reach EOF, there was an error during parsing
if (lexer.getCurToken() != tok_eof)
return parseError<ModuleAST>("nothing", "at end of module");
return llvm::make_unique<ModuleAST>(std::move(functions));
}
private:
Lexer &lexer;
/// Parse a return statement.
/// return :== return ; | return expr ;
std::unique_ptr<ReturnExprAST> ParseReturn() {
auto loc = lexer.getLastLocation();
lexer.consume(tok_return);
// return takes an optional argument
llvm::Optional<std::unique_ptr<ExprAST>> expr;
if (lexer.getCurToken() != ';') {
expr = ParseExpression();
if (!expr)
return nullptr;
}
return llvm::make_unique<ReturnExprAST>(std::move(loc), std::move(expr));
}
/// Parse a literal number.
/// numberexpr ::= number
std::unique_ptr<ExprAST> ParseNumberExpr() {
auto loc = lexer.getLastLocation();
auto Result =
llvm::make_unique<NumberExprAST>(std::move(loc), lexer.getValue());
lexer.consume(tok_number);
return std::move(Result);
}
/// Parse a literal array expression.
/// tensorLiteral ::= [ literalList ] | number
/// literalList ::= tensorLiteral | tensorLiteral, literalList
std::unique_ptr<ExprAST> ParseTensorLitteralExpr() {
auto loc = lexer.getLastLocation();
lexer.consume(Token('['));
// Hold the list of values at this nesting level.
std::vector<std::unique_ptr<ExprAST>> values;
// Hold the dimensions for all the nesting inside this level.
std::vector<int64_t> dims;
do {
// We can have either another nested array or a number literal.
if (lexer.getCurToken() == '[') {
values.push_back(ParseTensorLitteralExpr());
if (!values.back())
return nullptr; // parse error in the nested array.
} else {
if (lexer.getCurToken() != tok_number)
return parseError<ExprAST>("<num> or [", "in literal expression");
values.push_back(ParseNumberExpr());
}
// End of this list on ']'
if (lexer.getCurToken() == ']')
break;
// Elements are separated by a comma.
if (lexer.getCurToken() != ',')
return parseError<ExprAST>("] or ,", "in literal expression");
lexer.getNextToken(); // eat ,
} while (true);
if (values.empty())
return parseError<ExprAST>("<something>", "to fill literal expression");
lexer.getNextToken(); // eat ]
/// Fill in the dimensions now. First the current nesting level:
dims.push_back(values.size());
/// If there is any nested array, process all of them and ensure that
/// dimensions are uniform.
if (llvm::any_of(values, [](std::unique_ptr<ExprAST> &expr) {
return llvm::isa<LiteralExprAST>(expr.get());
})) {
auto *firstLiteral = llvm::dyn_cast<LiteralExprAST>(values.front().get());
if (!firstLiteral)
return parseError<ExprAST>("uniform well-nested dimensions",
"inside literal expession");
// Append the nested dimensions to the current level
auto &firstDims = firstLiteral->getDims();
dims.insert(dims.end(), firstDims.begin(), firstDims.end());
// Sanity check that shape is uniform across all elements of the list.
for (auto &expr : values) {
auto *exprLiteral = llvm::cast<LiteralExprAST>(expr.get());
if (!exprLiteral)
return parseError<ExprAST>("uniform well-nested dimensions",
"inside literal expession");
if (exprLiteral->getDims() != firstDims)
return parseError<ExprAST>("uniform well-nested dimensions",
"inside literal expession");
}
}
return llvm::make_unique<LiteralExprAST>(std::move(loc), std::move(values),
std::move(dims));
}
/// parenexpr ::= '(' expression ')'
std::unique_ptr<ExprAST> ParseParenExpr() {
lexer.getNextToken(); // eat (.
auto V = ParseExpression();
if (!V)
return nullptr;
if (lexer.getCurToken() != ')')
return parseError<ExprAST>(")", "to close expression with parentheses");
lexer.consume(Token(')'));
return V;
}
/// identifierexpr
/// ::= identifier
/// ::= identifier '(' expression ')'
std::unique_ptr<ExprAST> ParseIdentifierExpr() {
std::string name = lexer.getId();
auto loc = lexer.getLastLocation();
lexer.getNextToken(); // eat identifier.
if (lexer.getCurToken() != '(') // Simple variable ref.
return llvm::make_unique<VariableExprAST>(std::move(loc), name);
// This is a function call.
lexer.consume(Token('('));
std::vector<std::unique_ptr<ExprAST>> Args;
if (lexer.getCurToken() != ')') {
while (true) {
if (auto Arg = ParseExpression())
Args.push_back(std::move(Arg));
else
return nullptr;
if (lexer.getCurToken() == ')')
break;
if (lexer.getCurToken() != ',')
return parseError<ExprAST>(", or )", "in argument list");
lexer.getNextToken();
}
}
lexer.consume(Token(')'));
// It can be a builtin call to print
if (name == "print") {
if (Args.size() != 1)
return parseError<ExprAST>("<single arg>", "as argument to print()");
return llvm::make_unique<PrintExprAST>(std::move(loc),
std::move(Args[0]));
}
// Call to a user-defined function
return llvm::make_unique<CallExprAST>(std::move(loc), name,
std::move(Args));
}
/// primary
/// ::= identifierexpr
/// ::= numberexpr
/// ::= parenexpr
/// ::= tensorliteral
std::unique_ptr<ExprAST> ParsePrimary() {
switch (lexer.getCurToken()) {
default:
llvm::errs() << "unknown token '" << lexer.getCurToken()
<< "' when expecting an expression\n";
return nullptr;
case tok_identifier:
return ParseIdentifierExpr();
case tok_number:
return ParseNumberExpr();
case '(':
return ParseParenExpr();
case '[':
return ParseTensorLitteralExpr();
case ';':
return nullptr;
case '}':
return nullptr;
}
}
/// Recursively parse the right hand side of a binary expression, the ExprPrec
/// argument indicates the precedence of the current binary operator.
///
/// binoprhs ::= ('+' primary)*
std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
std::unique_ptr<ExprAST> LHS) {
// If this is a binop, find its precedence.
while (true) {
int TokPrec = GetTokPrecedence();
// If this is a binop that binds at least as tightly as the current binop,
// consume it, otherwise we are done.
if (TokPrec < ExprPrec)
return LHS;
// Okay, we know this is a binop.
int BinOp = lexer.getCurToken();
lexer.consume(Token(BinOp));
auto loc = lexer.getLastLocation();
// Parse the primary expression after the binary operator.
auto RHS = ParsePrimary();
if (!RHS)
return parseError<ExprAST>("expression", "to complete binary operator");
// If BinOp binds less tightly with RHS than the operator after RHS, let
// the pending operator take RHS as its LHS.
int NextPrec = GetTokPrecedence();
if (TokPrec < NextPrec) {
RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
if (!RHS)
return nullptr;
}
// Merge LHS/RHS.
LHS = llvm::make_unique<BinaryExprAST>(std::move(loc), BinOp,
std::move(LHS), std::move(RHS));
}
}
/// expression::= primary binoprhs
std::unique_ptr<ExprAST> ParseExpression() {
auto LHS = ParsePrimary();
if (!LHS)
return nullptr;
return ParseBinOpRHS(0, std::move(LHS));
}
/// type ::= < shape_list >
/// shape_list ::= num | num , shape_list
std::unique_ptr<VarType> ParseType() {
if (lexer.getCurToken() != '<')
return parseError<VarType>("<", "to begin type");
lexer.getNextToken(); // eat <
auto type = llvm::make_unique<VarType>();
while (lexer.getCurToken() == tok_number) {
type->shape.push_back(lexer.getValue());
lexer.getNextToken();
if (lexer.getCurToken() == ',')
lexer.getNextToken();
}
if (lexer.getCurToken() != '>')
return parseError<VarType>(">", "to end type");
lexer.getNextToken(); // eat >
return type;
}
/// Parse a variable declaration, it starts with a `var` keyword followed by
/// and identifier and an optional type (shape specification) before the
/// initializer.
/// decl ::= var identifier [ type ] = expr
std::unique_ptr<VarDeclExprAST> ParseDeclaration() {
if (lexer.getCurToken() != tok_var)
return parseError<VarDeclExprAST>("var", "to begin declaration");
auto loc = lexer.getLastLocation();
lexer.getNextToken(); // eat var
if (lexer.getCurToken() != tok_identifier)
return parseError<VarDeclExprAST>("identified",
"after 'var' declaration");
std::string id = lexer.getId();
lexer.getNextToken(); // eat id
std::unique_ptr<VarType> type; // Type is optional, it can be inferred
if (lexer.getCurToken() == '<') {
type = ParseType();
if (!type)
return nullptr;
}
if (!type)
type = llvm::make_unique<VarType>();
lexer.consume(Token('='));
auto expr = ParseExpression();
return llvm::make_unique<VarDeclExprAST>(std::move(loc), std::move(id),
std::move(*type), std::move(expr));
}
/// Parse a block: a list of expression separated by semicolons and wrapped in
/// curly braces.
///
/// block ::= { expression_list }
/// expression_list ::= block_expr ; expression_list
/// block_expr ::= decl | "return" | expr
std::unique_ptr<ExprASTList> ParseBlock() {
if (lexer.getCurToken() != '{')
return parseError<ExprASTList>("{", "to begin block");
lexer.consume(Token('{'));
auto exprList = llvm::make_unique<ExprASTList>();
// Ignore empty expressions: swallow sequences of semicolons.
while (lexer.getCurToken() == ';')
lexer.consume(Token(';'));
while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) {
if (lexer.getCurToken() == tok_var) {
// Variable declaration
auto varDecl = ParseDeclaration();
if (!varDecl)
return nullptr;
exprList->push_back(std::move(varDecl));
} else if (lexer.getCurToken() == tok_return) {
// Return statement
auto ret = ParseReturn();
if (!ret)
return nullptr;
exprList->push_back(std::move(ret));
} else {
// General expression
auto expr = ParseExpression();
if (!expr)
return nullptr;
exprList->push_back(std::move(expr));
}
// Ensure that elements are separated by a semicolon.
if (lexer.getCurToken() != ';')
return parseError<ExprASTList>(";", "after expression");
// Ignore empty expressions: swallow sequences of semicolons.
while (lexer.getCurToken() == ';')
lexer.consume(Token(';'));
}
if (lexer.getCurToken() != '}')
return parseError<ExprASTList>("}", "to close block");
lexer.consume(Token('}'));
return exprList;
}
/// prototype ::= def id '(' decl_list ')'
/// decl_list ::= identifier | identifier, decl_list
std::unique_ptr<PrototypeAST> ParsePrototype() {
auto loc = lexer.getLastLocation();
lexer.consume(tok_def);
if (lexer.getCurToken() != tok_identifier)
return parseError<PrototypeAST>("function name", "in prototype");
std::string FnName = lexer.getId();
lexer.consume(tok_identifier);
if (lexer.getCurToken() != '(')
return parseError<PrototypeAST>("(", "in prototype");
lexer.consume(Token('('));
std::vector<std::unique_ptr<VariableExprAST>> args;
if (lexer.getCurToken() != ')') {
do {
std::string name = lexer.getId();
auto loc = lexer.getLastLocation();
lexer.consume(tok_identifier);
auto decl = llvm::make_unique<VariableExprAST>(std::move(loc), name);
args.push_back(std::move(decl));
if (lexer.getCurToken() != ',')
break;
lexer.consume(Token(','));
if (lexer.getCurToken() != tok_identifier)
return parseError<PrototypeAST>(
"identifier", "after ',' in function parameter list");
} while (true);
}
if (lexer.getCurToken() != ')')
return parseError<PrototypeAST>("}", "to end function prototype");
// success.
lexer.consume(Token(')'));
return llvm::make_unique<PrototypeAST>(std::move(loc), FnName,
std::move(args));
}
/// Parse a function definition, we expect a prototype initiated with the
/// `def` keyword, followed by a block containing a list of expressions.
///
/// definition ::= prototype block
std::unique_ptr<FunctionAST> ParseDefinition() {
auto Proto = ParsePrototype();
if (!Proto)
return nullptr;
if (auto block = ParseBlock())
return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(block));
return nullptr;
}
/// Get the precedence of the pending binary operator token.
int GetTokPrecedence() {
if (!isascii(lexer.getCurToken()))
return -1;
// 1 is lowest precedence.
switch (static_cast<char>(lexer.getCurToken())) {
case '-':
return 20;
case '+':
return 20;
case '*':
return 40;
default:
return -1;
}
}
/// Helper function to signal errors while parsing, it takes an argument
/// indicating the expected token and another argument giving more context.
/// Location is retrieved from the lexer to enrich the error message.
template <typename R, typename T, typename U = const char *>
std::unique_ptr<R> parseError(T &&expected, U &&context = "") {
auto curToken = lexer.getCurToken();
llvm::errs() << "Parse error (" << lexer.getLastLocation().line << ", "
<< lexer.getLastLocation().col << "): expected '" << expected
<< "' " << context << " but has Token " << curToken;
if (isprint(curToken))
llvm::errs() << " '" << (char)curToken << "'";
llvm::errs() << "\n";
return nullptr;
}
};
} // namespace toy
#endif // MLIR_TUTORIAL_TOY_PARSER_H

View file

@ -0,0 +1,263 @@
//===- AST.cpp - Helper for printing out the Toy AST ----------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements the AST dump for the Toy language.
//
//===----------------------------------------------------------------------===//
#include "toy/AST.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/raw_ostream.h"
using namespace toy;
namespace {
// RAII helper to manage increasing/decreasing the indentation as we traverse
// the AST
struct Indent {
Indent(int &level) : level(level) { ++level; }
~Indent() { --level; }
int &level;
};
/// Helper class that implement the AST tree traversal and print the nodes along
/// the way. The only data member is the current indentation level.
class ASTDumper {
public:
void dump(ModuleAST *Node);
private:
void dump(VarType &type);
void dump(VarDeclExprAST *varDecl);
void dump(ExprAST *expr);
void dump(ExprASTList *exprList);
void dump(NumberExprAST *num);
void dump(LiteralExprAST *Node);
void dump(VariableExprAST *Node);
void dump(ReturnExprAST *Node);
void dump(BinaryExprAST *Node);
void dump(CallExprAST *Node);
void dump(PrintExprAST *Node);
void dump(PrototypeAST *Node);
void dump(FunctionAST *Node);
// Actually print spaces matching the current indentation level
void indent() {
for (int i = 0; i < curIndent; i++)
llvm::errs() << " ";
}
int curIndent = 0;
};
} // namespace
/// Return a formatted string for the location of any node
template <typename T> static std::string loc(T *Node) {
const auto &loc = Node->loc();
return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" +
llvm::Twine(loc.col))
.str();
}
// Helper Macro to bump the indentation level and print the leading spaces for
// the current indentations
#define INDENT() \
Indent level_(curIndent); \
indent();
/// Dispatch to a generic expressions to the appropriate subclass using RTTI
void ASTDumper::dump(ExprAST *expr) {
#define dispatch(CLASS) \
if (CLASS *node = llvm::dyn_cast<CLASS>(expr)) \
return dump(node);
dispatch(VarDeclExprAST);
dispatch(LiteralExprAST);
dispatch(NumberExprAST);
dispatch(VariableExprAST);
dispatch(ReturnExprAST);
dispatch(BinaryExprAST);
dispatch(CallExprAST);
dispatch(PrintExprAST);
// No match, fallback to a generic message
INDENT();
llvm::errs() << "<unknown Expr, kind " << expr->getKind() << ">\n";
}
/// A variable declaration is printing the variable name, the type, and then
/// recurse in the initializer value.
void ASTDumper::dump(VarDeclExprAST *varDecl) {
INDENT();
llvm::errs() << "VarDecl " << varDecl->getName();
dump(varDecl->getType());
llvm::errs() << " " << loc(varDecl) << "\n";
dump(varDecl->getInitVal());
}
/// A "block", or a list of expression
void ASTDumper::dump(ExprASTList *exprList) {
INDENT();
llvm::errs() << "Block {\n";
for (auto &expr : *exprList)
dump(expr.get());
indent();
llvm::errs() << "} // Block\n";
}
/// A literal number, just print the value.
void ASTDumper::dump(NumberExprAST *num) {
INDENT();
llvm::errs() << num->getValue() << " " << loc(num) << "\n";
}
/// Helper to print recurisvely a literal. This handles nested array like:
/// [ [ 1, 2 ], [ 3, 4 ] ]
/// We print out such array with the dimensions spelled out at every level:
/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
void printLitHelper(ExprAST *lit_or_num) {
// Inside a literal expression we can have either a number or another literal
if (auto num = llvm::dyn_cast<NumberExprAST>(lit_or_num)) {
llvm::errs() << num->getValue();
return;
}
auto *literal = llvm::cast<LiteralExprAST>(lit_or_num);
// Print the dimension for this literal first
llvm::errs() << "<";
{
const char *sep = "";
for (auto dim : literal->getDims()) {
llvm::errs() << sep << dim;
sep = ", ";
}
}
llvm::errs() << ">";
// Now print the content, recursing on every element of the list
llvm::errs() << "[ ";
const char *sep = "";
for (auto &elt : literal->getValues()) {
llvm::errs() << sep;
printLitHelper(elt.get());
sep = ", ";
}
llvm::errs() << "]";
}
/// Print a literal, see the recursive helper above for the implementation.
void ASTDumper::dump(LiteralExprAST *Node) {
INDENT();
llvm::errs() << "Literal: ";
printLitHelper(Node);
llvm::errs() << " " << loc(Node) << "\n";
}
/// Print a variable reference (just a name).
void ASTDumper::dump(VariableExprAST *Node) {
INDENT();
llvm::errs() << "var: " << Node->getName() << " " << loc(Node) << "\n";
}
/// Return statement print the return and its (optional) argument.
void ASTDumper::dump(ReturnExprAST *Node) {
INDENT();
llvm::errs() << "Return\n";
if (Node->getExpr().hasValue())
return dump(*Node->getExpr());
{
INDENT();
llvm::errs() << "(void)\n";
}
}
/// Print a binary operation, first the operator, then recurse into LHS and RHS.
void ASTDumper::dump(BinaryExprAST *Node) {
INDENT();
llvm::errs() << "BinOp: " << Node->getOp() << " " << loc(Node) << "\n";
dump(Node->getLHS());
dump(Node->getRHS());
}
/// Print a call expression, first the callee name and the list of args by
/// recursing into each individual argument.
void ASTDumper::dump(CallExprAST *Node) {
INDENT();
llvm::errs() << "Call '" << Node->getCallee() << "' [ " << loc(Node) << "\n";
for (auto &arg : Node->getArgs())
dump(arg.get());
indent();
llvm::errs() << "]\n";
}
/// Print a builtin print call, first the builtin name and then the argument.
void ASTDumper::dump(PrintExprAST *Node) {
INDENT();
llvm::errs() << "Print [ " << loc(Node) << "\n";
dump(Node->getArg());
indent();
llvm::errs() << "]\n";
}
/// Print type: only the shape is printed in between '<' and '>'
void ASTDumper::dump(VarType &type) {
llvm::errs() << "<";
const char *sep = "";
for (auto shape : type.shape) {
llvm::errs() << sep << shape;
sep = ", ";
}
llvm::errs() << ">";
}
/// Print a function prototype, first the function name, and then the list of
/// parameters names.
void ASTDumper::dump(PrototypeAST *Node) {
INDENT();
llvm::errs() << "Proto '" << Node->getName() << "' " << loc(Node) << "'\n";
indent();
llvm::errs() << "Params: [";
const char *sep = "";
for (auto &arg : Node->getArgs()) {
llvm::errs() << sep << arg->getName();
sep = ", ";
}
llvm::errs() << "]\n";
}
/// Print a function, first the prototype and then the body.
void ASTDumper::dump(FunctionAST *Node) {
INDENT();
llvm::errs() << "Function \n";
dump(Node->getProto());
dump(Node->getBody());
}
/// Print a module, actually loop over the functions and print them in sequence.
void ASTDumper::dump(ModuleAST *Node) {
INDENT();
llvm::errs() << "Module:\n";
for (auto &F : *Node)
dump(&F);
}
namespace toy {
// Public API
void dump(ModuleAST &module) { ASTDumper().dump(&module); }
} // namespace toy

View file

@ -0,0 +1,75 @@
//===- toyc.cpp - The Toy Compiler ----------------------------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements the entry point for the Toy compiler.
//
//===----------------------------------------------------------------------===//
#include "toy/Parser.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/ErrorOr.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/raw_ostream.h"
using namespace toy;
namespace cl = llvm::cl;
static cl::opt<std::string> InputFilename(cl::Positional,
cl::desc("<input toy file>"),
cl::init("-"),
cl::value_desc("filename"));
namespace {
enum Action { None, DumpAST };
}
static cl::opt<enum Action>
emitAction("emit", cl::desc("Select the kind of output desired"),
cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")));
/// Returns a Toy AST resulting from parsing the file or a nullptr on error.
std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> FileOrErr =
llvm::MemoryBuffer::getFileOrSTDIN(filename);
if (std::error_code EC = FileOrErr.getError()) {
llvm::errs() << "Could not open input file: " << EC.message() << "\n";
return nullptr;
}
auto buffer = FileOrErr.get()->getBuffer();
LexerBuffer lexer(buffer.begin(), buffer.end(), filename);
Parser parser(lexer);
return parser.ParseModule();
}
int main(int argc, char **argv) {
cl::ParseCommandLineOptions(argc, argv, "toy compiler\n");
auto moduleAST = parseInputFile(InputFilename);
if (!moduleAST)
return 1;
switch (emitAction) {
case Action::DumpAST:
dump(*moduleAST);
return 0;
default:
llvm::errs() << "No action specified (parsing only?), use -emit=<action>\n";
}
return 0;
}

View file

@ -0,0 +1,149 @@
# Chapter 1: Toy Tutorial Introduction
This tutorial runs through the implementation of a basic toy language on top of
MLIR. The goal of this tutorial is to introduce the concepts of MLIR, and
especially how *dialects* can help easily support language specific constructs
and transformations, while still offering an easy path to lower to LLVM or other
codegen infrastructure. This tutorial is based on the model of the
[LLVM Kaleidoscope Tutorial](https://llvm.org/docs/tutorial/LangImpl01.html).
This tutorial is divided in the following chapters:
- [Chapter #1](Ch-1.md): Introduction to the Toy language, and the definition
of its AST.
- [Chapter #2](Ch-2.md): Traversing the AST to emit custom MLIR, introducing
base MLIR concepts.
- [Chapter #3](Ch-3.md): Defining and registering a dialect in MLIR, showing
how we can start attaching semantics to our custom operations in MLIR.
- [Chapter #4](Ch-4.md): High-level language-specific analysis and
transformation, showcasing shape inference, generic function specialization,
and basic optimizations.
- [Chapter #5](Ch-5.md): Lowering to lower-level dialects. We'll convert our
high level language specific semantics towards a generic linear-algebra
oriented dialect for optimizations. Ultimately we will emit LLVM IR for code
generation.
- [Chapter #5](Ch-6.md): A REPL?
- [Chapter #6](Ch-7.md): Custom backends? GPU using LLVM? TPU? XLA
## The Language
This tutorial will be illustrated with a toy language that well call “Toy”
(naming is hard...). Toy is an array-based language that allows you to define
functions, some math computation, and print results.
Because we want to keep things simple, the codegen will be limited to arrays of
rank <= 2 and the only datatype in Toy is a 64-bit floating point type (aka
double in C parlance). As such, all values are implicitly double precision,
Values are immutable: every operation returns a newly allocated value, and
deallocation is automatically managed. But enough with the long description,
nothing is better than walking through an example to get a better understanding:
FIXME: update/modify matrix multiplication to use @ instead of *
```Toy {.toy}
def main() {
# Define a variable `a` with shape <2, 3>, initialized with the literal value.
# The shape is inferred from the supplied literal.
var a = [[1, 2, 3], [4, 5, 6]];
# b is identical to a, the literal array is implicitely reshaped: defining new
# variables is the way to reshape arrays (element count must match).
var b<2, 3> = [1, 2, 3, 4, 5, 6];
# transpose() and print() are the only builtin, the following will transpose
# b and perform a matrix multiplication before printing the result.
print(a * transpose(b));
}
```
Type checking is statically performed through type inference, the language only
requires type declarations to specify array shapes when needed. Function are
generic: their parameters are unranked (in other word we know these are arrays
but we don't know how many dimensions or the size of the dimensions). They are
specialized for every newly discovered signature at call sites. Let's revisit
the previous example by adding a user-defined function:
```Toy {.toy}
# User defined generic function that operates on unknown shaped arguments
def multiply_transpose(a, b) {
return a * transpose(b);
}
def main() {
# Define a variable `a` with shape <2, 3>, initialized with the literal value.
var a = [[1, 2, 3], [4, 5, 6]];
var b<2, 3> = [1, 2, 3, 4, 5, 6];
# This call will specialize `multiply_transpose` with <2, 3> for both
# arguments and deduce a return type of <2, 2> in initialization of `c`.
var c = multiply_transpose(a, b);
# A second call to `multiply_transpose` with <2, 3> for both arguments will
# reuse the previously specialized and inferred version and return `<2, 2>`
var d = multiply_transpose(b, a);
# A new call with `<2, 2>` for both dimension will trigger another
# specialization of `multiply_transpose`.
var e = multiply_transpose(c, d);
# Finally, calling into `multiply_transpose` with incompatible shape will
# trigger a shape inference error.
var e = multiply_transpose(transpose(a), c);
}
```
## The AST
The AST is fairly straightforward from the above code, here is a dump of it:
```
Module:
Function
Proto 'multiply_transpose' @test/ast.toy:5:1'
Args: [a, b]
Block {
Return
BinOp: * @test/ast.toy:6:12
var: a @test/ast.toy:6:10
Call 'transpose' [ @test/ast.toy:6:14
var: b @test/ast.toy:6:24
]
} // Block
Function
Proto 'main' @test/ast.toy:9:1'
Args: []
Block {
VarDecl a<2, 3> @test/ast.toy:11:3
Literal: <2, 3>[<3>[1.000000e+00, 2.000000e+00, 3.000000e+00], <3>[4.000000e+00, 5.000000e+00, 6.000000e+00]] @test/ast.toy:11:17
VarDecl b<2, 3> @test/ast.toy:12:3
Literal: <6>[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00] @test/ast.toy:12:17
VarDecl c<> @test/ast.toy:15:3
Call 'multiply_transpose' [ @test/ast.toy:15:11
var: a @test/ast.toy:15:30
var: b @test/ast.toy:15:33
]
VarDecl d<> @test/ast.toy:18:3
Call 'multiply_transpose' [ @test/ast.toy:18:11
var: b @test/ast.toy:18:30
var: a @test/ast.toy:18:33
]
VarDecl e<> @test/ast.toy:21:3
Call 'multiply_transpose' [ @test/ast.toy:21:11
var: b @test/ast.toy:21:30
var: c @test/ast.toy:21:33
]
VarDecl e<> @test/ast.toy:24:3
Call 'multiply_transpose' [ @test/ast.toy:24:11
Call 'transpose' [ @test/ast.toy:24:30
var: a @test/ast.toy:24:40
]
var: c @test/ast.toy:24:44
]
} // Block
```
You can reproduce this result and play with the example in the `examples/Ch1/`
directory, try running `path/to/BUILD/bin/toyc test/ast.toy -emit=ast`.
The code for the lexer is fairly straighforward, it is all in a single header:
`examples/toy/Ch1/include/toy/Lexer.h`. The parser can be found in
`examples/toy/Ch1/include/toy/Parser.h`, it is a recursive descent parser. If
you are not familiar with such Lexer/Parser, these are very similar to the LLVM
Kaleidoscope equivalent that are detailed in the first two chapters of the
[Kaleidoscope Tutorial](https://llvm.org/docs/tutorial/LangImpl02.html#the-abstract-syntax-tree-ast).
The [next chapter](Ch-2.md) will demonstrate how to convert this AST into MLIR.

View file

@ -1,3 +1,8 @@
llvm_canonicalize_cmake_booleans(
LLVM_BUILD_EXAMPLES
)
configure_lit_site_cfg(
${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in
${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py
@ -20,6 +25,13 @@ set(MLIR_TEST_DEPENDS
mlir-translate
)
if(LLVM_BUILD_EXAMPLES)
list(APPEND MLIR_TEST_DEPENDS
toyc-ch1
)
endif()
add_lit_testsuite(check-mlir "Running the MLIR regression tests"
${CMAKE_CURRENT_BINARY_DIR}
DEPENDS ${MLIR_TEST_DEPENDS}

View file

@ -0,0 +1,71 @@
# RUN: toyc-ch1 %s -emit=ast 2>&1 | FileCheck %s
# User defined generic function that operates solely on
def multiply_transpose(a, b) {
return a * transpose(b);
}
def main() {
# Define a variable `a` with shape <2, 3>, initialized with the literal value
var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
var b<2, 3> = [1, 2, 3, 4, 5, 6];
# This call will specialize `multiply_transpose` with <2, 3> for both
# arguments and deduce a return type of <2, 2> in initialization of `c`.
var c = multiply_transpose(a, b);
# A second call to `multiply_transpose` with <2, 3> for both arguments will
# reuse the previously specialized and inferred version and return `<2, 2>`
var d = multiply_transpose(b, a);
# A new call with `<2, 2>` for both dimension will trigger another
# specialization of `multiply_transpose`.
var e = multiply_transpose(b, c);
# Finally, calling into `multiply_transpose` with incompatible shape will
# trigger a shape inference error.
var e = multiply_transpose(transpose(a), c);
}
# CHECK: Module:
# CHECK-NEXT: Function
# CHECK-NEXT: Proto 'multiply_transpose'
# CHECK-NEXT: Params: [a, b]
# CHECK-NEXT: Block {
# CHECK-NEXT: Retur
# CHECK-NEXT: BinOp: *
# CHECK-NEXT: var: a
# CHECK-NEXT: Call 'transpose' [
# CHECK-NEXT: var: b
# CHECK-NEXT: ]
# CHECK-NEXT: } // Block
# CHECK-NEXT: Function
# CHECK-NEXT: Proto 'main'
# CHECK-NEXT: Params: []
# CHECK-NEXT: Block {
# CHECK-NEXT: VarDecl a<2, 3>
# CHECK-NEXT: Literal: <2, 3>[ <3>[ 1.000000e+00, 2.000000e+00, 3.000000e+00], <3>[ 4.000000e+00, 5.000000e+00, 6.000000e+00]]
# CHECK-NEXT: VarDecl b<2, 3>
# CHECK-NEXT: Literal: <6>[ 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]
# CHECK-NEXT: VarDecl c<>
# CHECK-NEXT: Call 'multiply_transpose' [
# CHECK-NEXT: var: a
# CHECK-NEXT: var: b
# CHECK-NEXT: ]
# CHECK-NEXT: VarDecl d<>
# CHECK-NEXT: Call 'multiply_transpose' [
# CHECK-NEXT: var: b
# CHECK-NEXT: var: a
# CHECK-NEXT: ]
# CHECK-NEXT: VarDecl e<>
# CHECK-NEXT: Call 'multiply_transpose' [
# CHECK-NEXT: var: b
# CHECK-NEXT: var: c
# CHECK-NEXT: ]
# CHECK-NEXT: VarDecl e<>
# CHECK-NEXT: Call 'multiply_transpose' [
# CHECK-NEXT: Call 'transpose' [
# CHECK-NEXT: var: a
# CHECK-NEXT: ]
# CHECK-NEXT: var: c
# CHECK-NEXT: ]
# CHECK-NEXT: } // Block

View file

@ -0,0 +1,2 @@
if not config.build_examples:
config.unsupported = True

View file

@ -21,7 +21,7 @@ config.name = 'MLIR'
config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell)
# suffixes: A list of file extensions to treat as test files.
config.suffixes = ['.td', '.mlir']
config.suffixes = ['.td', '.mlir', '.toy']
# test_source_root: The root path where tests are located.
config.test_source_root = os.path.dirname(__file__)
@ -54,4 +54,10 @@ tool_dirs = [config.mlir_tools_dir, config.llvm_tools_dir]
tools = [
'mlir-opt', 'mlir-tblgen', 'mlir-translate',
]
# The following tools are optional
tools.extend([
ToolSubst('toy-ch1', unresolved='ignore'),
])
llvm_config.add_tool_substitutions(tools, tool_dirs)

View file

@ -30,6 +30,7 @@ config.host_arch = "@HOST_ARCH@"
config.mlir_src_root = "@MLIR_SOURCE_DIR@"
config.mlir_obj_root = "@MLIR_BINARY_DIR@"
config.mlir_tools_dir = "@MLIR_TOOLS_DIR@"
config.build_examples = @LLVM_BUILD_EXAMPLES@
# Support substitution of the tools_dir with user parameters. This is
# used when we can't determine the tool dir at configuration time.