deduce.cpp 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. // Part of the Carbon Language project, under the Apache License v2.0 with LLVM
  2. // Exceptions. See /LICENSE for license information.
  3. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  4. #include "toolchain/check/deduce.h"
  5. #include "toolchain/base/kind_switch.h"
  6. #include "toolchain/check/context.h"
  7. #include "toolchain/check/generic.h"
  8. #include "toolchain/check/subst.h"
  9. #include "toolchain/sem_ir/typed_insts.h"
  10. namespace Carbon::Check {
  11. namespace {
  12. // A list of pairs of (instruction from generic, corresponding instruction from
  13. // call to of generic) for which we still need to perform deduction, along with
  14. // methods to add and pop pending deductions from the list. Deductions are
  15. // popped in order from most- to least-recently pushed, with the intent that
  16. // they are visited in depth-first order, although the order is not expected to
  17. // matter except when it influences which error is diagnosed.
  18. class DeductionWorklist {
  19. public:
  20. explicit DeductionWorklist(Context& context) : context_(context) {}
  21. struct PendingDeduction {
  22. SemIR::InstId param;
  23. SemIR::InstId arg;
  24. bool needs_substitution;
  25. };
  26. // Adds a single (param, arg) deduction.
  27. auto Add(SemIR::InstId param, SemIR::InstId arg, bool needs_substitution)
  28. -> void {
  29. deductions_.push_back(
  30. {.param = param, .arg = arg, .needs_substitution = needs_substitution});
  31. }
  32. // Adds a list of (param, arg) deductions. These are added in reverse order so
  33. // they are popped in forward order.
  34. auto AddAll(llvm::ArrayRef<SemIR::InstId> params,
  35. llvm::ArrayRef<SemIR::InstId> args, bool needs_substitution)
  36. -> void {
  37. if (params.size() != args.size()) {
  38. // TODO: Decide whether to error on this or just treat the parameter list
  39. // as non-deduced. For now we treat it as non-deduced.
  40. return;
  41. }
  42. for (auto [param, arg] : llvm::reverse(llvm::zip_equal(params, args))) {
  43. Add(param, arg, needs_substitution);
  44. }
  45. }
  46. auto AddAll(SemIR::InstBlockId params, llvm::ArrayRef<SemIR::InstId> args,
  47. bool needs_substitution) -> void {
  48. AddAll(context_.inst_blocks().Get(params), args, needs_substitution);
  49. }
  50. auto AddAll(SemIR::InstBlockId params, SemIR::InstBlockId args,
  51. bool needs_substitution) -> void {
  52. AddAll(context_.inst_blocks().Get(params), context_.inst_blocks().Get(args),
  53. needs_substitution);
  54. }
  55. // Returns whether we have completed all deductions.
  56. auto Done() -> bool { return deductions_.empty(); }
  57. // Pops the next deduction. Requires `!Done()`.
  58. auto PopNext() -> PendingDeduction { return deductions_.pop_back_val(); }
  59. private:
  60. Context& context_;
  61. llvm::SmallVector<PendingDeduction> deductions_;
  62. };
  63. } // namespace
  64. static auto NoteGenericHere(Context& context, SemIR::GenericId generic_id,
  65. Context::DiagnosticBuilder& diag) -> void {
  66. CARBON_DIAGNOSTIC(DeductionGenericHere, Note,
  67. "While deducing parameters of generic declared here.");
  68. diag.Note(context.generics().Get(generic_id).decl_id, DeductionGenericHere);
  69. }
  70. auto DeduceGenericCallArguments(
  71. Context& context, SemIR::LocId loc_id, SemIR::GenericId generic_id,
  72. SemIR::SpecificId enclosing_specific_id,
  73. [[maybe_unused]] SemIR::InstBlockId implicit_params_id,
  74. SemIR::InstBlockId params_id, [[maybe_unused]] SemIR::InstId self_id,
  75. llvm::ArrayRef<SemIR::InstId> arg_ids) -> SemIR::SpecificId {
  76. DeductionWorklist worklist(context);
  77. llvm::SmallVector<SemIR::InstId> result_arg_ids;
  78. llvm::SmallVector<Substitution> substitutions;
  79. // Copy any outer generic arguments from the specified instance and prepare to
  80. // substitute them into the function declaration.
  81. if (enclosing_specific_id.is_valid()) {
  82. auto args = context.inst_blocks().Get(
  83. context.specifics().Get(enclosing_specific_id).args_id);
  84. result_arg_ids.assign(args.begin(), args.end());
  85. // TODO: Subst is linear in the length of the substitutions list. Change it
  86. // so we can pass in an array mapping indexes to substitutions instead.
  87. substitutions.reserve(args.size());
  88. for (auto [i, subst_inst_id] : llvm::enumerate(args)) {
  89. substitutions.push_back(
  90. {.bind_id = SemIR::CompileTimeBindIndex(i),
  91. .replacement_id = context.constant_values().Get(subst_inst_id)});
  92. }
  93. }
  94. auto first_deduced_index = SemIR::CompileTimeBindIndex(result_arg_ids.size());
  95. // Initialize the deduced arguments to Invalid.
  96. result_arg_ids.resize(context.inst_blocks()
  97. .Get(context.generics().Get(generic_id).bindings_id)
  98. .size(),
  99. SemIR::InstId::Invalid);
  100. // Prepare to perform deduction of the explicit parameters against their
  101. // arguments.
  102. // TODO: Also perform deduction for type of self.
  103. worklist.AddAll(params_id, arg_ids, /*needs_substitution=*/true);
  104. while (!worklist.Done()) {
  105. auto [param_id, arg_id, needs_substitution] = worklist.PopNext();
  106. // If the parameter has a symbolic type, deduce against that.
  107. auto param_type_id = context.insts().Get(param_id).type_id();
  108. if (param_type_id.AsConstantId().is_symbolic()) {
  109. worklist.Add(
  110. context.types().GetInstId(param_type_id),
  111. context.types().GetInstId(context.insts().Get(arg_id).type_id()),
  112. needs_substitution);
  113. }
  114. // If the parameter is a symbolic constant, deduce against it.
  115. auto param_const_id = context.constant_values().Get(param_id);
  116. if (!param_const_id.is_valid() || !param_const_id.is_symbolic()) {
  117. continue;
  118. }
  119. // If we've not yet substituted into the parameter, do so now.
  120. if (needs_substitution) {
  121. param_const_id = SubstConstant(context, param_const_id, substitutions);
  122. if (!param_const_id.is_valid() || !param_const_id.is_symbolic()) {
  123. continue;
  124. }
  125. needs_substitution = false;
  126. }
  127. CARBON_KIND_SWITCH(context.insts().Get(context.constant_values().GetInstId(
  128. param_const_id))) {
  129. // Deducing a symbolic binding from an argument with a constant value
  130. // deduces the binding as having that constant value.
  131. case CARBON_KIND(SemIR::BindSymbolicName bind): {
  132. auto& entity_name = context.entity_names().Get(bind.entity_name_id);
  133. auto index = entity_name.bind_index;
  134. if (index.is_valid() && index >= first_deduced_index) {
  135. CARBON_CHECK(static_cast<size_t>(index.index) < result_arg_ids.size(),
  136. "Deduced value for unexpected index {0}; expected to "
  137. "deduce {1} arguments.",
  138. index, result_arg_ids.size());
  139. auto arg_const_inst_id =
  140. context.constant_values().GetConstantInstId(arg_id);
  141. if (arg_const_inst_id.is_valid()) {
  142. if (result_arg_ids[index.index].is_valid() &&
  143. result_arg_ids[index.index] != arg_const_inst_id) {
  144. // TODO: Include the two different deduced values.
  145. CARBON_DIAGNOSTIC(DeductionInconsistent, Error,
  146. "Inconsistent deductions for value of generic "
  147. "parameter `{0}`.",
  148. SemIR::NameId);
  149. auto diag = context.emitter().Build(loc_id, DeductionInconsistent,
  150. entity_name.name_id);
  151. NoteGenericHere(context, generic_id, diag);
  152. diag.Emit();
  153. return SemIR::SpecificId::Invalid;
  154. }
  155. result_arg_ids[index.index] = arg_const_inst_id;
  156. }
  157. }
  158. break;
  159. }
  160. // TODO: Handle more cases.
  161. default:
  162. break;
  163. }
  164. }
  165. // Check we deduced an argument value for every parameter.
  166. for (auto [i, deduced_arg_id] :
  167. llvm::enumerate(llvm::ArrayRef(result_arg_ids)
  168. .drop_front(first_deduced_index.index))) {
  169. if (!deduced_arg_id.is_valid()) {
  170. auto binding_index = first_deduced_index.index + i;
  171. auto binding_id = context.inst_blocks().Get(
  172. context.generics().Get(generic_id).bindings_id)[binding_index];
  173. auto entity_name_id =
  174. context.insts().GetAs<SemIR::AnyBindName>(binding_id).entity_name_id;
  175. CARBON_DIAGNOSTIC(DeductionIncomplete, Error,
  176. "Cannot deduce value for generic parameter `{0}`.",
  177. SemIR::NameId);
  178. auto diag = context.emitter().Build(
  179. loc_id, DeductionIncomplete,
  180. context.entity_names().Get(entity_name_id).name_id);
  181. NoteGenericHere(context, generic_id, diag);
  182. diag.Emit();
  183. return SemIR::SpecificId::Invalid;
  184. }
  185. }
  186. // TODO: Convert the deduced values to the types of the bindings.
  187. return MakeSpecific(context, generic_id,
  188. context.inst_blocks().AddCanonical(result_arg_ids));
  189. }
  190. } // namespace Carbon::Check