// ShortestPathTest.cpp
//
//     Created on: 21. 01. 2018
//         Author: Jan Uhlik
//    Modified by:
//
// Copyright (c) 2017 Czech Technical University in Prague | Faculty of Information Technology. All rights reserved.
// Git repository: https://gitlab.fit.cvut.cz/algorithms-library-toolkit/automata-library

#include "ShortestPathTest.hpp"

#include <graph/GraphClasses.hpp>
#include <node/NodeClasses.hpp>
#include <edge/EdgeClasses.hpp>

#include <heuristic/SquareGridHeuristics.hpp>

#include <traverse/BFS.hpp>
#include <traverse/IDDFS.hpp>
#include <shortest_path/BellmanFord.hpp>
#include <shortest_path/SPFA.hpp>
#include "shortest_path/Dijkstra.hpp"
#include <shortest_path/AStar.hpp>
#include <shortest_path/IDAStar.hpp>
#include <shortest_path/MM.hpp>
#include <shortest_path/JPS.hpp>
#include <generate/RandomGraphFactory.hpp>
#include <generate/RandomGridFactory.hpp>

using namespace graph;

CPPUNIT_TEST_SUITE_NAMED_REGISTRATION(ShortestPathTest, "shortest_path");
CPPUNIT_TEST_SUITE_REGISTRATION(ShortestPathTest);

const double EPS = 10E-7; // Magic constant for double compare

// ---------------------------------------------------------------------------------------------------------------------

void ShortestPathTest::setUp() {
  TestFixture::setUp();
}

// ---------------------------------------------------------------------------------------------------------------------

void ShortestPathTest::tearDown() {
  TestFixture::tearDown();
}

// ---------------------------------------------------------------------------------------------------------------------

void ShortestPathTest::testBFSGraph() {
  graph::DirectedGraph<int, edge::WeightedEdge<int>> graph;

  graph.addNode(1);
  graph.addNode(2);
  graph.addNode(3);
  graph.addNode(4);
  graph.addNode(5);

  graph.addEdge(edge::WeightedEdge<int>(1, 2, 5));
  graph.addEdge(edge::WeightedEdge<int>(2, 3, 3));
  graph.addEdge(edge::WeightedEdge<int>(3, 4, 3));
  graph.addEdge(edge::WeightedEdge<int>(4, 5, 5));
  graph.addEdge(edge::WeightedEdge<int>(2, 4, 4));

  auto res1 = traverse::BFS::findPath(graph, 1, 5);

  auto res2 = traverse::BFS::findPathBidirectional(graph, 1, 5);

  ext::vector<int> res = {1, 2, 4, 5};
  CPPUNIT_ASSERT(res1 == res);
  CPPUNIT_ASSERT(res1.size() == res.size());
  CPPUNIT_ASSERT(res2 == res);
  CPPUNIT_ASSERT(res2.size() == res.size());
}

// ---------------------------------------------------------------------------------------------------------------------

void ShortestPathTest::testIDDFSGraph() {
  graph::DirectedGraph<int, ext::pair<int, int>> graph;

  graph.addNode(1);
  graph.addNode(2);
  graph.addNode(3);
  graph.addNode(4);
  graph.addNode(5);
  graph.addNode(5);
  graph.addNode(6);
  graph.addNode(7);
  graph.addNode(8);

  graph.addEdge(1, 2);
  graph.addEdge(1, 3);
  graph.addEdge(1, 4);
  graph.addEdge(2, 5);
  graph.addEdge(2, 6);
  graph.addEdge(3, 7);
  graph.addEdge(6, 8);

  auto bfs = traverse::BFS::findPath(graph, 1, 8);
  auto iddfs = traverse::IDDFS::findPath(graph, 1, 8);
  auto iddfsBi = traverse::IDDFS::findPathBidirectional(graph, 1, 8);

  CPPUNIT_ASSERT(bfs == iddfs);
  CPPUNIT_ASSERT(iddfsBi == iddfs);
}

// ---------------------------------------------------------------------------------------------------------------------

void ShortestPathTest::testDijkstraGraph() {
  graph::UndirectedGraph<int, edge::WeightedEdge<int>> graph;

  graph.addNode(1);
  graph.addNode(2);
  graph.addNode(3);
  graph.addNode(4);
  graph.addNode(5);

  graph.addEdge(edge::WeightedEdge<int>(1, 2, 5));
  graph.addEdge(edge::WeightedEdge<int>(2, 3, 3));
  graph.addEdge(edge::WeightedEdge<int>(3, 4, 3));
  graph.addEdge(edge::WeightedEdge<int>(4, 5, 5));
  graph.addEdge(edge::WeightedEdge<int>(2, 4, 4));

  auto res1 = graph::shortest_path::Dijkstra::findPath(graph, 1, 5);

  auto res2 = graph::shortest_path::Dijkstra::findPathBidirectional(graph, 1, 5);

  ext::vector<int> res = {1, 2, 4, 5};
  CPPUNIT_ASSERT(res1.first == res);
  CPPUNIT_ASSERT(res1.second == 14);
  CPPUNIT_ASSERT(res2.first == res);
  CPPUNIT_ASSERT(res2.second == 14);
}

// ---------------------------------------------------------------------------------------------------------------------

void ShortestPathTest::testBFSGrid() {
  grid::SquareGrid4<> graph(10, 10);
  using node_type = decltype(graph)::node_type;

  graph.addObstacle(3, 2);
  graph.addObstacle(3, 3);
  graph.addObstacle(3, 4);
  graph.addObstacle(3, 5);
  graph.addObstacle(4, 5);
  graph.addObstacle(5, 5);
  graph.addObstacle(6, 5);

  node_type start = ext::make_pair(8l, 2l);
  node_type goal = ext::make_pair(1l, 9l);

  auto res1 = traverse::BFS::findPath(graph, start, goal);
  auto res2 = traverse::BFS::findPathBidirectional(graph, start, goal);

  CPPUNIT_ASSERT(res1.size() == 15);
  CPPUNIT_ASSERT(res1.size() == res2.size());
}

// ---------------------------------------------------------------------------------------------------------------------

void ShortestPathTest::testDijkstraGrid() {
  grid::WeightedSquareGrid8<> graph(10, 10);
  using node_type = decltype(graph)::node_type;

  graph.addObstacle(3, 2);
  graph.addObstacle(3, 3);
  graph.addObstacle(3, 4);
  graph.addObstacle(3, 5);
  graph.addObstacle(4, 5);
  graph.addObstacle(5, 5);
  graph.addObstacle(6, 5);

  node_type start = ext::make_pair(8l, 2l);
  node_type goal = ext::make_pair(1l, 9l);

  auto res1 = graph::shortest_path::Dijkstra::findPath(graph, start, goal);
  auto res2 = graph::shortest_path::Dijkstra::findPathBidirectional(graph, start, goal);

  CPPUNIT_ASSERT(fabs(res1.second - (M_SQRT2 * 5 + 4)) < EPS);
  CPPUNIT_ASSERT(res1.second == res2.second);
}

// ---------------------------------------------------------------------------------------------------------------------

void ShortestPathTest::testAStarGrid() {
  grid::WeightedSquareGrid8<> graph(11, 11);
  using node_type = decltype(graph)::node_type;

  graph.addObstacle(2, 5);
  graph.addObstacle(3, 5);
  graph.addObstacle(4, 5);
  graph.addObstacle(5, 5);
  graph.addObstacle(6, 5);
  graph.addObstacle(7, 5);
  graph.addObstacle(5, 3);
  graph.addObstacle(5, 4);
  graph.addObstacle(5, 6);
  graph.addObstacle(5, 7);
  graph.addObstacle(5, 8);

  node_type start = ext::make_pair(9l, 1l);
  node_type goal = ext::make_pair(1l, 9l);

  auto f_heuristic_forward = [&](const node_type &n) -> double {
    return heuristic::DiagonalDistance::diagonalDistance(goal, n);
  };

  auto f_heuristic_backward = [&](const node_type &n) -> double {
    return heuristic::DiagonalDistance::diagonalDistance(start, n);
  };

  auto res1 = graph::shortest_path::AStar::findPath(graph,
                                             start,
                                             goal,
                                             f_heuristic_forward);

  auto res2 = graph::shortest_path::AStar::findPathBidirectional(graph,
                                                          start,
                                                          goal,
                                                          f_heuristic_forward,
                                                          f_heuristic_backward);

  auto res3 = graph::shortest_path::MM::findPathBidirectional(graph,
                                                       start,
                                                       goal,
                                                       f_heuristic_forward,
                                                       f_heuristic_backward);

  CPPUNIT_ASSERT(fabs(res1.second - (M_SQRT2 * 4 + 8)) < EPS);
  CPPUNIT_ASSERT(res1.second == res2.second);
  CPPUNIT_ASSERT(res2.second == res3.second);
}

// ---------------------------------------------------------------------------------------------------------------------

void ShortestPathTest::testAllNonGridAlgorithm() {
  graph::UndirectedGraph<int, edge::WeightedEdge<int, int>> graph;
  int start = 5;
  int goal = 17;

  auto f_heuristic = [&](const int &) -> int {
    return 0;
  };

  graph.addNode(1);
  graph.addNode(2);
  graph.addNode(3);
  graph.addNode(4);
  graph.addNode(5);
  graph.addNode(6);
  graph.addNode(7);
  graph.addNode(8);
  graph.addNode(9);
  graph.addNode(10);
  graph.addNode(11);
  graph.addNode(12);
  graph.addNode(13);
  graph.addNode(14);
  graph.addNode(15);
  graph.addNode(16);
  graph.addNode(17);
  graph.addNode(18);
  graph.addNode(19);
  graph.addNode(20);

  graph.addEdge(1, 5, 10);
  graph.addEdge(2, 3, 20);
  graph.addEdge(2, 4, 4);
  graph.addEdge(2, 6, 19);
  graph.addEdge(2, 5, 9);
  graph.addEdge(4, 7, 21);
  graph.addEdge(5, 6, 24);
  graph.addEdge(6, 7, 1);
  graph.addEdge(6, 8, 3);
  graph.addEdge(6, 9, 15);
  graph.addEdge(7, 10, 18);
  graph.addEdge(8, 12, 22);
  graph.addEdge(9, 11, 14);
  graph.addEdge(10, 18, 50); //
  graph.addEdge(10, 11, 8);
  graph.addEdge(11, 13, 2);
  graph.addEdge(12, 13, 5);
  graph.addEdge(13, 14, 13);
  graph.addEdge(13, 19, 7);
  graph.addEdge(14, 15, 16);
  graph.addEdge(15, 16, 6);
  graph.addEdge(15, 17, 25);
  graph.addEdge(16, 18, 17);
  graph.addEdge(17, 18, 12);
  graph.addEdge(17, 20, 11);
  graph.addEdge(19, 20, 23);

  auto dijkstra = graph::shortest_path::Dijkstra::findPath(graph, start, goal);
  auto dijkstraBi = graph::shortest_path::Dijkstra::findPathBidirectional(graph, start, goal);
  auto bellmanFord = graph::shortest_path::BellmanFord::findPath(graph, start, goal);
  auto spfa = graph::shortest_path::SPFA::findPath(graph, start, goal);
  auto astar = graph::shortest_path::AStar::findPath(graph, start, goal, f_heuristic);
  auto idastar = graph::shortest_path::IDAStar::findPath(graph, start, goal, f_heuristic);
  auto astarBi = graph::shortest_path::AStar::findPathBidirectional(graph, start, goal, f_heuristic, f_heuristic);
  auto mm = graph::shortest_path::MM::findPathBidirectional(graph, start, goal, f_heuristic, f_heuristic);

  CPPUNIT_ASSERT(dijkstra.second == 94);

  CPPUNIT_ASSERT(dijkstra == dijkstraBi);
  CPPUNIT_ASSERT(bellmanFord == dijkstraBi);
  CPPUNIT_ASSERT(bellmanFord == spfa);
  CPPUNIT_ASSERT(astar == spfa);
  CPPUNIT_ASSERT(astar == idastar);
  CPPUNIT_ASSERT(astarBi == idastar);
  CPPUNIT_ASSERT(astarBi == mm);

  auto BFS = traverse::BFS::findPath(graph, start, goal);
  auto BFSBi = traverse::BFS::findPathBidirectional(graph, start, goal);

  CPPUNIT_ASSERT(BFS == BFSBi);
}

// ---------------------------------------------------------------------------------------------------------------------

void ShortestPathTest::testAllGridAlgorithm() {
  grid::WeightedSquareGrid8<int, edge::WeightedEdge<ext::pair<int, int>, double>> graph(11, 11);
  using node_type = ext::pair<int, int>;

  auto start = ext::make_pair(8, 2);
  auto goal = ext::make_pair(1, 9);

  graph.addObstacle(2, 5);
  graph.addObstacle(3, 5);
  graph.addObstacle(4, 5);
  graph.addObstacle(5, 5);
  graph.addObstacle(6, 5);
  graph.addObstacle(7, 5);
  graph.addObstacle(5, 3);
  graph.addObstacle(5, 4);
  graph.addObstacle(5, 6);
  graph.addObstacle(5, 7);
  graph.addObstacle(5, 8);

  auto f_heuristic_forward = [&](const node_type &n) -> double {
    return heuristic::DiagonalDistance::diagonalDistance(goal, n);
  };

  auto f_heuristic_backward = [&](const node_type &n) -> double {
    return heuristic::DiagonalDistance::diagonalDistance(start, n);
  };

  auto dijkstra = graph::shortest_path::Dijkstra::findPath(graph, start, goal);
  auto dijkstraBi = graph::shortest_path::Dijkstra::findPathBidirectional(graph, start, goal);
  auto bellmanFord = graph::shortest_path::BellmanFord::findPath(graph, start, goal);
  auto spfa = graph::shortest_path::SPFA::findPath(graph, start, goal);
  auto astar = graph::shortest_path::AStar::findPath(graph, start, goal, f_heuristic_forward);
  auto idastar = graph::shortest_path::IDAStar::findPath(graph, start, goal, f_heuristic_forward);
  auto astarBi =
      graph::shortest_path::AStar::findPathBidirectional(graph, start, goal, f_heuristic_forward, f_heuristic_backward);
  auto mm = graph::shortest_path::MM::findPathBidirectional(graph, start, goal, f_heuristic_forward, f_heuristic_backward);
  auto jps = graph::shortest_path::JPS::findPath(graph, start, goal, f_heuristic_forward);

  CPPUNIT_ASSERT(fabs(dijkstra.second - (M_SQRT2 * 3 + 8)) < EPS);
  CPPUNIT_ASSERT(fabs(dijkstra.second - dijkstraBi.second) < EPS);
  CPPUNIT_ASSERT(fabs(bellmanFord.second - dijkstraBi.second) < EPS);
  CPPUNIT_ASSERT(fabs(bellmanFord.second - spfa.second) < EPS);
  CPPUNIT_ASSERT(fabs(astar.second - spfa.second) < EPS);
  CPPUNIT_ASSERT(fabs(astar.second - idastar.second) < EPS);
  CPPUNIT_ASSERT(fabs(astarBi.second - idastar.second) < EPS);
  CPPUNIT_ASSERT(fabs(astarBi.second - mm.second) < EPS);
  CPPUNIT_ASSERT(fabs(jps.second - mm.second) < EPS);
}

// ---------------------------------------------------------------------------------------------------------------------

void ShortestPathTest::testAllGridAlgorithmRandom() {
  auto graph_tuple = generate::RandomGridFactory::
  randomGrid<grid::WeightedSquareGrid8<long, edge::WeightedEdge<ext::pair<long, long>>>>(50, 50, 100);
  auto graph = std::get<0>(graph_tuple);
  auto start = std::get<1>(graph_tuple);
  auto goal = std::get<2>(graph_tuple);

  using weight_type = decltype(graph)::edge_type::weight_type;
  using node_type = decltype(graph)::node_type;

  auto f_heuristic_forward = [&](const node_type &n) -> weight_type {
    return heuristic::DiagonalDistance::diagonalDistance(goal, n);
  };

  auto f_heuristic_backward = [&](const node_type &n) -> weight_type {
    return heuristic::DiagonalDistance::diagonalDistance(start, n);
  };

  auto dijkstra = graph::shortest_path::Dijkstra::findPath(graph, start, goal);
  auto dijkstraBi = graph::shortest_path::Dijkstra::findPathBidirectional(graph, start, goal);
  auto bellmanFord = graph::shortest_path::BellmanFord::findPath(graph, start, goal);
  auto spfa = graph::shortest_path::SPFA::findPath(graph, start, goal);
  auto astar = graph::shortest_path::AStar::findPath(graph, start, goal, f_heuristic_forward);
  auto astarBi =
      graph::shortest_path::AStar::findPathBidirectional(graph, start, goal, f_heuristic_forward, f_heuristic_backward);
  auto mm = graph::shortest_path::MM::findPathBidirectional(graph, start, goal, f_heuristic_forward, f_heuristic_backward);
  auto jps = graph::shortest_path::JPS::findPath(graph, start, goal, f_heuristic_forward);

  CPPUNIT_ASSERT(fabs(dijkstra.second - dijkstraBi.second) < EPS);
  CPPUNIT_ASSERT(fabs(bellmanFord.second - dijkstraBi.second) < EPS);
  CPPUNIT_ASSERT(fabs(bellmanFord.second - spfa.second) < EPS);
  CPPUNIT_ASSERT(fabs(astar.second - spfa.second) < EPS);
  CPPUNIT_ASSERT(fabs(astarBi.second - astar.second) < EPS);
  CPPUNIT_ASSERT(fabs(astarBi.second - mm.second) < EPS);
  CPPUNIT_ASSERT(fabs(jps.second - mm.second) < EPS);
}