//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "FoldInitTypeCheck.h"
#include "clang/AST/ASTContext.h"
#include "clang/ASTMatchers/ASTMatchFinder.h"

using namespace clang::ast_matchers;

namespace clang::tidy::bugprone {

void FoldInitTypeCheck::registerMatchers(MatchFinder *Finder) {
  // We match functions of interest and bind the iterator and init value types.
  // Note: Right now we check only builtin types.
  const auto BuiltinTypeWithId = [](const char *ID) {
    return hasCanonicalType(builtinType().bind(ID));
  };
  const auto IteratorWithValueType = [&BuiltinTypeWithId](const char *ID) {
    return anyOf(
        // Pointer types.
        pointsTo(BuiltinTypeWithId(ID)),
        // Iterator types have an `operator*` whose return type is the type we
        // care about.
        // Notes:
        //   - `operator*` can be in one of the bases of the iterator class.
        //   - this does not handle cases when the `operator*` is defined
        //     outside the iterator class.
        recordType(
            hasDeclaration(cxxRecordDecl(isSameOrDerivedFrom(has(functionDecl(
                hasOverloadedOperatorName("*"),
                returns(qualType(hasCanonicalType(anyOf(
                    // `value_type& operator*();`
                    references(BuiltinTypeWithId(ID)),
                    // `value_type operator*();`
                    BuiltinTypeWithId(ID),
                    // `auto operator*();`, `decltype(auto) operator*();`
                    autoType(hasDeducedType(BuiltinTypeWithId(ID)))
                    //
                    )))))))))));
  };

  const auto IteratorParam = parmVarDecl(
      hasType(hasCanonicalType(IteratorWithValueType("IterValueType"))));
  const auto Iterator2Param = parmVarDecl(
      hasType(hasCanonicalType(IteratorWithValueType("Iter2ValueType"))));
  const auto InitParam = parmVarDecl(hasType(BuiltinTypeWithId("InitType")));

  // Transparent standard functors that preserve arithmetic conversion
  // semantics.
  const auto TransparentFunctor = expr(hasType(
      hasCanonicalType(recordType(hasDeclaration(cxxRecordDecl(hasAnyName(
          "::std::plus", "::std::minus", "::std::multiplies", "::std::divides",
          "::std::bit_and", "::std::bit_or", "::std::bit_xor")))))));

  // std::accumulate, std::reduce.
  Finder->addMatcher(
      callExpr(
          callee(functionDecl(hasAnyName("::std::accumulate", "::std::reduce"),
                              hasParameter(0, IteratorParam),
                              hasParameter(2, InitParam))),
          anyOf(argumentCountIs(3),
                allOf(argumentCountIs(4), hasArgument(3, TransparentFunctor))))
          .bind("Call"),
      this);
  // std::inner_product.
  Finder->addMatcher(
      callExpr(
          callee(functionDecl(
              hasName("::std::inner_product"), hasParameter(0, IteratorParam),
              hasParameter(2, Iterator2Param), hasParameter(3, InitParam))),
          anyOf(argumentCountIs(4),
                allOf(argumentCountIs(6), hasArgument(4, TransparentFunctor),
                      hasArgument(5, TransparentFunctor))))
          .bind("Call"),
      this);
  // std::reduce with a policy.
  Finder->addMatcher(
      callExpr(
          callee(functionDecl(hasName("::std::reduce"),
                              hasParameter(1, IteratorParam),
                              hasParameter(3, InitParam))),
          anyOf(argumentCountIs(4),
                allOf(argumentCountIs(5), hasArgument(4, TransparentFunctor))))
          .bind("Call"),
      this);
  // std::inner_product with a policy.
  Finder->addMatcher(
      callExpr(
          callee(functionDecl(
              hasName("::std::inner_product"), hasParameter(1, IteratorParam),
              hasParameter(3, Iterator2Param), hasParameter(4, InitParam))),
          anyOf(argumentCountIs(5),
                allOf(argumentCountIs(7), hasArgument(5, TransparentFunctor),
                      hasArgument(6, TransparentFunctor))))
          .bind("Call"),
      this);
}

/// Returns true if ValueType is allowed to fold into InitType, i.e. if:
///   static_cast<InitType>(ValueType{some_value})
/// does not result in trucation.
static bool isValidBuiltinFold(const BuiltinType &ValueType,
                               const BuiltinType &InitType,
                               const ASTContext &Context) {
  const auto ValueTypeSize = Context.getTypeSize(&ValueType);
  const auto InitTypeSize = Context.getTypeSize(&InitType);
  // It's OK to fold a float into a float of bigger or equal size, but not OK to
  // fold into an int.
  if (ValueType.isFloatingPoint())
    return InitType.isFloatingPoint() && InitTypeSize >= ValueTypeSize;
  // It's OK to fold an int into:
  //  - an int of the same size and signedness.
  //  - a bigger int, regardless of signedness.
  //  - FIXME: should it be a warning to fold into floating point?
  if (ValueType.isInteger()) {
    if (InitType.isInteger()) {
      if (InitType.isSignedInteger() == ValueType.isSignedInteger())
        return InitTypeSize >= ValueTypeSize;
      return InitTypeSize > ValueTypeSize;
    }
    if (InitType.isFloatingPoint())
      return InitTypeSize >= ValueTypeSize;
  }
  return false;
}

/// Prints a diagnostic if IterValueType doe snot fold into IterValueType (see
// isValidBuiltinFold for details).
void FoldInitTypeCheck::doCheck(const BuiltinType &IterValueType,
                                const BuiltinType &InitType,
                                const ASTContext &Context,
                                const CallExpr &CallNode) {
  if (!isValidBuiltinFold(IterValueType, InitType, Context)) {
    diag(CallNode.getExprLoc(), "folding type %0 into type %1 might result in "
                                "loss of precision")
        << IterValueType.desugar() << InitType.desugar();
  }
}

void FoldInitTypeCheck::check(const MatchFinder::MatchResult &Result) {
  // Given the iterator and init value type retrieved by the matchers,
  // we check that the ::value_type of the iterator is compatible with
  // the init value type.
  const auto *InitType = Result.Nodes.getNodeAs<BuiltinType>("InitType");
  const auto *IterValueType =
      Result.Nodes.getNodeAs<BuiltinType>("IterValueType");
  assert(InitType != nullptr);
  assert(IterValueType != nullptr);

  const auto *CallNode = Result.Nodes.getNodeAs<CallExpr>("Call");
  assert(CallNode != nullptr);

  doCheck(*IterValueType, *InitType, *Result.Context, *CallNode);

  if (const auto *Iter2ValueType =
          Result.Nodes.getNodeAs<BuiltinType>("Iter2ValueType"))
    doCheck(*Iter2ValueType, *InitType, *Result.Context, *CallNode);
}

} // namespace clang::tidy::bugprone
