diff --git a/include/artic/ast.h b/include/artic/ast.h index 401450d8..6f175501 100644 --- a/include/artic/ast.h +++ b/include/artic/ast.h @@ -1506,6 +1506,32 @@ struct TypeDecl : public NamedDecl { void print(Printer&) const override; }; +struct ExtTypeDecl : public NamedDecl { + Ptr type_params; + std::string type_name; + std::vector, Ptr>> args_; + std::vector args_types_; + + ExtTypeDecl( + const Loc& loc, + Identifier&& id, + std::string&& type_name, + Ptr&& type_params, + std::vector, Ptr>>&& args) + : NamedDecl(loc, std::move(id)) + , type_name(type_name) + , type_params(std::move(type_params)) + , args_(std::move(args)) { + } + + const thorin::Def* emit(Emitter&) const override; + const artic::Type* infer(TypeChecker&) override; + void bind_head(NameBinder&) override; + void bind(NameBinder&) override; + void resolve_summons(Summoner&) override {}; + void print(Printer&) const override; +}; + /// Module definition. struct ModDecl : public NamedDecl { PtrVector decls; diff --git a/include/artic/parser.h b/include/artic/parser.h index 8d5c0de1..bdf4d565 100644 --- a/include/artic/parser.h +++ b/include/artic/parser.h @@ -34,6 +34,7 @@ class Parser : public Logger { Ptr parse_implicit_decl(); Ptr parse_static_decl(); Ptr parse_type_decl(); + Ptr parse_ext_type_decl(); Ptr parse_type_param(); Ptr parse_type_params(); Ptr parse_mod_decl(); diff --git a/include/artic/token.h b/include/artic/token.h index a641430b..81ff9cfa 100644 --- a/include/artic/token.h +++ b/include/artic/token.h @@ -30,6 +30,7 @@ namespace artic { f(Struct, "struct") \ f(Enum, "enum") \ f(Type, "type") \ + f(TypeExt, "type_ext") \ f(Static, "static") \ f(Implicit, "implicit") \ f(Summon, "summon") \ diff --git a/include/artic/types.h b/include/artic/types.h index 3a6955ad..c5e82e92 100644 --- a/include/artic/types.h +++ b/include/artic/types.h @@ -582,6 +582,22 @@ struct TypeAlias : public PolyTypeFromDecl { friend class TypeTable; }; +struct ExtType : public PolyTypeFromDecl { + void print(Printer&) const override; + +private: + ExtType(TypeTable& type_table, const ast::ExtTypeDecl& decl) + : PolyTypeFromDecl(type_table, decl) + {} + + using UserType::convert; + const thorin::Type* convert(Emitter&, const Type*) const override; + + std::string stringify(Emitter&) const override; + + friend class TypeTable; +}; + /// An application of a complex type with polymorphic parameters. struct TypeApp : public Type { const UserType* applied; @@ -684,6 +700,7 @@ class TypeTable { const EnumType* enum_type(const ast::EnumDecl&); const ModType* mod_type(const ast::ModDecl&); const TypeAlias* type_alias(const ast::TypeDecl&); + const ExtType* ext_type(const ast::ExtTypeDecl&); /// Creates a type application for structures/enumeration types, /// or returns the type alias expanded with the given type arguments. diff --git a/src/bind.cpp b/src/bind.cpp index 37c4a3f4..3bdc71f2 100644 --- a/src/bind.cpp +++ b/src/bind.cpp @@ -509,6 +509,10 @@ void TypeDecl::bind_head(NameBinder& binder) { binder.insert_symbol(*this); } +void ExtTypeDecl::bind_head(NameBinder& binder) { + binder.insert_symbol(*this); +} + void TypeDecl::bind(NameBinder& binder) { binder.push_scope(); if (type_params) binder.bind(*type_params); @@ -516,6 +520,20 @@ void TypeDecl::bind(NameBinder& binder) { binder.pop_scope(); } +void ExtTypeDecl::bind(NameBinder& binder) { + binder.push_scope(); + //if (type_params) binder.bind(*type_params); + for (auto& arg : args_) { + if (auto t = std::get_if>(&arg)) + binder.bind(**t); + else if (auto e = std::get_if>(&arg)) + binder.bind(**e); + else + assert(false); + } + binder.pop_scope(); +} + void ModDecl::bind_head(NameBinder& binder) { if (id.name != "") binder.insert_symbol(*this); diff --git a/src/check.cpp b/src/check.cpp index fa6f01f9..a0315a2d 100644 --- a/src/check.cpp +++ b/src/check.cpp @@ -83,7 +83,7 @@ const Type* TypeChecker::invalid_cast(const Loc& loc, const Type* type, const Ty const Type* TypeChecker::invalid_simd(const Loc& loc, const Type* elem_type) { if (should_report_error(elem_type)) - error(loc, "expected primitive type for simd type component, but got '{}'", *elem_type); + error(loc, "expected primitive or pointer type for simd type component, but got '{}'", *elem_type); return type_table.type_error(); } @@ -489,7 +489,7 @@ const Type* TypeChecker::infer_array( if (elem_count == 0) return cannot_infer(loc, msg); auto elem_type = infer_elems(); - if (is_simd && !elem_type->template isa()) + if (is_simd && !(elem_type->template isa() || elem_type->template isa())) return invalid_simd(loc, elem_type); return type_table.sized_array_type(elem_type, elem_count, is_simd); } @@ -509,7 +509,7 @@ const Type* TypeChecker::check_array( if (is_simd_type(array_type) != is_simd) return incompatible_type(loc, (is_simd ? "simd " : "non-simd ") + std::string(msg), expected); auto elem_type = array_type->elem; - if (is_simd && !elem_type->isa()) + if (is_simd && !(elem_type->template isa() || elem_type->template isa())) return invalid_simd(loc, elem_type); check_elems(elem_type); if (auto sized_array_type = array_type->isa(); @@ -847,7 +847,7 @@ const artic::Type* TupleType::infer(TypeChecker& checker) { const artic::Type* SizedArrayType::infer(TypeChecker& checker) { auto elem_type = checker.infer(*elem); - if (is_simd && !elem_type->isa()) + if (is_simd && !(elem_type->template isa() || elem_type->template isa())) return checker.invalid_simd(loc, elem_type); if (std::holds_alternative(size)) { @@ -1022,7 +1022,7 @@ const artic::Type* ArrayExpr::check(TypeChecker& checker, const artic::Type* exp const artic::Type* RepeatArrayExpr::infer(TypeChecker& checker) { auto elem_type = checker.deref(elem); - if (is_simd && !elem_type->isa()) + if (is_simd && !(elem_type->template isa() || elem_type->template isa())) return checker.invalid_simd(loc, elem_type); if (std::holds_alternative(size)) { @@ -1816,6 +1816,27 @@ const artic::Type* TypeDecl::infer(TypeChecker& checker) { return type; } +const artic::Type* ExtTypeDecl::infer(TypeChecker& checker) { + if (!checker.enter_decl(this)) + return checker.type_table.type_error(); + const ExtType* ext_type = checker.type_table.ext_type(*this); + // Set the type before entering the args + type = ext_type; + for (auto& arg : args_) { + if (auto t = std::get_if>(&arg)) + args_types_.emplace_back(checker.infer(**t)); + else if (auto e = std::get_if>(&arg)) { + checker.infer(**e); + if (!(*e)->is_constant()) + checker.error((*e)->loc, "only constants are allowed as external type members"); + args_types_.emplace_back(nullptr); + } else + assert(false); + } + checker.exit_decl(this); + return ext_type; +} + const artic::Type* ModDecl::infer(TypeChecker& checker) { for (auto& decl : decls) checker.infer(*decl); diff --git a/src/emit.cpp b/src/emit.cpp index 83e1f7da..ab53ba5a 100644 --- a/src/emit.cpp +++ b/src/emit.cpp @@ -1756,6 +1756,10 @@ const thorin::Def* TypeDecl::emit(Emitter&) const { return nullptr; } +const thorin::Def* ExtTypeDecl::emit(Emitter&) const { + return nullptr; +} + const thorin::Def* ModDecl::emit(Emitter& emitter) const { for (auto& decl : decls) { // Do not emit polymorphic functions directly: Those will be emitted from @@ -1885,8 +1889,17 @@ std::string SizedArrayType::stringify(Emitter& emitter) const { } const thorin::Type* SizedArrayType::convert(Emitter& emitter) const { - if (is_simd) - return emitter.world.prim_type(elem->convert(emitter)->as()->primtype_tag(), size); + if (is_simd) { + auto elem_type = elem->convert(emitter); + if (auto prim_type = elem_type->isa()) + return emitter.world.prim_type(prim_type->primtype_tag(), size); + else if (auto ptr_type = elem_type->isa()) + return emitter.world.ptr_type(ptr_type->pointee(), size, ptr_type->addr_space()); + + //This should be unreachable after type checking. + assert(false); + return nullptr; + } return emitter.world.definite_array_type(elem->convert(emitter), size); } @@ -2023,6 +2036,31 @@ const thorin::Type* TypeApp::convert(Emitter& emitter) const { return result; } +std::string ExtType::stringify(Emitter& emitter) const { + if (!type_params()) + return decl.id.name; + return stringify_params(emitter, decl.id.name + "_", type_params()->params); +} + +const thorin::Type* ExtType::convert(Emitter& emitter, const Type* parent) const { + if (auto it = emitter.types.find(this); !type_params() && it != emitter.types.end()) + return it->second; + + auto type = emitter.world.extern_type(decl.type_name, decl.args_types_.size(), { decl.id.name }); + emitter.types[parent] = type; + + for (size_t i = 0; i < decl.args_types_.size(); i++) { + if (auto t = decl.args_types_[i]) { + type->set_op(i, t->convert(emitter)); + } else if (auto e = std::get_if>(&decl.args_[i])) + type->set_op(i, emitter.emit(**e)); + else + assert(false); + } + + return type; +} + // A read-only buffer from memory, not performing any copy. struct MemBuf : public std::streambuf { MemBuf(const std::string& str) { diff --git a/src/lexer.cpp b/src/lexer.cpp index 1f474765..c4f35ca3 100644 --- a/src/lexer.cpp +++ b/src/lexer.cpp @@ -23,6 +23,7 @@ std::unordered_map Lexer::keywords{ std::make_pair("struct", Token::Struct), std::make_pair("enum", Token::Enum), std::make_pair("type", Token::Type), + std::make_pair("type_ext", Token::TypeExt), std::make_pair("implicit", Token::Implicit), std::make_pair("summon", Token::Summon), std::make_pair("static", Token::Static), diff --git a/src/parser.cpp b/src/parser.cpp index 7ea25107..1466c6d6 100644 --- a/src/parser.cpp +++ b/src/parser.cpp @@ -42,6 +42,7 @@ Ptr Parser::parse_decl(bool is_top_level) { case Token::Struct: decl = parse_struct_decl(); break; case Token::Enum: decl = parse_enum_decl(); break; case Token::Type: decl = parse_type_decl(); break; + case Token::TypeExt: decl = parse_ext_type_decl(); break; case Token::Implicit: decl = parse_implicit_decl(); break; case Token::Static: decl = parse_static_decl(); break; case Token::Mod: decl = parse_mod_decl(); break; @@ -200,6 +201,30 @@ Ptr Parser::parse_type_decl() { return make_ptr(tracker(), std::move(id), std::move(type_params), std::move(aliased_type)); } +Ptr Parser::parse_ext_type_decl() { + Tracker tracker(this); + eat(Token::TypeExt); + auto id = parse_id(); + + Ptr type_params; + // if (ahead().tag() == Token::LBracket) + // type_params = parse_type_params(); + + std::vector, Ptr>> args; + expect(Token::Eq); + auto type_name = parse_str(); + expect(Token::LBrace); + parse_list(Token::RBrace, Token::Comma, [&] { + if (accept(Token::Type)) + args.emplace_back(parse_type()); + else + args.emplace_back(parse_expr()); + }); + + expect(Token::Semi); + return make_ptr(tracker(), std::move(id), std::move(type_name), std::move(type_params), std::move(args)); +} + Ptr Parser::parse_implicit_decl() { Tracker tracker(this); eat(Token::Implicit); diff --git a/src/print.cpp b/src/print.cpp index d28cc940..49111c65 100644 --- a/src/print.cpp +++ b/src/print.cpp @@ -614,6 +614,22 @@ void TypeDecl::print(Printer& p) const { p << ';'; } +void ExtTypeDecl::print(Printer& p) const { + if (attrs) attrs->print(p); + p << log::keyword_style("type_ext") << ' ' << id.name; + // if (type_params) type_params->print(p); + p << " = \"" << type_name << "\" { "; + print_list(p, ',', args_, [&] (auto& arg) { + if (auto t = std::get_if>(&arg)) { + p << **t; + } else if (auto e = std::get_if>(&arg)) + p << **e; + else + p << "invalid"; + }); + p << " };"; +} + void ModDecl::print(Printer& p) const { if (attrs) attrs->print(p); bool anon = id.name == ""; @@ -825,6 +841,10 @@ void TypeAlias::print(Printer& p) const { p << decl.id.name; } +void ExtType::print(Printer& p) const { + p << decl.id.name; +} + void TypeApp::print(Printer& p) const { applied->print(p); p << '['; diff --git a/src/types.cpp b/src/types.cpp index f7e3bf42..7055c8d5 100644 --- a/src/types.cpp +++ b/src/types.cpp @@ -726,6 +726,10 @@ const TypeAlias* TypeTable::type_alias(const ast::TypeDecl& decl) { return insert(decl); } +const ExtType* TypeTable::ext_type(const ast::ExtTypeDecl& decl) { + return insert(decl); +} + const Type* TypeTable::type_app(const UserType* applied, const ArrayRef& type_args) { if (auto type_alias = applied->isa()) { assert(type_alias->type_params() && type_alias->decl.aliased_type->type);