From 894d95fc11f579361cde033d86b9c2a1400ae37c Mon Sep 17 00:00:00 2001
From: David Rosca <roscadav@fit.cvut.cz>
Date: Sun, 8 Mar 2015 09:42:20 +0100
Subject: [PATCH] DFS: Add version returning predecessors and opening/closing
 times

Also simplify lambdas in tests to capture everything
---
 alib2algo/src/graph/traverse/Dfs.cpp          | 66 ++++++++++++--
 alib2algo/src/graph/traverse/Dfs.h            |  3 +
 alib2algo/test-src/graph/traverse/DfsTest.cpp | 88 +++++++++++++++++--
 alib2algo/test-src/graph/traverse/DfsTest.h   |  2 +
 4 files changed, 142 insertions(+), 17 deletions(-)

diff --git a/alib2algo/src/graph/traverse/Dfs.cpp b/alib2algo/src/graph/traverse/Dfs.cpp
index 30f00d3ce8..dda9f51b1f 100644
--- a/alib2algo/src/graph/traverse/Dfs.cpp
+++ b/alib2algo/src/graph/traverse/Dfs.cpp
@@ -13,6 +13,7 @@ struct Data
 {
 	Node start;
 	std::function<bool(const Node&)> func;
+	std::function<void(const Node&, const Node&, int, int)> func2;
 };
 
 template <typename T>
@@ -22,23 +23,48 @@ static void dfs_impl(const T &graph, const Node &start, std::function<bool(const
 	std::stack<Node> s;
 
 	s.push(start);
+	visited[start] = true;
+
 	while (!s.empty()) {
 		Node n = s.top(); s.pop();
-		if (visited.find(n) == visited.end()) {
-			if (!func(n)) {
-				return;
-			}
-			visited.insert({n, true});
-			for (const Node &e : graph.neighbors(n)) {
+
+		if (!func(n)) {
+			return;
+		}
+
+		for (const Node &e : graph.neighbors(n)) {
+			if (visited.find(e) == visited.end()) {
+				visited[e] = true;
 				s.push(e);
 			}
 		}
 	}
 }
 
+template <typename T>
+static void dfs2_impl(const T &graph, const Node &n, const Node &p, std::unordered_map<Node, bool> &visited, int &time, std::function<void(const Node&, const Node&, int, int)> func)
+{
+	int opened = ++time;
+	visited[n] = true;
+
+	for (const Node &e : graph.neighbors(n)) {
+		if (visited.find(e) == visited.end()) {
+			dfs2_impl(graph, e, n, visited, time, func);
+		}
+	}
+
+	func(n, p, opened, ++time);
+}
+
 void Dfs::dfs(const Graph &graph, const Node &start, std::function<bool(const Node&)> func)
 {
-	Data data = { start, func };
+	Data data = { start, func, nullptr };
+	graph.getData().Accept(static_cast<void*>(&data), DFS);
+}
+
+void Dfs::dfs(const Graph &graph, const Node &start, std::function<void(const Node&, const Node&, int, int)> func)
+{
+	Data data = { start, nullptr, func };
 	graph.getData().Accept(static_cast<void*>(&data), DFS);
 }
 
@@ -47,21 +73,43 @@ void Dfs::dfs(const DirectedGraph &graph, const Node &start, std::function<bool(
 	dfs_impl(graph, start, func);
 }
 
+void Dfs::dfs(const DirectedGraph &graph, const Node &start, std::function<void(const Node&, const Node&, int, int)> func)
+{
+	int time = 0;
+	std::unordered_map<Node, bool> visited;
+	dfs2_impl(graph, start, Node(), visited, time, func);
+}
+
 void Dfs::dfs(const UndirectedGraph &graph, const Node &start, std::function<bool(const Node&)> func)
 {
 	dfs_impl(graph, start, func);
 }
 
+void Dfs::dfs(const UndirectedGraph &graph, const Node &start, std::function<void(const Node&, const Node&, int, int)> func)
+{
+	int time = 0;
+	std::unordered_map<Node, bool> visited;
+	dfs2_impl(graph, start, Node(), visited, time, func);
+}
+
 void Dfs::Visit(void *data, const DirectedGraph &graph) const
 {
 	Data d = *static_cast<Data*>(data);
-	dfs(graph, d.start, d.func);
+	if (d.func) {
+		dfs(graph, d.start, d.func);
+	} else if (d.func2) {
+		dfs(graph, d.start, d.func2);
+	}
 }
 
 void Dfs::Visit(void *data, const UndirectedGraph &graph) const
 {
 	Data d = *static_cast<Data*>(data);
-	dfs(graph, d.start, d.func);
+	if (d.func) {
+		dfs(graph, d.start, d.func);
+	} else if (d.func2) {
+		dfs(graph, d.start, d.func2);
+	}
 }
 
 const Dfs Dfs::DFS;
diff --git a/alib2algo/src/graph/traverse/Dfs.h b/alib2algo/src/graph/traverse/Dfs.h
index 85fa390ba3..bc45c5de49 100644
--- a/alib2algo/src/graph/traverse/Dfs.h
+++ b/alib2algo/src/graph/traverse/Dfs.h
@@ -16,9 +16,12 @@ class Dfs : public graph::VisitableGraphBase::const_visitor_type
 {
 public:
 	static void dfs(const Graph &graph, const Node &start, std::function<bool(const Node&)> func);
+	static void dfs(const Graph &graph, const Node &start, std::function<void(const Node&, const Node&, int, int)> func);
 
 	static void dfs(const DirectedGraph &graph, const Node &start, std::function<bool(const Node&)> func);
+	static void dfs(const DirectedGraph &graph, const Node &start, std::function<void(const Node&, const Node&, int, int)> func);
 	static void dfs(const UndirectedGraph &graph, const Node &start, std::function<bool(const Node&)> func);
+	static void dfs(const UndirectedGraph &graph, const Node &start, std::function<void(const Node&, const Node&, int, int)> func);
 
 private:
 	void Visit(void *data, const DirectedGraph &graph) const;
diff --git a/alib2algo/test-src/graph/traverse/DfsTest.cpp b/alib2algo/test-src/graph/traverse/DfsTest.cpp
index f4c8071e83..32541be9b9 100644
--- a/alib2algo/test-src/graph/traverse/DfsTest.cpp
+++ b/alib2algo/test-src/graph/traverse/DfsTest.cpp
@@ -26,7 +26,7 @@ void GraphDfsTest::testTraverseAll()
 	dg.addEdge(graph::DirectedEdge(n4, n5));
 
 	counter = 0;
-	graph::traverse::Dfs::dfs(dg, n1, [&counter](const graph::Node &) {
+	graph::traverse::Dfs::dfs(dg, n1, [&](const graph::Node &) {
 		counter++;
 		return true;
 	});
@@ -44,7 +44,7 @@ void GraphDfsTest::testTraverseAll()
 	ug.addEdge(graph::UndirectedEdge(n4, n5));
 
 	counter = 0;
-	graph::traverse::Dfs::dfs(ug, n1, [&counter](const graph::Node &) {
+	graph::traverse::Dfs::dfs(ug, n1, [&](const graph::Node &) {
 		counter++;
 		return true;
 	});
@@ -69,7 +69,7 @@ void GraphDfsTest::testEarlyReturn()
 	dg.addEdge(graph::DirectedEdge(n3, n4));
 
 	counter = 0;
-	graph::traverse::Dfs::dfs(dg, n1, [&counter](const graph::Node &) {
+	graph::traverse::Dfs::dfs(dg, n1, [&](const graph::Node &) {
 		counter++;
 		if (counter == 2) {
 			return false;
@@ -87,7 +87,7 @@ void GraphDfsTest::testEarlyReturn()
 	ug.addEdge(graph::UndirectedEdge(n3, n4));
 
 	counter = 0;
-	graph::traverse::Dfs::dfs(ug, n1, [&counter](const graph::Node &) {
+	graph::traverse::Dfs::dfs(ug, n1, [&](const graph::Node &) {
 		counter++;
 		if (counter == 2) {
 			return false;
@@ -117,7 +117,7 @@ void GraphDfsTest::testDisconnectedGraph()
 	dg.addEdge(graph::DirectedEdge(n3, n4));
 
 	counter = 0;
-	graph::traverse::Dfs::dfs(dg, n1, [&counter](const graph::Node &) {
+	graph::traverse::Dfs::dfs(dg, n1, [&](const graph::Node &) {
 		counter++;
 		return true;
 	});
@@ -125,7 +125,7 @@ void GraphDfsTest::testDisconnectedGraph()
 	CPPUNIT_ASSERT_EQUAL(4, counter);
 
 	counter = 0;
-	graph::traverse::Dfs::dfs(dg, n5, [&counter](const graph::Node &) {
+	graph::traverse::Dfs::dfs(dg, n5, [&](const graph::Node &) {
 		counter++;
 		return true;
 	});
@@ -141,7 +141,7 @@ void GraphDfsTest::testDisconnectedGraph()
 	ug.addEdge(graph::UndirectedEdge(n3, n4));
 
 	counter = 0;
-	graph::traverse::Dfs::dfs(ug, n1, [&counter](const graph::Node &) {
+	graph::traverse::Dfs::dfs(ug, n1, [&](const graph::Node &) {
 		counter++;
 		return true;
 	});
@@ -149,10 +149,82 @@ void GraphDfsTest::testDisconnectedGraph()
 	CPPUNIT_ASSERT_EQUAL(4, counter);
 
 	counter = 0;
-	graph::traverse::Dfs::dfs(ug, n5, [&counter](const graph::Node &) {
+	graph::traverse::Dfs::dfs(ug, n5, [&](const graph::Node &) {
 		counter++;
 		return true;
 	});
 
 	CPPUNIT_ASSERT_EQUAL(1, counter);
 }
+
+void GraphDfsTest::testDfs2()
+{
+	// Common
+	int counter;
+	std::unordered_map<graph::Node, int> opened;
+	std::unordered_map<graph::Node, int> closed;
+	std::unordered_map<graph::Node, graph::Node> predecessors;
+
+	graph::Node n1("n1");
+	graph::Node n2("n2");
+	graph::Node n3("n3");
+	graph::Node n4("n4");
+	graph::Node n5("n5");
+
+	// Directed
+	graph::DirectedGraph dg;
+	dg.addEdge(graph::DirectedEdge(n1, n2));
+	dg.addEdge(graph::DirectedEdge(n1, n3));
+	dg.addEdge(graph::DirectedEdge(n1, n4));
+	dg.addEdge(graph::DirectedEdge(n2, n5));
+
+	counter = 0;
+	opened.clear();
+	closed.clear();
+	predecessors.clear();
+
+	graph::traverse::Dfs::dfs(dg, n1, [&](const graph::Node &n, const graph::Node &p, int o, int c) {
+		counter++;
+		opened[n] = o;
+		closed[n] = c;
+		predecessors[n] = p;
+	});
+
+	CPPUNIT_ASSERT_EQUAL(5, counter);
+	CPPUNIT_ASSERT_EQUAL(4, closed[n5]);
+
+	CPPUNIT_ASSERT_EQUAL(graph::Node(), predecessors[n1]);
+	CPPUNIT_ASSERT_EQUAL(n1, predecessors[n2]);
+	CPPUNIT_ASSERT_EQUAL(n1, predecessors[n3]);
+	CPPUNIT_ASSERT_EQUAL(n1, predecessors[n4]);
+	CPPUNIT_ASSERT_EQUAL(n2, predecessors[n5]);
+
+	// Undirected
+	graph::UndirectedGraph ug;
+	ug.addEdge(graph::UndirectedEdge(n1, n2));
+	ug.addEdge(graph::UndirectedEdge(n1, n3));
+	ug.addEdge(graph::UndirectedEdge(n1, n4));
+	ug.addEdge(graph::UndirectedEdge(n2, n5));
+
+	counter = 0;
+	opened.clear();
+	closed.clear();
+	predecessors.clear();
+
+	graph::traverse::Dfs::dfs(ug, n1, [&](const graph::Node &n, const graph::Node &p, int o, int c) {
+		counter++;
+		opened[n] = o;
+		closed[n] = c;
+		predecessors[n] = p;
+	});
+
+	CPPUNIT_ASSERT_EQUAL(5, counter);
+	CPPUNIT_ASSERT_EQUAL(4, closed[n5]);
+
+	CPPUNIT_ASSERT_EQUAL(graph::Node(), predecessors[n1]);
+	CPPUNIT_ASSERT_EQUAL(n1, predecessors[n2]);
+	CPPUNIT_ASSERT_EQUAL(n1, predecessors[n3]);
+	CPPUNIT_ASSERT_EQUAL(n1, predecessors[n4]);
+	CPPUNIT_ASSERT_EQUAL(n2, predecessors[n5]);
+}
+
diff --git a/alib2algo/test-src/graph/traverse/DfsTest.h b/alib2algo/test-src/graph/traverse/DfsTest.h
index c3bf1350e4..280368070a 100644
--- a/alib2algo/test-src/graph/traverse/DfsTest.h
+++ b/alib2algo/test-src/graph/traverse/DfsTest.h
@@ -9,12 +9,14 @@ class GraphDfsTest : public CppUnit::TestFixture
 	CPPUNIT_TEST(testTraverseAll);
 	CPPUNIT_TEST(testEarlyReturn);
 	CPPUNIT_TEST(testDisconnectedGraph);
+	CPPUNIT_TEST(testDfs2);
 	CPPUNIT_TEST_SUITE_END();
 
 public:
 	void testTraverseAll();
 	void testEarlyReturn();
 	void testDisconnectedGraph();
+	void testDfs2();
 };
 
 #endif // DFS_TEST_H_
-- 
GitLab