Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions include/artic/ast.h
Original file line number Diff line number Diff line change
Expand Up @@ -1506,6 +1506,32 @@ struct TypeDecl : public NamedDecl {
void print(Printer&) const override;
};

struct ExtTypeDecl : public NamedDecl {
Ptr<TypeParamList> type_params;
std::string type_name;
std::vector<std::variant<Ptr<Type>, Ptr<Expr>>> args_;
std::vector<const artic::Type*> args_types_;

ExtTypeDecl(
const Loc& loc,
Identifier&& id,
std::string&& type_name,
Ptr<TypeParamList>&& type_params,
std::vector<std::variant<Ptr<Type>, Ptr<Expr>>>&& args)
: NamedDecl(loc, std::move(id))
, type_name(type_name)
, type_params(std::move(type_params))
, args_(std::move(args)) {
}

const thorin::Def* emit(Emitter&) const override;
const artic::Type* infer(TypeChecker&) override;
void bind_head(NameBinder&) override;
void bind(NameBinder&) override;
void resolve_summons(Summoner&) override {};
void print(Printer&) const override;
};

/// Module definition.
struct ModDecl : public NamedDecl {
PtrVector<Decl> decls;
Expand Down
1 change: 1 addition & 0 deletions include/artic/parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class Parser : public Logger {
Ptr<ast::ImplicitDecl> parse_implicit_decl();
Ptr<ast::StaticDecl> parse_static_decl();
Ptr<ast::TypeDecl> parse_type_decl();
Ptr<ast::ExtTypeDecl> parse_ext_type_decl();
Ptr<ast::TypeParam> parse_type_param();
Ptr<ast::TypeParamList> parse_type_params();
Ptr<ast::ModDecl> parse_mod_decl();
Expand Down
1 change: 1 addition & 0 deletions include/artic/token.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ namespace artic {
f(Struct, "struct") \
f(Enum, "enum") \
f(Type, "type") \
f(TypeExt, "type_ext") \
f(Static, "static") \
f(Implicit, "implicit") \
f(Summon, "summon") \
Expand Down
17 changes: 17 additions & 0 deletions include/artic/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,22 @@ struct TypeAlias : public PolyTypeFromDecl<UserType, ast::TypeDecl> {
friend class TypeTable;
};

struct ExtType : public PolyTypeFromDecl<UserType, ast::ExtTypeDecl> {
void print(Printer&) const override;

private:
ExtType(TypeTable& type_table, const ast::ExtTypeDecl& decl)
: PolyTypeFromDecl(type_table, decl)
{}

using UserType::convert;
const thorin::Type* convert(Emitter&, const Type*) const override;

std::string stringify(Emitter&) const override;

friend class TypeTable;
};

/// An application of a complex type with polymorphic parameters.
struct TypeApp : public Type {
const UserType* applied;
Expand Down Expand Up @@ -684,6 +700,7 @@ class TypeTable {
const EnumType* enum_type(const ast::EnumDecl&);
const ModType* mod_type(const ast::ModDecl&);
const TypeAlias* type_alias(const ast::TypeDecl&);
const ExtType* ext_type(const ast::ExtTypeDecl&);

/// Creates a type application for structures/enumeration types,
/// or returns the type alias expanded with the given type arguments.
Expand Down
18 changes: 18 additions & 0 deletions src/bind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -509,13 +509,31 @@ void TypeDecl::bind_head(NameBinder& binder) {
binder.insert_symbol(*this);
}

void ExtTypeDecl::bind_head(NameBinder& binder) {
binder.insert_symbol(*this);
}

void TypeDecl::bind(NameBinder& binder) {
binder.push_scope();
if (type_params) binder.bind(*type_params);
binder.bind(*aliased_type);
binder.pop_scope();
}

void ExtTypeDecl::bind(NameBinder& binder) {
binder.push_scope();
//if (type_params) binder.bind(*type_params);
for (auto& arg : args_) {
if (auto t = std::get_if<Ptr<Type>>(&arg))
binder.bind(**t);
else if (auto e = std::get_if<Ptr<Expr>>(&arg))
binder.bind(**e);
else
assert(false);
}
binder.pop_scope();
}

void ModDecl::bind_head(NameBinder& binder) {
if (id.name != "")
binder.insert_symbol(*this);
Expand Down
31 changes: 26 additions & 5 deletions src/check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ const Type* TypeChecker::invalid_cast(const Loc& loc, const Type* type, const Ty

const Type* TypeChecker::invalid_simd(const Loc& loc, const Type* elem_type) {
if (should_report_error(elem_type))
error(loc, "expected primitive type for simd type component, but got '{}'", *elem_type);
error(loc, "expected primitive or pointer type for simd type component, but got '{}'", *elem_type);
return type_table.type_error();
}

Expand Down Expand Up @@ -489,7 +489,7 @@ const Type* TypeChecker::infer_array(
if (elem_count == 0)
return cannot_infer(loc, msg);
auto elem_type = infer_elems();
if (is_simd && !elem_type->template isa<PrimType>())
if (is_simd && !(elem_type->template isa<PrimType>() || elem_type->template isa<PtrType>()))
return invalid_simd(loc, elem_type);
return type_table.sized_array_type(elem_type, elem_count, is_simd);
}
Expand All @@ -509,7 +509,7 @@ const Type* TypeChecker::check_array(
if (is_simd_type(array_type) != is_simd)
return incompatible_type(loc, (is_simd ? "simd " : "non-simd ") + std::string(msg), expected);
auto elem_type = array_type->elem;
if (is_simd && !elem_type->isa<PrimType>())
if (is_simd && !(elem_type->template isa<PrimType>() || elem_type->template isa<PtrType>()))
return invalid_simd(loc, elem_type);
check_elems(elem_type);
if (auto sized_array_type = array_type->isa<artic::SizedArrayType>();
Expand Down Expand Up @@ -847,7 +847,7 @@ const artic::Type* TupleType::infer(TypeChecker& checker) {

const artic::Type* SizedArrayType::infer(TypeChecker& checker) {
auto elem_type = checker.infer(*elem);
if (is_simd && !elem_type->isa<artic::PrimType>())
if (is_simd && !(elem_type->template isa<artic::PrimType>() || elem_type->template isa<artic::PtrType>()))
return checker.invalid_simd(loc, elem_type);

if (std::holds_alternative<ast::Path>(size)) {
Expand Down Expand Up @@ -1022,7 +1022,7 @@ const artic::Type* ArrayExpr::check(TypeChecker& checker, const artic::Type* exp

const artic::Type* RepeatArrayExpr::infer(TypeChecker& checker) {
auto elem_type = checker.deref(elem);
if (is_simd && !elem_type->isa<artic::PrimType>())
if (is_simd && !(elem_type->template isa<artic::PrimType>() || elem_type->template isa<artic::PtrType>()))
return checker.invalid_simd(loc, elem_type);

if (std::holds_alternative<ast::Path>(size)) {
Expand Down Expand Up @@ -1816,6 +1816,27 @@ const artic::Type* TypeDecl::infer(TypeChecker& checker) {
return type;
}

const artic::Type* ExtTypeDecl::infer(TypeChecker& checker) {
if (!checker.enter_decl(this))
return checker.type_table.type_error();
const ExtType* ext_type = checker.type_table.ext_type(*this);
// Set the type before entering the args
type = ext_type;
for (auto& arg : args_) {
if (auto t = std::get_if<Ptr<Type>>(&arg))
args_types_.emplace_back(checker.infer(**t));
else if (auto e = std::get_if<Ptr<Expr>>(&arg)) {
checker.infer(**e);
if (!(*e)->is_constant())
checker.error((*e)->loc, "only constants are allowed as external type members");
args_types_.emplace_back(nullptr);
} else
assert(false);
}
checker.exit_decl(this);
return ext_type;
}

const artic::Type* ModDecl::infer(TypeChecker& checker) {
for (auto& decl : decls)
checker.infer(*decl);
Expand Down
42 changes: 40 additions & 2 deletions src/emit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1756,6 +1756,10 @@ const thorin::Def* TypeDecl::emit(Emitter&) const {
return nullptr;
}

const thorin::Def* ExtTypeDecl::emit(Emitter&) const {
return nullptr;
}

const thorin::Def* ModDecl::emit(Emitter& emitter) const {
for (auto& decl : decls) {
// Do not emit polymorphic functions directly: Those will be emitted from
Expand Down Expand Up @@ -1885,8 +1889,17 @@ std::string SizedArrayType::stringify(Emitter& emitter) const {
}

const thorin::Type* SizedArrayType::convert(Emitter& emitter) const {
if (is_simd)
return emitter.world.prim_type(elem->convert(emitter)->as<thorin::PrimType>()->primtype_tag(), size);
if (is_simd) {
auto elem_type = elem->convert(emitter);
if (auto prim_type = elem_type->isa<thorin::PrimType>())
return emitter.world.prim_type(prim_type->primtype_tag(), size);
else if (auto ptr_type = elem_type->isa<thorin::PtrType>())
return emitter.world.ptr_type(ptr_type->pointee(), size, ptr_type->addr_space());

//This should be unreachable after type checking.
assert(false);
return nullptr;
}
return emitter.world.definite_array_type(elem->convert(emitter), size);
}

Expand Down Expand Up @@ -2023,6 +2036,31 @@ const thorin::Type* TypeApp::convert(Emitter& emitter) const {
return result;
}

std::string ExtType::stringify(Emitter& emitter) const {
if (!type_params())
return decl.id.name;
return stringify_params(emitter, decl.id.name + "_", type_params()->params);
}

const thorin::Type* ExtType::convert(Emitter& emitter, const Type* parent) const {
if (auto it = emitter.types.find(this); !type_params() && it != emitter.types.end())
return it->second;

auto type = emitter.world.extern_type(decl.type_name, decl.args_types_.size(), { decl.id.name });
emitter.types[parent] = type;

for (size_t i = 0; i < decl.args_types_.size(); i++) {
if (auto t = decl.args_types_[i]) {
type->set_op(i, t->convert(emitter));
} else if (auto e = std::get_if<Ptr<ast::Expr>>(&decl.args_[i]))
type->set_op(i, emitter.emit(**e));
else
assert(false);
}

return type;
}

// A read-only buffer from memory, not performing any copy.
struct MemBuf : public std::streambuf {
MemBuf(const std::string& str) {
Expand Down
1 change: 1 addition & 0 deletions src/lexer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ std::unordered_map<std::string, Token::Tag> Lexer::keywords{
std::make_pair("struct", Token::Struct),
std::make_pair("enum", Token::Enum),
std::make_pair("type", Token::Type),
std::make_pair("type_ext", Token::TypeExt),
std::make_pair("implicit", Token::Implicit),
std::make_pair("summon", Token::Summon),
std::make_pair("static", Token::Static),
Expand Down
25 changes: 25 additions & 0 deletions src/parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ Ptr<ast::Decl> Parser::parse_decl(bool is_top_level) {
case Token::Struct: decl = parse_struct_decl(); break;
case Token::Enum: decl = parse_enum_decl(); break;
case Token::Type: decl = parse_type_decl(); break;
case Token::TypeExt: decl = parse_ext_type_decl(); break;
case Token::Implicit: decl = parse_implicit_decl(); break;
case Token::Static: decl = parse_static_decl(); break;
case Token::Mod: decl = parse_mod_decl(); break;
Expand Down Expand Up @@ -200,6 +201,30 @@ Ptr<ast::TypeDecl> Parser::parse_type_decl() {
return make_ptr<ast::TypeDecl>(tracker(), std::move(id), std::move(type_params), std::move(aliased_type));
}

Ptr<ast::ExtTypeDecl> Parser::parse_ext_type_decl() {
Tracker tracker(this);
eat(Token::TypeExt);
auto id = parse_id();

Ptr<ast::TypeParamList> type_params;
// if (ahead().tag() == Token::LBracket)
// type_params = parse_type_params();

std::vector<std::variant<Ptr<ast::Type>, Ptr<ast::Expr>>> args;
expect(Token::Eq);
auto type_name = parse_str();
expect(Token::LBrace);
parse_list(Token::RBrace, Token::Comma, [&] {
if (accept(Token::Type))
args.emplace_back(parse_type());
else
args.emplace_back(parse_expr());
});

expect(Token::Semi);
return make_ptr<ast::ExtTypeDecl>(tracker(), std::move(id), std::move(type_name), std::move(type_params), std::move(args));
}

Ptr<ast::ImplicitDecl> Parser::parse_implicit_decl() {
Tracker tracker(this);
eat(Token::Implicit);
Expand Down
20 changes: 20 additions & 0 deletions src/print.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,22 @@ void TypeDecl::print(Printer& p) const {
p << ';';
}

void ExtTypeDecl::print(Printer& p) const {
if (attrs) attrs->print(p);
p << log::keyword_style("type_ext") << ' ' << id.name;
// if (type_params) type_params->print(p);
p << " = \"" << type_name << "\" { ";
print_list(p, ',', args_, [&] (auto& arg) {
if (auto t = std::get_if<Ptr<ast::Type>>(&arg)) {
p << **t;
} else if (auto e = std::get_if<Ptr<ast::Expr>>(&arg))
p << **e;
else
p << "invalid";
});
p << " };";
}

void ModDecl::print(Printer& p) const {
if (attrs) attrs->print(p);
bool anon = id.name == "";
Expand Down Expand Up @@ -825,6 +841,10 @@ void TypeAlias::print(Printer& p) const {
p << decl.id.name;
}

void ExtType::print(Printer& p) const {
p << decl.id.name;
}

void TypeApp::print(Printer& p) const {
applied->print(p);
p << '[';
Expand Down
4 changes: 4 additions & 0 deletions src/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,10 @@ const TypeAlias* TypeTable::type_alias(const ast::TypeDecl& decl) {
return insert<TypeAlias>(decl);
}

const ExtType* TypeTable::ext_type(const ast::ExtTypeDecl& decl) {
return insert<ExtType>(decl);
}

const Type* TypeTable::type_app(const UserType* applied, const ArrayRef<const Type*>& type_args) {
if (auto type_alias = applied->isa<TypeAlias>()) {
assert(type_alias->type_params() && type_alias->decl.aliased_type->type);
Expand Down
Loading