deduce.cpp 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. // Part of the Carbon Language project, under the Apache License v2.0 with LLVM
  2. // Exceptions. See /LICENSE for license information.
  3. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  4. #include "toolchain/check/deduce.h"
  5. #include "toolchain/base/kind_switch.h"
  6. #include "toolchain/check/context.h"
  7. #include "toolchain/check/convert.h"
  8. #include "toolchain/check/generic.h"
  9. #include "toolchain/check/subst.h"
  10. #include "toolchain/sem_ir/ids.h"
  11. #include "toolchain/sem_ir/typed_insts.h"
  12. namespace Carbon::Check {
  13. namespace {
  14. // A list of pairs of (instruction from generic, corresponding instruction from
  15. // call to of generic) for which we still need to perform deduction, along with
  16. // methods to add and pop pending deductions from the list. Deductions are
  17. // popped in order from most- to least-recently pushed, with the intent that
  18. // they are visited in depth-first order, although the order is not expected to
  19. // matter except when it influences which error is diagnosed.
  20. class DeductionWorklist {
  21. public:
  22. explicit DeductionWorklist(Context& context) : context_(context) {}
  23. struct PendingDeduction {
  24. SemIR::InstId param;
  25. SemIR::InstId arg;
  26. bool needs_substitution;
  27. };
  28. // Adds a single (param, arg) deduction.
  29. auto Add(SemIR::InstId param, SemIR::InstId arg, bool needs_substitution)
  30. -> void {
  31. deductions_.push_back(
  32. {.param = param, .arg = arg, .needs_substitution = needs_substitution});
  33. }
  34. // Adds a list of (param, arg) deductions. These are added in reverse order so
  35. // they are popped in forward order.
  36. auto AddAll(llvm::ArrayRef<SemIR::InstId> params,
  37. llvm::ArrayRef<SemIR::InstId> args, bool needs_substitution)
  38. -> void {
  39. if (params.size() != args.size()) {
  40. // TODO: Decide whether to error on this or just treat the parameter list
  41. // as non-deduced. For now we treat it as non-deduced.
  42. return;
  43. }
  44. for (auto [param, arg] : llvm::reverse(llvm::zip_equal(params, args))) {
  45. Add(param, arg, needs_substitution);
  46. }
  47. }
  48. auto AddAll(SemIR::InstBlockId params, llvm::ArrayRef<SemIR::InstId> args,
  49. bool needs_substitution) -> void {
  50. AddAll(context_.inst_blocks().Get(params), args, needs_substitution);
  51. }
  52. auto AddAll(SemIR::InstBlockId params, SemIR::InstBlockId args,
  53. bool needs_substitution) -> void {
  54. AddAll(context_.inst_blocks().Get(params), context_.inst_blocks().Get(args),
  55. needs_substitution);
  56. }
  57. // Returns whether we have completed all deductions.
  58. auto Done() -> bool { return deductions_.empty(); }
  59. // Pops the next deduction. Requires `!Done()`.
  60. auto PopNext() -> PendingDeduction { return deductions_.pop_back_val(); }
  61. private:
  62. Context& context_;
  63. llvm::SmallVector<PendingDeduction> deductions_;
  64. };
  65. } // namespace
  66. static auto NoteGenericHere(Context& context, SemIR::GenericId generic_id,
  67. Context::DiagnosticBuilder& diag) -> void {
  68. CARBON_DIAGNOSTIC(DeductionGenericHere, Note,
  69. "while deducing parameters of generic declared here");
  70. diag.Note(context.generics().Get(generic_id).decl_id, DeductionGenericHere);
  71. }
  72. auto DeduceGenericCallArguments(
  73. Context& context, SemIR::LocId loc_id, SemIR::GenericId generic_id,
  74. SemIR::SpecificId enclosing_specific_id,
  75. [[maybe_unused]] SemIR::InstBlockId implicit_params_id,
  76. SemIR::InstBlockId params_id, [[maybe_unused]] SemIR::InstId self_id,
  77. llvm::ArrayRef<SemIR::InstId> arg_ids) -> SemIR::SpecificId {
  78. DeductionWorklist worklist(context);
  79. llvm::SmallVector<SemIR::InstId> result_arg_ids;
  80. llvm::SmallVector<Substitution> substitutions;
  81. // Copy any outer generic arguments from the specified instance and prepare to
  82. // substitute them into the function declaration.
  83. if (enclosing_specific_id.is_valid()) {
  84. auto args = context.inst_blocks().Get(
  85. context.specifics().Get(enclosing_specific_id).args_id);
  86. result_arg_ids.assign(args.begin(), args.end());
  87. // TODO: Subst is linear in the length of the substitutions list. Change it
  88. // so we can pass in an array mapping indexes to substitutions instead.
  89. substitutions.reserve(args.size());
  90. for (auto [i, subst_inst_id] : llvm::enumerate(args)) {
  91. substitutions.push_back(
  92. {.bind_id = SemIR::CompileTimeBindIndex(i),
  93. .replacement_id = context.constant_values().Get(subst_inst_id)});
  94. }
  95. }
  96. auto first_deduced_index = SemIR::CompileTimeBindIndex(result_arg_ids.size());
  97. // Initialize the deduced arguments to Invalid.
  98. result_arg_ids.resize(context.inst_blocks()
  99. .Get(context.generics().Get(generic_id).bindings_id)
  100. .size(),
  101. SemIR::InstId::Invalid);
  102. // Prepare to perform deduction of the explicit parameters against their
  103. // arguments.
  104. // TODO: Also perform deduction for type of self.
  105. worklist.AddAll(params_id, arg_ids, /*needs_substitution=*/true);
  106. while (!worklist.Done()) {
  107. auto [param_id, arg_id, needs_substitution] = worklist.PopNext();
  108. // If the parameter has a symbolic type, deduce against that.
  109. auto param_type_id = context.insts().Get(param_id).type_id();
  110. if (param_type_id.AsConstantId().is_symbolic()) {
  111. worklist.Add(
  112. context.types().GetInstId(param_type_id),
  113. context.types().GetInstId(context.insts().Get(arg_id).type_id()),
  114. needs_substitution);
  115. } else {
  116. // The argument needs to have the same type as the parameter.
  117. DiagnosticAnnotationScope annotate_diagnostics(
  118. &context.emitter(), [&](auto& builder) {
  119. if (auto param = context.insts().TryGetAs<SemIR::BindSymbolicName>(
  120. param_id)) {
  121. CARBON_DIAGNOSTIC(
  122. InitializingGenericParam, Note,
  123. "initializing generic parameter `{0}` declared here",
  124. SemIR::NameId);
  125. builder.Note(
  126. param_id, InitializingGenericParam,
  127. context.entity_names().Get(param->entity_name_id).name_id);
  128. }
  129. });
  130. arg_id = ConvertToValueOfType(context, loc_id, arg_id, param_type_id);
  131. if (arg_id == SemIR::InstId::BuiltinError) {
  132. return SemIR::SpecificId::Invalid;
  133. }
  134. }
  135. // If the parameter is a symbolic constant, deduce against it.
  136. auto param_const_id = context.constant_values().Get(param_id);
  137. if (!param_const_id.is_valid() || !param_const_id.is_symbolic()) {
  138. continue;
  139. }
  140. // If we've not yet substituted into the parameter, do so now.
  141. if (needs_substitution) {
  142. param_const_id = SubstConstant(context, param_const_id, substitutions);
  143. if (!param_const_id.is_valid() || !param_const_id.is_symbolic()) {
  144. continue;
  145. }
  146. needs_substitution = false;
  147. }
  148. CARBON_KIND_SWITCH(context.insts().Get(context.constant_values().GetInstId(
  149. param_const_id))) {
  150. // Deducing a symbolic binding from an argument with a constant value
  151. // deduces the binding as having that constant value.
  152. case CARBON_KIND(SemIR::BindSymbolicName bind): {
  153. auto& entity_name = context.entity_names().Get(bind.entity_name_id);
  154. auto index = entity_name.bind_index;
  155. if (index.is_valid() && index >= first_deduced_index) {
  156. CARBON_CHECK(static_cast<size_t>(index.index) < result_arg_ids.size(),
  157. "Deduced value for unexpected index {0}; expected to "
  158. "deduce {1} arguments.",
  159. index, result_arg_ids.size());
  160. auto arg_const_inst_id =
  161. context.constant_values().GetConstantInstId(arg_id);
  162. if (arg_const_inst_id.is_valid()) {
  163. if (result_arg_ids[index.index].is_valid() &&
  164. result_arg_ids[index.index] != arg_const_inst_id) {
  165. // TODO: Include the two different deduced values.
  166. CARBON_DIAGNOSTIC(DeductionInconsistent, Error,
  167. "inconsistent deductions for value of generic "
  168. "parameter `{0}`",
  169. SemIR::NameId);
  170. auto diag = context.emitter().Build(loc_id, DeductionInconsistent,
  171. entity_name.name_id);
  172. NoteGenericHere(context, generic_id, diag);
  173. diag.Emit();
  174. return SemIR::SpecificId::Invalid;
  175. }
  176. result_arg_ids[index.index] = arg_const_inst_id;
  177. }
  178. }
  179. break;
  180. }
  181. // TODO: Handle more cases.
  182. default:
  183. break;
  184. }
  185. }
  186. // Check we deduced an argument value for every parameter.
  187. for (auto [i, deduced_arg_id] :
  188. llvm::enumerate(llvm::ArrayRef(result_arg_ids)
  189. .drop_front(first_deduced_index.index))) {
  190. if (!deduced_arg_id.is_valid()) {
  191. auto binding_index = first_deduced_index.index + i;
  192. auto binding_id = context.inst_blocks().Get(
  193. context.generics().Get(generic_id).bindings_id)[binding_index];
  194. auto entity_name_id =
  195. context.insts().GetAs<SemIR::AnyBindName>(binding_id).entity_name_id;
  196. CARBON_DIAGNOSTIC(DeductionIncomplete, Error,
  197. "cannot deduce value for generic parameter `{0}`",
  198. SemIR::NameId);
  199. auto diag = context.emitter().Build(
  200. loc_id, DeductionIncomplete,
  201. context.entity_names().Get(entity_name_id).name_id);
  202. NoteGenericHere(context, generic_id, diag);
  203. diag.Emit();
  204. return SemIR::SpecificId::Invalid;
  205. }
  206. }
  207. // TODO: Convert the deduced values to the types of the bindings.
  208. return MakeSpecific(context, generic_id,
  209. context.inst_blocks().AddCanonical(result_arg_ids));
  210. }
  211. } // namespace Carbon::Check