Skip to content
Snippets Groups Projects
Commit 8ba19d2d authored by David Rosca's avatar David Rosca
Browse files

Graph algo: Add DFS + tests

parent a684305f
No related branches found
No related tags found
No related merge requests found
#include "Dfs.h"
#include <stack>
#include <unordered_map>
namespace graph
{
namespace traverse
{
struct Data
{
Node start;
std::function<bool(const Node&)> func;
};
template <typename T>
static void dfs_impl(const T &graph, const Node &start, std::function<bool(const Node&)> func)
{
std::unordered_map<Node, bool> visited;
std::stack<Node> s;
s.push(start);
while (!s.empty()) {
Node n = s.top(); s.pop();
if (visited.find(n) == visited.end()) {
if (!func(n)) {
return;
}
visited.insert({n, true});
for (const Node &e : graph.neighbors(n)) {
s.push(e);
}
}
}
}
void Dfs::dfs(const Graph &graph, const Node &start, std::function<bool(const Node&)> func)
{
Data data = { start, func };
graph.getData().Accept(static_cast<void*>(&data), DFS);
}
void Dfs::dfs(const DirectedGraph &graph, const Node &start, std::function<bool(const Node&)> func)
{
dfs_impl(graph, start, func);
}
void Dfs::dfs(const UndirectedGraph &graph, const Node &start, std::function<bool(const Node&)> func)
{
dfs_impl(graph, start, func);
}
void Dfs::Visit(void *data, const DirectedGraph &graph) const
{
Data d = *static_cast<Data*>(data);
dfs(graph, d.start, d.func);
}
void Dfs::Visit(void *data, const UndirectedGraph &graph) const
{
Data d = *static_cast<Data*>(data);
dfs(graph, d.start, d.func);
}
const Dfs Dfs::DFS;
} // namespace traverse
} // namespace graph
#ifndef GRAPH_DFS_H_
#define GRAPH_DFS_H_
#include <graph/Graph.h>
#include <graph/directed/DirectedGraph.h>
#include <graph/undirected/UndirectedGraph.h>
namespace graph
{
namespace traverse
{
// func is called for each visited node, traversal stops when returning false
class Dfs : public graph::VisitableGraphBase::const_visitor_type
{
public:
static void dfs(const Graph &graph, const Node &start, std::function<bool(const Node&)> func);
static void dfs(const DirectedGraph &graph, const Node &start, std::function<bool(const Node&)> func);
static void dfs(const UndirectedGraph &graph, const Node &start, std::function<bool(const Node&)> func);
private:
void Visit(void *data, const DirectedGraph &graph) const;
void Visit(void *data, const UndirectedGraph &graph) const;
static const Dfs DFS;
};
} // namespace traverse
} // namespace graph
#endif // GRAPH_DFS_H_
#include "DfsTest.h"
#include "graph/traverse/Dfs.h"
CPPUNIT_TEST_SUITE_REGISTRATION(GraphDfsTest);
void GraphDfsTest::testTraverseAll()
{
// Common
int counter;
graph::Node n1("n1");
graph::Node n2("n2");
graph::Node n3("n3");
graph::Node n4("n4");
graph::Node n5("n5");
graph::Node n6("n6");
// Directed
graph::DirectedGraph dg;
dg.addEdge(graph::DirectedEdge(n1, n2));
dg.addEdge(graph::DirectedEdge(n1, n2, "multi-edge"));
dg.addEdge(graph::DirectedEdge(n2, n3));
dg.addEdge(graph::DirectedEdge(n3, n4));
dg.addEdge(graph::DirectedEdge(n3, n6));
dg.addEdge(graph::DirectedEdge(n4, n1));
dg.addEdge(graph::DirectedEdge(n4, n5));
counter = 0;
graph::traverse::Dfs::dfs(dg, n1, [&counter](const graph::Node &) {
counter++;
return true;
});
CPPUNIT_ASSERT_EQUAL(6, counter);
// Undirected
graph::UndirectedGraph ug;
ug.addEdge(graph::UndirectedEdge(n1, n2));
ug.addEdge(graph::UndirectedEdge(n1, n2, "multi-edge"));
ug.addEdge(graph::UndirectedEdge(n2, n3));
ug.addEdge(graph::UndirectedEdge(n3, n4));
ug.addEdge(graph::UndirectedEdge(n3, n6));
ug.addEdge(graph::UndirectedEdge(n4, n1));
ug.addEdge(graph::UndirectedEdge(n4, n5));
counter = 0;
graph::traverse::Dfs::dfs(ug, n1, [&counter](const graph::Node &) {
counter++;
return true;
});
CPPUNIT_ASSERT_EQUAL(6, counter);
}
void GraphDfsTest::testEarlyReturn()
{
// Common
int counter;
graph::Node n1("n1");
graph::Node n2("n2");
graph::Node n3("n3");
graph::Node n4("n4");
// Directed
graph::DirectedGraph dg;
dg.addEdge(graph::DirectedEdge(n1, n2));
dg.addEdge(graph::DirectedEdge(n1, n2, "multi-edge"));
dg.addEdge(graph::DirectedEdge(n2, n3));
dg.addEdge(graph::DirectedEdge(n3, n4));
counter = 0;
graph::traverse::Dfs::dfs(dg, n1, [&counter](const graph::Node &) {
counter++;
if (counter == 2) {
return false;
}
return true;
});
CPPUNIT_ASSERT_EQUAL(2, counter);
// Undirected
graph::UndirectedGraph ug;
ug.addEdge(graph::UndirectedEdge(n1, n2));
ug.addEdge(graph::UndirectedEdge(n1, n2, "multi-edge"));
ug.addEdge(graph::UndirectedEdge(n2, n3));
ug.addEdge(graph::UndirectedEdge(n3, n4));
counter = 0;
graph::traverse::Dfs::dfs(ug, n1, [&counter](const graph::Node &) {
counter++;
if (counter == 2) {
return false;
}
return true;
});
CPPUNIT_ASSERT_EQUAL(2, counter);
}
void GraphDfsTest::testDisconnectedGraph()
{
// Common
int counter;
graph::Node n1("n1");
graph::Node n2("n2");
graph::Node n3("n3");
graph::Node n4("n4");
graph::Node n5("n5");
// Directed
graph::DirectedGraph dg;
dg.addNode(n5);
dg.addEdge(graph::DirectedEdge(n1, n2));
dg.addEdge(graph::DirectedEdge(n1, n2, "multi-edge"));
dg.addEdge(graph::DirectedEdge(n2, n3));
dg.addEdge(graph::DirectedEdge(n3, n4));
counter = 0;
graph::traverse::Dfs::dfs(dg, n1, [&counter](const graph::Node &) {
counter++;
return true;
});
CPPUNIT_ASSERT_EQUAL(4, counter);
counter = 0;
graph::traverse::Dfs::dfs(dg, n5, [&counter](const graph::Node &) {
counter++;
return true;
});
CPPUNIT_ASSERT_EQUAL(1, counter);
// Undirected
graph::UndirectedGraph ug;
ug.addNode(n5);
ug.addEdge(graph::UndirectedEdge(n1, n2));
ug.addEdge(graph::UndirectedEdge(n1, n2, "multi-edge"));
ug.addEdge(graph::UndirectedEdge(n2, n3));
ug.addEdge(graph::UndirectedEdge(n3, n4));
counter = 0;
graph::traverse::Dfs::dfs(ug, n1, [&counter](const graph::Node &) {
counter++;
return true;
});
CPPUNIT_ASSERT_EQUAL(4, counter);
counter = 0;
graph::traverse::Dfs::dfs(ug, n5, [&counter](const graph::Node &) {
counter++;
return true;
});
CPPUNIT_ASSERT_EQUAL(1, counter);
}
#ifndef DFS_TEST_H_
#define DFS_TEST_H_
#include <cppunit/extensions/HelperMacros.h>
class GraphDfsTest : public CppUnit::TestFixture
{
CPPUNIT_TEST_SUITE(GraphDfsTest);
CPPUNIT_TEST(testTraverseAll);
CPPUNIT_TEST(testEarlyReturn);
CPPUNIT_TEST(testDisconnectedGraph);
CPPUNIT_TEST_SUITE_END();
public:
void testTraverseAll();
void testEarlyReturn();
void testDisconnectedGraph();
};
#endif // DFS_TEST_H_
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment