Просмотр исходного кода

Mark instructions that can be deduced through in `typed_insts.h` (#4588)

Co-authored-by: Josh L <josh11b@users.noreply.github.com>
josh11b 1 год назад
Родитель
Сommit
d5e022d53c
3 измененных файлов с 45 добавлено и 34 удалено
  1. 14 24
      toolchain/check/deduce.cpp
  2. 11 0
      toolchain/sem_ir/inst_kind.h
  3. 20 10
      toolchain/sem_ir/typed_insts.h

+ 14 - 24
toolchain/check/deduce.cpp

@@ -417,30 +417,6 @@ auto DeductionContext::Deduce() -> bool {
         continue;
       }
 
-      // Various kinds of parameter should match an argument of the same form,
-      // if the operands all match.
-      case SemIR::ArrayType::Kind:
-      case SemIR::ClassType::Kind:
-      case SemIR::ConstType::Kind:
-      case SemIR::FacetType::Kind:
-      case SemIR::FloatType::Kind:
-      case SemIR::IntType::Kind:
-      case SemIR::PointerType::Kind:
-      case SemIR::StructType::Kind:
-      case SemIR::TupleType::Kind:
-      case SemIR::TupleValue::Kind: {
-        auto arg_inst = context().insts().Get(arg_id);
-        if (arg_inst.kind() != param_inst.kind()) {
-          break;
-        }
-        auto [kind0, kind1] = param_inst.ArgKinds();
-        worklist_.AddInstArg(kind0, param_inst.arg0(), arg_inst.arg0(),
-                             needs_substitution);
-        worklist_.AddInstArg(kind1, param_inst.arg1(), arg_inst.arg1(),
-                             needs_substitution);
-        continue;
-      }
-
       case SemIR::StructValue::Kind:
         // TODO: Match field name order between param and arg.
         break;
@@ -448,6 +424,20 @@ auto DeductionContext::Deduce() -> bool {
         // TODO: Handle more cases.
 
       default:
+        if (param_inst.kind().deduce_through()) {
+          // Various kinds of parameter should match an argument of the same
+          // form, if the operands all match.
+          auto arg_inst = context().insts().Get(arg_id);
+          if (arg_inst.kind() != param_inst.kind()) {
+            break;
+          }
+          auto [kind0, kind1] = param_inst.ArgKinds();
+          worklist_.AddInstArg(kind0, param_inst.arg0(), arg_inst.arg0(),
+                               needs_substitution);
+          worklist_.AddInstArg(kind1, param_inst.arg1(), arg_inst.arg1(),
+                               needs_substitution);
+          continue;
+        }
         break;
     }
 

+ 11 - 0
toolchain/sem_ir/inst_kind.h

@@ -94,6 +94,7 @@ class InstKind : public CARBON_ENUM_BASE(InstKind) {
     InstConstantKind constant_kind = InstConstantKind::Never;
     TerminatorKind terminator_kind = TerminatorKind::NotTerminator;
     bool is_lowered = true;
+    bool deduce_through = false;
   };
 
   // Provides a definition for this instruction kind. Should only be called
@@ -135,6 +136,12 @@ class InstKind : public CARBON_ENUM_BASE(InstKind) {
     return definition_info(*this).terminator_kind;
   }
 
+  // Returns true if `Instruction(A)` == `Instruction(B)` allows deduction to
+  // conclude `A` == `B`.
+  auto deduce_through() const -> bool {
+    return definition_info(*this).deduce_through;
+  }
+
   // Compute a fingerprint for this instruction kind, allowing its use as part
   // of the key in a `FoldingSet`.
   void Profile(llvm::FoldingSetNodeID& id) { id.AddInteger(AsInt()); }
@@ -187,6 +194,10 @@ class InstKind::Definition : public InstKind {
   // Returns true if the instruction is lowered.
   constexpr auto is_lowered() const -> bool { return info_.is_lowered; }
 
+  // Returns true if `Instruction(A)` == `Instruction(B)` allows deduction to
+  // conclude `A` == `B`.
+  constexpr auto deduce_through() const -> bool { return info_.deduce_through; }
+
  private:
   friend class InstKind;
 

+ 20 - 10
toolchain/sem_ir/typed_insts.h

@@ -166,7 +166,8 @@ struct ArrayType {
   static constexpr auto Kind = InstKind::ArrayType.Define<Parse::ArrayExprId>(
       {.ir_name = "array_type",
        .is_type = InstIsType::Always,
-       .constant_kind = InstConstantKind::Conditional});
+       .constant_kind = InstConstantKind::Conditional,
+       .deduce_through = true});
 
   TypeId type_id;
   InstId bound_id;
@@ -501,7 +502,8 @@ struct ClassType {
   static constexpr auto Kind = InstKind::ClassType.Define<Parse::NodeId>(
       {.ir_name = "class_type",
        .is_type = InstIsType::Always,
-       .constant_kind = InstConstantKind::Always});
+       .constant_kind = InstConstantKind::Always,
+       .deduce_through = true});
 
   TypeId type_id;
   ClassId class_id;
@@ -534,7 +536,8 @@ struct ConstType {
       InstKind::ConstType.Define<Parse::PrefixOperatorConstId>(
           {.ir_name = "const_type",
            .is_type = InstIsType::Always,
-           .constant_kind = InstConstantKind::Conditional});
+           .constant_kind = InstConstantKind::Conditional,
+           .deduce_through = true});
 
   TypeId type_id;
   TypeId inner_id;
@@ -606,7 +609,8 @@ struct FacetType {
   static constexpr auto Kind = InstKind::FacetType.Define<Parse::NodeId>(
       {.ir_name = "facet_type",
        .is_type = InstIsType::Always,
-       .constant_kind = InstConstantKind::Always});
+       .constant_kind = InstConstantKind::Always,
+       .deduce_through = true});
 
   TypeId type_id;
   FacetTypeId facet_type_id;
@@ -656,7 +660,8 @@ struct FloatType {
   static constexpr auto Kind = InstKind::FloatType.Define<Parse::InvalidNodeId>(
       {.ir_name = "float_type",
        .is_type = InstIsType::Always,
-       .constant_kind = InstConstantKind::Conditional});
+       .constant_kind = InstConstantKind::Conditional,
+       .deduce_through = true});
 
   TypeId type_id;
   // TODO: Consider adding a more compact way of representing either a small
@@ -879,7 +884,8 @@ struct IntType {
   static constexpr auto Kind = InstKind::IntType.Define<Parse::InvalidNodeId>(
       {.ir_name = "int_type",
        .is_type = InstIsType::Always,
-       .constant_kind = InstConstantKind::Conditional});
+       .constant_kind = InstConstantKind::Conditional,
+       .deduce_through = true});
 
   TypeId type_id;
   IntKind int_kind;
@@ -1008,7 +1014,8 @@ struct PointerType {
       InstKind::PointerType.Define<Parse::PostfixOperatorStarId>(
           {.ir_name = "ptr_type",
            .is_type = InstIsType::Always,
-           .constant_kind = InstConstantKind::Conditional});
+           .constant_kind = InstConstantKind::Conditional,
+           .deduce_through = true});
 
   TypeId type_id;
   TypeId pointee_id;
@@ -1233,7 +1240,8 @@ struct StructType {
       InstKind::StructType.Define<Parse::StructTypeLiteralId>(
           {.ir_name = "struct_type",
            .is_type = InstIsType::Always,
-           .constant_kind = InstConstantKind::Conditional});
+           .constant_kind = InstConstantKind::Conditional,
+           .deduce_through = true});
 
   TypeId type_id;
   StructTypeFieldsId fields_id;
@@ -1307,7 +1315,8 @@ struct TupleType {
   static constexpr auto Kind = InstKind::TupleType.Define<Parse::InvalidNodeId>(
       {.ir_name = "tuple_type",
        .is_type = InstIsType::Always,
-       .constant_kind = InstConstantKind::Conditional});
+       .constant_kind = InstConstantKind::Conditional,
+       .deduce_through = true});
 
   TypeId type_id;
   TypeBlockId elements_id;
@@ -1317,7 +1326,8 @@ struct TupleType {
 struct TupleValue {
   static constexpr auto Kind = InstKind::TupleValue.Define<Parse::NodeId>(
       {.ir_name = "tuple_value",
-       .constant_kind = InstConstantKind::Conditional});
+       .constant_kind = InstConstantKind::Conditional,
+       .deduce_through = true});
 
   TypeId type_id;
   InstBlockId elements_id;