Add server-side clustering with pysupercluster
Some checks failed
Build Docker Image / build (push) Failing after 2m14s

This commit is contained in:
Ruslan Bakiev
2026-01-14 10:12:39 +07:00
parent e15976382e
commit 7efa753092
3 changed files with 206 additions and 1 deletions

175
geo_app/cluster_index.py Normal file
View File

@@ -0,0 +1,175 @@
"""
Cached SuperCluster index for server-side map clustering.
Uses pysupercluster for fast geospatial point clustering.
Index is lazily initialized on first request and cached in memory.
"""
import logging
import threading
import numpy as np
logger = logging.getLogger(__name__)
# Global cache for cluster indices
_cluster_cache = {}
_cache_lock = threading.Lock()
def _build_index(nodes, transport_type=None):
"""
Build SuperCluster index from node list.
Args:
nodes: List of node dicts with latitude, longitude, _key, name
transport_type: Optional filter for transport type
Returns:
Tuple of (SuperCluster index, node_data dict keyed by index)
"""
try:
import pysupercluster
except ImportError:
logger.error("pysupercluster not installed")
return None, {}
# Filter nodes with valid coordinates
valid_nodes = []
for node in nodes:
lat = node.get('latitude')
lon = node.get('longitude')
if lat is not None and lon is not None:
# Filter by transport type if specified
if transport_type:
types = node.get('transport_types') or []
if transport_type not in types:
continue
valid_nodes.append(node)
if not valid_nodes:
logger.warning("No valid nodes for clustering")
return None, {}
# Build numpy array of coordinates (lon, lat)
coords = np.array([
(node['longitude'], node['latitude'])
for node in valid_nodes
])
# Build node data lookup by index
node_data = {
i: {
'uuid': node.get('_key'),
'name': node.get('name'),
'latitude': node.get('latitude'),
'longitude': node.get('longitude'),
}
for i, node in enumerate(valid_nodes)
}
# Create SuperCluster index
# min_zoom=0, max_zoom=16 covers typical map zoom range
# radius=60 pixels for clustering
index = pysupercluster.SuperCluster(
coords,
min_zoom=0,
max_zoom=16,
radius=60,
extent=512,
)
logger.info("Built cluster index with %d points", len(valid_nodes))
return index, node_data
def get_clustered_nodes(db, west, south, east, north, zoom, transport_type=None):
"""
Get clustered nodes for given bounding box and zoom level.
Args:
db: ArangoDB connection
west, south, east, north: Bounding box coordinates
zoom: Map zoom level (integer)
transport_type: Optional filter
Returns:
List of cluster/point dicts with id, latitude, longitude, count, expansion_zoom, name
"""
cache_key = f"nodes:{transport_type or 'all'}"
with _cache_lock:
if cache_key not in _cluster_cache:
# Load all nodes from database
aql = """
FOR node IN nodes
FILTER node.node_type == 'logistics' OR node.node_type == null
FILTER node.latitude != null AND node.longitude != null
RETURN node
"""
cursor = db.aql.execute(aql)
all_nodes = list(cursor)
# Build index
index, node_data = _build_index(all_nodes, transport_type)
_cluster_cache[cache_key] = (index, node_data, all_nodes)
index, node_data, all_nodes = _cluster_cache[cache_key]
if index is None:
return []
# Get clusters for bounding box
# pysupercluster uses top_left (lon, lat) and bottom_right (lon, lat)
try:
clusters = index.getClusters(
top_left=(west, north),
bottom_right=(east, south),
zoom=int(zoom),
)
except Exception as e:
logger.error("getClusters failed: %s", e)
return []
results = []
for cluster in clusters:
cluster_id = cluster.get('id')
count = cluster.get('count', 1)
lat = cluster.get('latitude')
lon = cluster.get('longitude')
expansion_zoom = cluster.get('expansion_zoom')
# For single points (count=1), get the actual node data
name = None
uuid = None
if count == 1 and cluster_id is not None and cluster_id in node_data:
node_info = node_data[cluster_id]
name = node_info.get('name')
uuid = node_info.get('uuid')
results.append({
'id': uuid or f"cluster-{cluster_id}",
'latitude': lat,
'longitude': lon,
'count': count,
'expansion_zoom': expansion_zoom,
'name': name,
})
logger.info("Returning %d clusters/points for zoom=%d", len(results), zoom)
return results
def invalidate_cache(transport_type=None):
"""
Invalidate cluster cache.
Call this after nodes are updated in the database.
"""
with _cache_lock:
if transport_type:
cache_key = f"nodes:{transport_type}"
if cache_key in _cluster_cache:
del _cluster_cache[cache_key]
else:
_cluster_cache.clear()
logger.info("Cluster cache invalidated")

View File

@@ -6,6 +6,7 @@ import requests
import graphene
from django.conf import settings
from .arango_client import get_db, ensure_graph
from .cluster_index import get_clustered_nodes, invalidate_cache
logger = logging.getLogger(__name__)
@@ -80,6 +81,16 @@ class ProductRouteOptionType(graphene.ObjectType):
routes = graphene.List(RoutePathType)
class ClusterPointType(graphene.ObjectType):
"""Cluster or individual point for map display."""
id = graphene.String(description="UUID for points, 'cluster-N' for clusters")
latitude = graphene.Float()
longitude = graphene.Float()
count = graphene.Int(description="1 for single point, >1 for cluster")
expansion_zoom = graphene.Int(description="Zoom level to expand cluster")
name = graphene.String(description="Node name (only for single points)")
class Query(graphene.ObjectType):
"""Root query."""
MAX_EXPANSIONS = 20000
@@ -161,6 +172,17 @@ class Query(graphene.ObjectType):
description="Find routes from product offer nodes to destination",
)
clustered_nodes = graphene.List(
ClusterPointType,
west=graphene.Float(required=True, description="Bounding box west longitude"),
south=graphene.Float(required=True, description="Bounding box south latitude"),
east=graphene.Float(required=True, description="Bounding box east longitude"),
north=graphene.Float(required=True, description="Bounding box north latitude"),
zoom=graphene.Int(required=True, description="Map zoom level (0-16)"),
transport_type=graphene.String(description="Filter by transport type"),
description="Get clustered nodes for map display (server-side clustering)",
)
@staticmethod
def _build_routes(db, from_uuid, to_uuid, limit):
"""Shared helper to compute K shortest routes between two nodes."""
@@ -740,6 +762,12 @@ class Query(graphene.ObjectType):
return found_routes
def resolve_clustered_nodes(self, info, west, south, east, north, zoom, transport_type=None):
"""Get clustered nodes for map display using server-side SuperCluster."""
db = get_db()
clusters = get_clustered_nodes(db, west, south, east, north, zoom, transport_type)
return [ClusterPointType(**c) for c in clusters]
schema = graphene.Schema(query=Query)