#include "llvm/IR/Instruction.h"
#include "llvm/Transforms/Utils.h"
#include "llvm/Transforms/Utils/LoopPeel.h"
#include "llvm/Transforms/Utils/LoopSimplify.h"
#include "llvm/Transforms/Utils/LoopUtils.h"
#include "llvm/Transforms/Utils/SizeOpts.h"
#include "llvm/Transforms/Utils/UnrollLoop.h"
#include "llvm/Analysis/LoopAnalysisManager.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/LoopUnrollAnalyzer.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/LoopUtils.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Pass.h"
#include "llvm/Analysis/LoopPass.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/IR/IRBuilder.h"
#include <algorithm>
#include <unordered_set>

using namespace llvm;

namespace {
struct OurLoopFission : public LoopPass {
  std::vector<BasicBlock *> LoopBasicBlocks;

  static char ID; // Pass identification, replacement for typeid
  OurLoopFission() : LoopPass(ID) {}

  BasicBlock *copyLoop(Loop *L)
  {
    Instruction *Clone;
    std::unordered_map<Value *, Value *> Mapping;
    std::unordered_map<BasicBlock *, BasicBlock *> BasicBlockMapping;
    std::vector<BasicBlock *> LoopBasicBlocksCopy;
    BasicBlock *NewBasicBlock;
    BasicBlock *ExitBlock = L->getExitBlock();
    IRBuilder<> Builder(ExitBlock);

    for (BasicBlock *BB : LoopBasicBlocks) {
      NewBasicBlock = BasicBlock::Create(ExitBlock->getContext(), "",
                                         ExitBlock->getParent(), ExitBlock);
      BasicBlockMapping[BB] = NewBasicBlock;
      LoopBasicBlocksCopy.push_back(NewBasicBlock);
    }

    for (BasicBlock *BB : LoopBasicBlocks) {
      BasicBlock *CopyBlock = BasicBlockMapping[BB];
      Builder.SetInsertPoint(CopyBlock);

      for (Instruction &I : *BB) {
        Clone = I.clone();
        Mapping[&I] = Clone;
        Builder.Insert(Clone);

        for (size_t i = 0; i < Clone->getNumOperands(); i++) {
          if (Mapping.find(Clone->getOperand(i)) != Mapping.end()) {
            Clone->setOperand(i, Mapping[Clone->getOperand(i)]);
          }
        }
      }
    }

    for (BasicBlock *BB : LoopBasicBlocksCopy) {
      for (size_t i = 0; i < BB->getTerminator()->getNumSuccessors(); i++) {
        BasicBlock *Successor = BB->getTerminator()->getSuccessor(i);
        if (BasicBlockMapping.find(Successor) != BasicBlockMapping.end()) {
          BB->getTerminator()->setSuccessor(i, BasicBlockMapping[Successor]);
        }
      }
    }

    BasicBlock *FirstIfBasicBlock = findIfBasicBlock(true, LoopBasicBlocksCopy);
    BasicBlock *LastIfBasicBlock = findIfBasicBlock(false, LoopBasicBlocksCopy);

    FirstIfBasicBlock->printAsOperand(errs());
    errs() << "\n";
    LastIfBasicBlock->printAsOperand(errs());

    std::unordered_set<BasicBlock *> BlocksToDelete;
    deleteAllFromBlock(FirstIfBasicBlock, LastIfBasicBlock, BlocksToDelete);
    LoopBasicBlocksCopy.front()->getTerminator()->setSuccessor(0, LastIfBasicBlock);

    for (BasicBlock *BlockToDelete : BlocksToDelete) {
      BlockToDelete->eraseFromParent();
    }

    return LoopBasicBlocksCopy.front();
  }

  BasicBlock *findIfBasicBlock(bool findFirst, std::vector<BasicBlock *> LoopBasicBlocks)
  {
    BasicBlock *IfBasicBlock = nullptr;

    for (size_t i = 1; i < LoopBasicBlocks.size(); i++) {
      if (BranchInst *Branch = dyn_cast<BranchInst>(LoopBasicBlocks[i]->getTerminator())) {
        if (Branch->isConditional()) {
          if (findFirst) {
            return LoopBasicBlocks[i];
          }
          else {
            IfBasicBlock = LoopBasicBlocks[i];
          }
        }
      }
    }

    return IfBasicBlock;
  }

  void deleteAllFromBlock(BasicBlock *DeleteFrom, BasicBlock *StopBlock,
                          std::unordered_set<BasicBlock *> &BlocksToDelete)
  {
    BlocksToDelete.insert(DeleteFrom);

    for (size_t i = 0; i < DeleteFrom->getTerminator()->getNumSuccessors(); i++) {
      BasicBlock *Successor = DeleteFrom->getTerminator()->getSuccessor(i);
      if (Successor != StopBlock || BlocksToDelete.find(Successor) != BlocksToDelete.end()) {
        deleteAllFromBlock(Successor, StopBlock, BlocksToDelete);
      }
    }
  }

  void fissionLoop(Loop *L)
  {
    BasicBlock *NewLoop = copyLoop(L);
    LoopBasicBlocks.front()->getTerminator()->setSuccessor(1, NewLoop);

    BasicBlock *IfBasicBlock = findIfBasicBlock(true, LoopBasicBlocks);
    BranchInst *Branch = dyn_cast<BranchInst>(IfBasicBlock->getTerminator()->getSuccessor(1)->getTerminator());
    bool isConditional = Branch->isConditional();
    std::unordered_set<BasicBlock *> BlocksToDelete;
    // Ima else granu
    if (!isConditional) {
      deleteAllFromBlock(Branch->getSuccessor(0),
                         L->getLoopLatch(), BlocksToDelete);
      IfBasicBlock->getTerminator()->getSuccessor(0)->getTerminator()->setSuccessor(0, L->getLoopLatch());
      IfBasicBlock->getTerminator()->getSuccessor(1)->getTerminator()->setSuccessor(0, L->getLoopLatch());
    }
    // Nema else granu
    else {
      deleteAllFromBlock(IfBasicBlock->getTerminator()->getSuccessor(1), L->getLoopLatch(), BlocksToDelete);
      IfBasicBlock->getTerminator()->getSuccessor(0)->getTerminator()->setSuccessor(0, L->getLoopLatch());
      IfBasicBlock->getTerminator()->setSuccessor(1, L->getLoopLatch());
    }

    for (BasicBlock *BlockToDelete : BlocksToDelete) {
      BlockToDelete->eraseFromParent();
    }
  }

  bool runOnLoop(Loop *L, LPPassManager &LPM) override {
    LoopBasicBlocks = L->getBlocksVector();
    fissionLoop(L);
    return true;
  }
}; // end of struct OurLoopFission
}  // end of anonymous namespace

char OurLoopFission::ID = 0;
static RegisterPass<OurLoopFission> X("loop-fission", "",
                                            false /* Only looks at CFG */,
                                            false /* Analysis Pass */);