Move graph routing into route engine
All checks were successful
Build Docker Image / build (push) Successful in 1m56s
All checks were successful
Build Docker Image / build (push) Successful in 1m56s
This commit is contained in:
200
geo_app/route_engine.py
Normal file
200
geo_app/route_engine.py
Normal file
@@ -0,0 +1,200 @@
|
||||
"""Unified graph routing helpers."""
|
||||
import heapq
|
||||
|
||||
from .arango_client import ensure_graph
|
||||
|
||||
|
||||
def _allowed_next_phase(current_phase, transport_type):
|
||||
"""
|
||||
Phase-based routing: auto → rail* → auto.
|
||||
- end_auto: allow one auto, rail, or offer
|
||||
- end_auto_done: auto used — rail or offer
|
||||
- rail: any number of rail, then one auto or offer
|
||||
- start_auto_done: auto used — only offer
|
||||
"""
|
||||
if current_phase == 'end_auto':
|
||||
if transport_type == 'offer':
|
||||
return 'offer'
|
||||
if transport_type == 'auto':
|
||||
return 'end_auto_done'
|
||||
if transport_type == 'rail':
|
||||
return 'rail'
|
||||
return None
|
||||
if current_phase == 'end_auto_done':
|
||||
if transport_type == 'offer':
|
||||
return 'offer'
|
||||
if transport_type == 'rail':
|
||||
return 'rail'
|
||||
return None
|
||||
if current_phase == 'rail':
|
||||
if transport_type == 'offer':
|
||||
return 'offer'
|
||||
if transport_type == 'rail':
|
||||
return 'rail'
|
||||
if transport_type == 'auto':
|
||||
return 'start_auto_done'
|
||||
return None
|
||||
if current_phase == 'start_auto_done':
|
||||
if transport_type == 'offer':
|
||||
return 'offer'
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def _allowed_types_for_phase(phase):
|
||||
if phase == 'end_auto':
|
||||
return ['auto', 'rail', 'offer']
|
||||
if phase == 'end_auto_done':
|
||||
return ['rail', 'offer']
|
||||
if phase == 'rail':
|
||||
return ['rail', 'auto', 'offer']
|
||||
if phase == 'start_auto_done':
|
||||
return ['offer']
|
||||
return ['offer']
|
||||
|
||||
|
||||
def _fetch_neighbors(db, node_key, allowed_types):
|
||||
aql = """
|
||||
FOR edge IN edges
|
||||
FILTER edge.transport_type IN @types
|
||||
FILTER edge._from == @node_id OR edge._to == @node_id
|
||||
LET neighbor_id = edge._from == @node_id ? edge._to : edge._from
|
||||
LET neighbor = DOCUMENT(neighbor_id)
|
||||
FILTER neighbor != null
|
||||
RETURN {
|
||||
neighbor_key: neighbor._key,
|
||||
neighbor_doc: neighbor,
|
||||
from_id: edge._from,
|
||||
to_id: edge._to,
|
||||
transport_type: edge.transport_type,
|
||||
distance_km: edge.distance_km,
|
||||
travel_time_seconds: edge.travel_time_seconds
|
||||
}
|
||||
"""
|
||||
cursor = db.aql.execute(
|
||||
aql,
|
||||
bind_vars={'node_id': f"nodes/{node_key}", 'types': allowed_types},
|
||||
)
|
||||
return list(cursor)
|
||||
|
||||
|
||||
def graph_find_targets(db, start_uuid, target_predicate, route_builder, limit=10, max_expansions=20000):
|
||||
"""Unified graph traversal: auto → rail* → auto, returns routes for target nodes."""
|
||||
ensure_graph()
|
||||
|
||||
nodes_col = db.collection('nodes')
|
||||
start = nodes_col.get(start_uuid)
|
||||
if not start:
|
||||
return []
|
||||
|
||||
queue = []
|
||||
counter = 0
|
||||
heapq.heappush(queue, (0, counter, start_uuid, 'end_auto'))
|
||||
|
||||
visited = {}
|
||||
predecessors = {}
|
||||
node_docs = {start_uuid: start}
|
||||
found = []
|
||||
expansions = 0
|
||||
|
||||
while queue and len(found) < limit and expansions < max_expansions:
|
||||
cost, _, node_key, phase = heapq.heappop(queue)
|
||||
|
||||
if (node_key, phase) in visited and cost > visited[(node_key, phase)]:
|
||||
continue
|
||||
visited[(node_key, phase)] = cost
|
||||
|
||||
node_doc = node_docs.get(node_key)
|
||||
if node_doc and target_predicate(node_doc):
|
||||
path_edges = []
|
||||
state = (node_key, phase)
|
||||
current_key = node_key
|
||||
while state in predecessors:
|
||||
prev_state, edge_info = predecessors[state]
|
||||
prev_key = prev_state[0]
|
||||
path_edges.append((current_key, prev_key, edge_info))
|
||||
state = prev_state
|
||||
current_key = prev_key
|
||||
|
||||
route = route_builder(path_edges, node_docs) if route_builder else None
|
||||
distance_km = route.total_distance_km if route else None
|
||||
|
||||
found.append({
|
||||
'node': node_doc,
|
||||
'route': route,
|
||||
'distance_km': distance_km,
|
||||
'cost': cost,
|
||||
})
|
||||
continue
|
||||
|
||||
neighbors = _fetch_neighbors(db, node_key, _allowed_types_for_phase(phase))
|
||||
expansions += 1
|
||||
|
||||
for neighbor in neighbors:
|
||||
transport_type = neighbor.get('transport_type')
|
||||
next_phase = _allowed_next_phase(phase, transport_type)
|
||||
if next_phase is None:
|
||||
continue
|
||||
|
||||
travel_time = neighbor.get('travel_time_seconds')
|
||||
distance_km = neighbor.get('distance_km')
|
||||
neighbor_key = neighbor.get('neighbor_key')
|
||||
if not neighbor_key:
|
||||
continue
|
||||
|
||||
node_docs[neighbor_key] = neighbor.get('neighbor_doc')
|
||||
step_cost = travel_time if travel_time is not None else (distance_km or 0)
|
||||
new_cost = cost + step_cost
|
||||
|
||||
state_key = (neighbor_key, next_phase)
|
||||
if state_key in visited and new_cost >= visited[state_key]:
|
||||
continue
|
||||
|
||||
counter += 1
|
||||
heapq.heappush(queue, (new_cost, counter, neighbor_key, next_phase))
|
||||
predecessors[state_key] = ((node_key, phase), neighbor)
|
||||
|
||||
return found
|
||||
|
||||
|
||||
def snap_to_nearest_hub(db, lat, lon):
|
||||
aql = """
|
||||
FOR hub IN nodes
|
||||
FILTER hub.node_type == 'logistics' OR hub.node_type == null
|
||||
FILTER hub.product_uuid == null
|
||||
LET types = hub.transport_types != null ? hub.transport_types : []
|
||||
FILTER ('rail' IN types) OR ('sea' IN types)
|
||||
FILTER hub.latitude != null AND hub.longitude != null
|
||||
LET dist = DISTANCE(hub.latitude, hub.longitude, @lat, @lon) / 1000
|
||||
SORT dist ASC
|
||||
LIMIT 1
|
||||
RETURN hub
|
||||
"""
|
||||
cursor = db.aql.execute(aql, bind_vars={'lat': lat, 'lon': lon})
|
||||
hubs = list(cursor)
|
||||
return hubs[0] if hubs else None
|
||||
|
||||
|
||||
def resolve_start_hub(db, source_uuid=None, lat=None, lon=None):
|
||||
nodes_col = db.collection('nodes')
|
||||
|
||||
if source_uuid:
|
||||
node = nodes_col.get(source_uuid)
|
||||
if not node:
|
||||
return None
|
||||
|
||||
if node.get('node_type') in ('logistics', None):
|
||||
types = node.get('transport_types') or []
|
||||
if ('rail' in types) or ('sea' in types):
|
||||
return node
|
||||
|
||||
node_lat = node.get('latitude')
|
||||
node_lon = node.get('longitude')
|
||||
if node_lat is None or node_lon is None:
|
||||
return None
|
||||
return snap_to_nearest_hub(db, node_lat, node_lon)
|
||||
|
||||
if lat is None or lon is None:
|
||||
return None
|
||||
|
||||
return snap_to_nearest_hub(db, lat, lon)
|
||||
@@ -1,11 +1,11 @@
|
||||
"""GraphQL schema for Geo service."""
|
||||
import logging
|
||||
import heapq
|
||||
import math
|
||||
import requests
|
||||
import graphene
|
||||
from django.conf import settings
|
||||
from .arango_client import get_db, ensure_graph
|
||||
from .route_engine import graph_find_targets, resolve_start_hub, snap_to_nearest_hub
|
||||
from .cluster_index import get_clustered_nodes
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -1044,10 +1044,11 @@ class Query(graphene.ObjectType):
|
||||
logger.info("Hub %s not found", hub_uuid)
|
||||
return []
|
||||
|
||||
matches = _graph_find_targets(
|
||||
matches = graph_find_targets(
|
||||
db,
|
||||
start_uuid=hub_uuid,
|
||||
target_predicate=lambda doc: doc.get('node_type') == 'offer',
|
||||
route_builder=_build_route_from_edges,
|
||||
limit=1000,
|
||||
max_expansions=Query.MAX_EXPANSIONS,
|
||||
)
|
||||
@@ -1164,12 +1165,13 @@ class Query(graphene.ObjectType):
|
||||
logger.info("Hub %s missing coordinates", hub_uuid)
|
||||
return []
|
||||
|
||||
matches = _graph_find_targets(
|
||||
matches = graph_find_targets(
|
||||
db,
|
||||
start_uuid=hub_uuid,
|
||||
target_predicate=lambda doc: doc.get('node_type') == 'offer' and (
|
||||
product_uuid is None or doc.get('product_uuid') == product_uuid
|
||||
),
|
||||
route_builder=_build_route_from_edges,
|
||||
limit=limit,
|
||||
max_expansions=Query.MAX_EXPANSIONS,
|
||||
)
|
||||
@@ -1210,7 +1212,7 @@ class Query(graphene.ObjectType):
|
||||
|
||||
# Graph-based nearest hubs when source_uuid provided
|
||||
if source_uuid:
|
||||
start_hub = _resolve_start_hub(db, source_uuid=source_uuid)
|
||||
start_hub = resolve_start_hub(db, source_uuid=source_uuid)
|
||||
if not start_hub:
|
||||
logger.warning("Source node %s not found for nearest hubs, falling back to coordinate search", source_uuid)
|
||||
else:
|
||||
@@ -1224,10 +1226,11 @@ class Query(graphene.ObjectType):
|
||||
types = doc.get('transport_types') or []
|
||||
return ('rail' in types) or ('sea' in types)
|
||||
|
||||
matches = _graph_find_targets(
|
||||
matches = graph_find_targets(
|
||||
db,
|
||||
start_uuid=start_uuid,
|
||||
target_predicate=is_target_hub,
|
||||
route_builder=_build_route_from_edges,
|
||||
limit=limit,
|
||||
max_expansions=Query.MAX_EXPANSIONS,
|
||||
)
|
||||
@@ -1252,7 +1255,7 @@ class Query(graphene.ObjectType):
|
||||
if product_uuid:
|
||||
return self.resolve_hubs_for_product_graph(info, product_uuid, limit=limit)
|
||||
|
||||
start_hub = _resolve_start_hub(db, lat=lat, lon=lon)
|
||||
start_hub = resolve_start_hub(db, lat=lat, lon=lon)
|
||||
if not start_hub:
|
||||
return []
|
||||
|
||||
@@ -1266,10 +1269,11 @@ class Query(graphene.ObjectType):
|
||||
types = doc.get('transport_types') or []
|
||||
return ('rail' in types) or ('sea' in types)
|
||||
|
||||
matches = _graph_find_targets(
|
||||
matches = graph_find_targets(
|
||||
db,
|
||||
start_uuid=start_uuid,
|
||||
target_predicate=is_target_hub,
|
||||
route_builder=_build_route_from_edges,
|
||||
limit=max(limit - 1, 0),
|
||||
max_expansions=Query.MAX_EXPANSIONS,
|
||||
)
|
||||
@@ -1313,7 +1317,7 @@ class Query(graphene.ObjectType):
|
||||
try:
|
||||
nodes_col = db.collection('nodes')
|
||||
|
||||
start_hub = _resolve_start_hub(db, source_uuid=hub_uuid, lat=lat, lon=lon)
|
||||
start_hub = resolve_start_hub(db, source_uuid=hub_uuid, lat=lat, lon=lon)
|
||||
if not start_hub:
|
||||
logger.info("No hub found near coordinates (%.3f, %.3f)", lat, lon)
|
||||
return []
|
||||
@@ -1529,7 +1533,7 @@ class Query(graphene.ObjectType):
|
||||
logger.info("Offer %s not found", offer_uuid)
|
||||
return None
|
||||
|
||||
nearest_hub = _snap_to_nearest_hub(db, lat, lon)
|
||||
nearest_hub = snap_to_nearest_hub(db, lat, lon)
|
||||
if not nearest_hub:
|
||||
logger.info("No hub found near coordinates (%.3f, %.3f)", lat, lon)
|
||||
return None
|
||||
@@ -1537,10 +1541,11 @@ class Query(graphene.ObjectType):
|
||||
hub_uuid = nearest_hub['_key']
|
||||
logger.info("Found nearest hub %s to coordinates (%.3f, %.3f)", hub_uuid, lat, lon)
|
||||
|
||||
matches = _graph_find_targets(
|
||||
matches = graph_find_targets(
|
||||
db,
|
||||
start_uuid=hub_uuid,
|
||||
target_predicate=lambda doc: doc.get('_key') == offer_uuid,
|
||||
route_builder=_build_route_from_edges,
|
||||
limit=1,
|
||||
max_expansions=Query.MAX_EXPANSIONS,
|
||||
)
|
||||
@@ -1839,199 +1844,3 @@ def _distance_km(lat1, lon1, lat2, lon2):
|
||||
|
||||
|
||||
Query._distance_km = _distance_km
|
||||
|
||||
|
||||
def _graph_allowed_next_phase(current_phase, transport_type):
|
||||
"""
|
||||
Phase-based routing: auto → rail* → auto.
|
||||
- end_auto: allow one auto, rail, or offer
|
||||
- end_auto_done: auto used — rail or offer
|
||||
- rail: any number of rail, then one auto or offer
|
||||
- start_auto_done: auto used — only offer
|
||||
"""
|
||||
if current_phase == 'end_auto':
|
||||
if transport_type == 'offer':
|
||||
return 'offer'
|
||||
if transport_type == 'auto':
|
||||
return 'end_auto_done'
|
||||
if transport_type == 'rail':
|
||||
return 'rail'
|
||||
return None
|
||||
if current_phase == 'end_auto_done':
|
||||
if transport_type == 'offer':
|
||||
return 'offer'
|
||||
if transport_type == 'rail':
|
||||
return 'rail'
|
||||
return None
|
||||
if current_phase == 'rail':
|
||||
if transport_type == 'offer':
|
||||
return 'offer'
|
||||
if transport_type == 'rail':
|
||||
return 'rail'
|
||||
if transport_type == 'auto':
|
||||
return 'start_auto_done'
|
||||
return None
|
||||
if current_phase == 'start_auto_done':
|
||||
if transport_type == 'offer':
|
||||
return 'offer'
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def _graph_allowed_types_for_phase(phase):
|
||||
if phase == 'end_auto':
|
||||
return ['auto', 'rail', 'offer']
|
||||
if phase == 'end_auto_done':
|
||||
return ['rail', 'offer']
|
||||
if phase == 'rail':
|
||||
return ['rail', 'auto', 'offer']
|
||||
if phase == 'start_auto_done':
|
||||
return ['offer']
|
||||
return ['offer']
|
||||
|
||||
|
||||
def _graph_fetch_neighbors(db, node_key, allowed_types):
|
||||
aql = """
|
||||
FOR edge IN edges
|
||||
FILTER edge.transport_type IN @types
|
||||
FILTER edge._from == @node_id OR edge._to == @node_id
|
||||
LET neighbor_id = edge._from == @node_id ? edge._to : edge._from
|
||||
LET neighbor = DOCUMENT(neighbor_id)
|
||||
FILTER neighbor != null
|
||||
RETURN {
|
||||
neighbor_key: neighbor._key,
|
||||
neighbor_doc: neighbor,
|
||||
from_id: edge._from,
|
||||
to_id: edge._to,
|
||||
transport_type: edge.transport_type,
|
||||
distance_km: edge.distance_km,
|
||||
travel_time_seconds: edge.travel_time_seconds
|
||||
}
|
||||
"""
|
||||
cursor = db.aql.execute(
|
||||
aql,
|
||||
bind_vars={'node_id': f"nodes/{node_key}", 'types': allowed_types},
|
||||
)
|
||||
return list(cursor)
|
||||
|
||||
|
||||
def _graph_find_targets(db, start_uuid, target_predicate, limit=10, max_expansions=20000):
|
||||
"""Unified graph traversal: auto → rail* → auto, returns routes for target nodes."""
|
||||
ensure_graph()
|
||||
|
||||
nodes_col = db.collection('nodes')
|
||||
start = nodes_col.get(start_uuid)
|
||||
if not start:
|
||||
return []
|
||||
|
||||
queue = []
|
||||
counter = 0
|
||||
heapq.heappush(queue, (0, counter, start_uuid, 'end_auto'))
|
||||
|
||||
visited = {}
|
||||
predecessors = {}
|
||||
node_docs = {start_uuid: start}
|
||||
found = []
|
||||
expansions = 0
|
||||
|
||||
while queue and len(found) < limit and expansions < max_expansions:
|
||||
cost, _, node_key, phase = heapq.heappop(queue)
|
||||
|
||||
if (node_key, phase) in visited and cost > visited[(node_key, phase)]:
|
||||
continue
|
||||
visited[(node_key, phase)] = cost
|
||||
|
||||
node_doc = node_docs.get(node_key)
|
||||
if node_doc and target_predicate(node_doc):
|
||||
path_edges = []
|
||||
state = (node_key, phase)
|
||||
current_key = node_key
|
||||
while state in predecessors:
|
||||
prev_state, edge_info = predecessors[state]
|
||||
prev_key = prev_state[0]
|
||||
path_edges.append((current_key, prev_key, edge_info))
|
||||
state = prev_state
|
||||
current_key = prev_key
|
||||
|
||||
route = _build_route_from_edges(path_edges, node_docs)
|
||||
distance_km = route.total_distance_km if route else None
|
||||
|
||||
found.append({
|
||||
'node': node_doc,
|
||||
'route': route,
|
||||
'distance_km': distance_km,
|
||||
'cost': cost,
|
||||
})
|
||||
continue
|
||||
|
||||
neighbors = _graph_fetch_neighbors(db, node_key, _graph_allowed_types_for_phase(phase))
|
||||
expansions += 1
|
||||
|
||||
for neighbor in neighbors:
|
||||
transport_type = neighbor.get('transport_type')
|
||||
next_phase = _graph_allowed_next_phase(phase, transport_type)
|
||||
if next_phase is None:
|
||||
continue
|
||||
|
||||
travel_time = neighbor.get('travel_time_seconds')
|
||||
distance_km = neighbor.get('distance_km')
|
||||
neighbor_key = neighbor.get('neighbor_key')
|
||||
if not neighbor_key:
|
||||
continue
|
||||
|
||||
node_docs[neighbor_key] = neighbor.get('neighbor_doc')
|
||||
step_cost = travel_time if travel_time is not None else (distance_km or 0)
|
||||
new_cost = cost + step_cost
|
||||
|
||||
state_key = (neighbor_key, next_phase)
|
||||
if state_key in visited and new_cost >= visited[state_key]:
|
||||
continue
|
||||
|
||||
counter += 1
|
||||
heapq.heappush(queue, (new_cost, counter, neighbor_key, next_phase))
|
||||
predecessors[state_key] = ((node_key, phase), neighbor)
|
||||
|
||||
return found
|
||||
|
||||
|
||||
def _snap_to_nearest_hub(db, lat, lon):
|
||||
aql = """
|
||||
FOR hub IN nodes
|
||||
FILTER hub.node_type == 'logistics' OR hub.node_type == null
|
||||
FILTER hub.product_uuid == null
|
||||
LET types = hub.transport_types != null ? hub.transport_types : []
|
||||
FILTER ('rail' IN types) OR ('sea' IN types)
|
||||
FILTER hub.latitude != null AND hub.longitude != null
|
||||
LET dist = DISTANCE(hub.latitude, hub.longitude, @lat, @lon) / 1000
|
||||
SORT dist ASC
|
||||
LIMIT 1
|
||||
RETURN hub
|
||||
"""
|
||||
cursor = db.aql.execute(aql, bind_vars={'lat': lat, 'lon': lon})
|
||||
hubs = list(cursor)
|
||||
return hubs[0] if hubs else None
|
||||
|
||||
|
||||
def _resolve_start_hub(db, source_uuid=None, lat=None, lon=None):
|
||||
nodes_col = db.collection('nodes')
|
||||
|
||||
if source_uuid:
|
||||
node = nodes_col.get(source_uuid)
|
||||
if not node:
|
||||
return None
|
||||
|
||||
if node.get('node_type') in ('logistics', None):
|
||||
types = node.get('transport_types') or []
|
||||
if ('rail' in types) or ('sea' in types):
|
||||
return node
|
||||
|
||||
node_lat = node.get('latitude')
|
||||
node_lon = node.get('longitude')
|
||||
if node_lat is None or node_lon is None:
|
||||
return None
|
||||
return _snap_to_nearest_hub(db, node_lat, node_lon)
|
||||
|
||||
if lat is None or lon is None:
|
||||
return None
|
||||
|
||||
return _snap_to_nearest_hub(db, lat, lon)
|
||||
|
||||
Reference in New Issue
Block a user