//===- AddToOpToSplit.cpp - Lower Shadowed Gradient ops
//------------------ //
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass to lower custom ops generated by the Enzyme AD
// procedure to the MemRef dialect.
//===----------------------------------------------------------------------===//

#include "Dialect/Dialect.h"
#include "Dialect/Ops.h"
#include "PassDetails.h"
#include "Passes/Passes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Transforms/DialectConversion.h"

#include "mlir/Rewrite/PatternApplicator.h"
#include "llvm/Support/raw_ostream.h"

#include "Interfaces/AutoDiffTypeInterface.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"

#include "Utils.h"

using namespace mlir;
using namespace enzyme;
using llvm::errs;
namespace {

Operation *getAddToOp(linalg::GenericOp &adjoint) {
  Operation *addToOp = nullptr;
  adjoint.walk([&](Operation *op) {
    if (isa<enzyme::AddToOp>(op)) {
      addToOp = op;
    }
  });
  return addToOp;
}

bool isMemrefCacheType(Type type) {
  if (auto memrefType = dyn_cast<MemRefType>(type)) {
    return isa<CacheType>(memrefType.getElementType());
  }
  return false;
}

void processGenericDuplication(Operation *op, OpBuilder &builder, Location loc,
                               int i) {
  auto clonedAdjoint = builder.clone(*op);
  linalg::GenericOp clonedAdjointGenericOp =
      cast<linalg::GenericOp>(clonedAdjoint);

  // Delete all but the ith output
  unsigned numInputs = clonedAdjointGenericOp.getInputsMutable().size();
  SmallVector<mlir::Attribute> indexingMaps(
      clonedAdjointGenericOp.getIndexingMaps().getValue());

  if (clonedAdjointGenericOp.getOutputsMutable().size() - i - 1 > 0) {
    unsigned idx = i + 1;
    unsigned len = clonedAdjointGenericOp.getOutputsMutable().size() - i - 1;
    clonedAdjointGenericOp.getOutputsMutable().erase(idx, len);
    clonedAdjointGenericOp.getRegion().front().eraseArguments(numInputs + idx,
                                                              len);
    indexingMaps.erase(indexingMaps.begin() + numInputs + idx,
                       indexingMaps.begin() + numInputs + idx + len);
  }

  if (i > 0) {
    clonedAdjointGenericOp.getOutputsMutable().erase(0, i);
    clonedAdjointGenericOp.getRegion().front().eraseArguments(numInputs, i);
    indexingMaps.erase(indexingMaps.begin() + numInputs,
                       indexingMaps.begin() + numInputs + i);
  }
  clonedAdjointGenericOp.setIndexingMapsAttr(
      builder.getArrayAttr(indexingMaps));

  auto clonedAddToOp = getAddToOp(clonedAdjointGenericOp);
  auto scope = OpBuilder::InsertionGuard(builder);

  builder.setInsertionPointAfter(clonedAddToOp);
  auto terminator = builder.create<linalg::YieldOp>(loc);

  auto operand = clonedAddToOp->getOperand(i);
  auto outputOperand =
      clonedAdjointGenericOp.getRegion().front().getArgument(numInputs + i);
  auto operandType = cast<AutoDiffTypeInterface>(operand.getType());

  builder.setInsertionPoint(terminator);
  auto returnValue =
      operandType.createAddOp(builder, loc, outputOperand, operand);

  terminator->setOperands({returnValue});
  clonedAddToOp->erase();
}

struct AddToOpToSplitPass
    : public enzyme::AddToOpToSplitPassBase<AddToOpToSplitPass> {
  void runOnOperation() override {
    getOperation()->walk([&](Operation *op) {
      auto enzymeAdjoint = dyn_cast<enzyme::GenericAdjointOp>(op);
      auto loc = op->getLoc();
      if (!enzymeAdjoint)
        return;

      OpBuilder builder(enzymeAdjoint);
      auto adjoint = Utils::adjointToGeneric(enzymeAdjoint, builder, loc);

      Operation *addToOp = getAddToOp(adjoint);
      if (!addToOp)
        return;

      // TODO duplicate memref<CacheType> inputs
      // For now just error out
      for (auto input : adjoint.getInputs()) {
        if (isMemrefCacheType(input.getType())) {
          llvm::report_fatal_error(
              "Cannot split AddToOp with memref<CacheType> inputs");
          return;
        }
      }

      builder.setInsertionPoint(adjoint);

      for (size_t i = 0; i < addToOp->getNumOperands(); i++) {
        processGenericDuplication(adjoint.getOperation(), builder, loc, i);
      }
      adjoint->erase();
    });
  };
};
} // end anonymous namespace

namespace mlir {
namespace enzyme {
std::unique_ptr<Pass> createAddToOpToSplitPass() {
  return std::make_unique<AddToOpToSplitPass>();
}
} // namespace enzyme
} // namespace mlir
