Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 118 additions & 0 deletions graphs/johnson.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import heapq
from collections.abc import Hashable

Node = Hashable
edge = tuple[Node, Node, float]
adjacency = dict[Node, list[tuple[Node, float]]]
Comment on lines +4 to +6


def _collect_nodes_and_edges(graph: adjacency) -> tuple[list[Node], list[edge]]:
nodes = set()
edges: list[edge] = []
for u, neighbors in graph.items():
nodes.add(u)
for v, w in neighbors:
nodes.add(v)
edges.append((u, v, w))
return list(nodes), edges


def _bellman_ford(nodes: list[Node], edges: list[edge]) -> dict[Node, float]:
"""
Bellman-Ford relaxation to compute potentials h[v] for all vertices.
Raises ValueError if a negative weight cycle exists.
"""
dist: dict[Node, float] = dict.fromkeys(nodes, 0.0)
n = len(nodes)

for _ in range(n - 1):
updated = False
for u, v, w in edges:
if dist[u] + w < dist[v]:
dist[v] = dist[u] + w
updated = True
if not updated:
break
else:
for u, v, w in edges:
if dist[u] + w < dist[v]:
raise ValueError("Negative weight cycle detected")
return dist


def _dijkstra(
start: Node,
nodes: list[Node],
graph: adjacency,
potentials: dict[Node, float],
) -> dict[Node, float]:
"""
Dijkstra over reweighted graph, using potentials h to make weights non-negative.
Returns distances from start in the reweighted space.
"""
inf = float("inf")
dist: dict[Node, float] = dict.fromkeys(nodes, inf)
dist[start] = 0.0
heap: list[tuple[float, Node]] = [(0.0, start)]

Comment on lines +56 to +57
while heap:
d_u, u = heapq.heappop(heap)
if d_u > dist[u]:
continue
for v, w in graph.get(u, []):
w_prime = w + potentials[u] - potentials[v]
if w_prime < 0:
raise ValueError(
"Negative edge weight after reweighting: numeric error"
)
Comment on lines +63 to +67
new_dist = d_u + w_prime
if new_dist < dist[v]:
dist[v] = new_dist
heapq.heappush(heap, (new_dist, v))
return dist


def johnson(graph: adjacency) -> dict[Node, dict[Node, float]]:
"""
Compute all-pairs shortest paths using Johnson's algorithm.

Comment thread
cclauss marked this conversation as resolved.
Reference:
https://en.wikipedia.org/wiki/Johnson%27s_algorithm

Args:
graph: adjacency list {u: [(v, weight), ...], ...}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot is this still a useful change?


Returns:
dict of dicts: dist[u][v] = shortest distance from u to v

Raises:
ValueError: if a negative weight cycle is detected

Comment on lines +88 to +90
Example:
>>> g = {
... 0: [(1, 3), (2, 8), (4, -4)],
... 1: [(3, 1), (4, 7)],
... 2: [(1, 4)],
Comment on lines +91 to +95
... 3: [(0, 2), (2, -5)],
... 4: [(3, 6)],
... }
>>> round(johnson(g)[0][3], 2)
2.0
"""
nodes, edges = _collect_nodes_and_edges(graph)
potentials = _bellman_ford(nodes, edges)

all_pairs: dict[Node, dict[Node, float]] = {}
inf = float("inf")
for s in nodes:
dist_reweighted = _dijkstra(s, nodes, graph, potentials)
dists_orig: dict[Node, float] = {}
for v in nodes:
d_prime = dist_reweighted[v]
if d_prime < inf:
dists_orig[v] = d_prime - potentials[s] + potentials[v]
else:
dists_orig[v] = inf
all_pairs[s] = dists_orig

return all_pairs
24 changes: 24 additions & 0 deletions graphs/tests/test_johnson.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import math

import pytest

from graphs.johnson import johnson


def test_johnson_basic():
g = {
0: [(1, 3), (2, 8), (4, -4)],
1: [(3, 1), (4, 7)],
2: [(1, 4)],
3: [(0, 2), (2, -5)],
4: [(3, 6)],
}
dist = johnson(g)
assert math.isclose(dist[0][3], 2.0, abs_tol=1e-9)
assert math.isclose(dist[3][2], -5.0, abs_tol=1e-9)


Comment on lines +17 to +20
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot is this still a useful change?

def test_johnson_negative_cycle():
g2 = {0: [(1, 1)], 1: [(0, -3)]}
with pytest.raises(ValueError):
johnson(g2)