From 7efa753092c52ba8f036437987134c3ec1bcf468 Mon Sep 17 00:00:00 2001 From: Ruslan Bakiev <572431+veikab@users.noreply.github.com> Date: Wed, 14 Jan 2026 10:12:39 +0700 Subject: [PATCH] Add server-side clustering with pysupercluster --- geo_app/cluster_index.py | 175 +++++++++++++++++++++++++++++++++++++++ geo_app/schema.py | 28 +++++++ pyproject.toml | 4 +- 3 files changed, 206 insertions(+), 1 deletion(-) create mode 100644 geo_app/cluster_index.py diff --git a/geo_app/cluster_index.py b/geo_app/cluster_index.py new file mode 100644 index 0000000..231ac3d --- /dev/null +++ b/geo_app/cluster_index.py @@ -0,0 +1,175 @@ +""" +Cached SuperCluster index for server-side map clustering. + +Uses pysupercluster for fast geospatial point clustering. +Index is lazily initialized on first request and cached in memory. +""" +import logging +import threading +import numpy as np + +logger = logging.getLogger(__name__) + +# Global cache for cluster indices +_cluster_cache = {} +_cache_lock = threading.Lock() + + +def _build_index(nodes, transport_type=None): + """ + Build SuperCluster index from node list. + + Args: + nodes: List of node dicts with latitude, longitude, _key, name + transport_type: Optional filter for transport type + + Returns: + Tuple of (SuperCluster index, node_data dict keyed by index) + """ + try: + import pysupercluster + except ImportError: + logger.error("pysupercluster not installed") + return None, {} + + # Filter nodes with valid coordinates + valid_nodes = [] + for node in nodes: + lat = node.get('latitude') + lon = node.get('longitude') + if lat is not None and lon is not None: + # Filter by transport type if specified + if transport_type: + types = node.get('transport_types') or [] + if transport_type not in types: + continue + valid_nodes.append(node) + + if not valid_nodes: + logger.warning("No valid nodes for clustering") + return None, {} + + # Build numpy array of coordinates (lon, lat) + coords = np.array([ + (node['longitude'], node['latitude']) + for node in valid_nodes + ]) + + # Build node data lookup by index + node_data = { + i: { + 'uuid': node.get('_key'), + 'name': node.get('name'), + 'latitude': node.get('latitude'), + 'longitude': node.get('longitude'), + } + for i, node in enumerate(valid_nodes) + } + + # Create SuperCluster index + # min_zoom=0, max_zoom=16 covers typical map zoom range + # radius=60 pixels for clustering + index = pysupercluster.SuperCluster( + coords, + min_zoom=0, + max_zoom=16, + radius=60, + extent=512, + ) + + logger.info("Built cluster index with %d points", len(valid_nodes)) + return index, node_data + + +def get_clustered_nodes(db, west, south, east, north, zoom, transport_type=None): + """ + Get clustered nodes for given bounding box and zoom level. + + Args: + db: ArangoDB connection + west, south, east, north: Bounding box coordinates + zoom: Map zoom level (integer) + transport_type: Optional filter + + Returns: + List of cluster/point dicts with id, latitude, longitude, count, expansion_zoom, name + """ + cache_key = f"nodes:{transport_type or 'all'}" + + with _cache_lock: + if cache_key not in _cluster_cache: + # Load all nodes from database + aql = """ + FOR node IN nodes + FILTER node.node_type == 'logistics' OR node.node_type == null + FILTER node.latitude != null AND node.longitude != null + RETURN node + """ + cursor = db.aql.execute(aql) + all_nodes = list(cursor) + + # Build index + index, node_data = _build_index(all_nodes, transport_type) + _cluster_cache[cache_key] = (index, node_data, all_nodes) + + index, node_data, all_nodes = _cluster_cache[cache_key] + + if index is None: + return [] + + # Get clusters for bounding box + # pysupercluster uses top_left (lon, lat) and bottom_right (lon, lat) + try: + clusters = index.getClusters( + top_left=(west, north), + bottom_right=(east, south), + zoom=int(zoom), + ) + except Exception as e: + logger.error("getClusters failed: %s", e) + return [] + + results = [] + for cluster in clusters: + cluster_id = cluster.get('id') + count = cluster.get('count', 1) + lat = cluster.get('latitude') + lon = cluster.get('longitude') + expansion_zoom = cluster.get('expansion_zoom') + + # For single points (count=1), get the actual node data + name = None + uuid = None + if count == 1 and cluster_id is not None and cluster_id in node_data: + node_info = node_data[cluster_id] + name = node_info.get('name') + uuid = node_info.get('uuid') + + results.append({ + 'id': uuid or f"cluster-{cluster_id}", + 'latitude': lat, + 'longitude': lon, + 'count': count, + 'expansion_zoom': expansion_zoom, + 'name': name, + }) + + logger.info("Returning %d clusters/points for zoom=%d", len(results), zoom) + return results + + +def invalidate_cache(transport_type=None): + """ + Invalidate cluster cache. + + Call this after nodes are updated in the database. + """ + with _cache_lock: + if transport_type: + cache_key = f"nodes:{transport_type}" + if cache_key in _cluster_cache: + del _cluster_cache[cache_key] + else: + _cluster_cache.clear() + + logger.info("Cluster cache invalidated") diff --git a/geo_app/schema.py b/geo_app/schema.py index eb55ac8..e4d68e1 100644 --- a/geo_app/schema.py +++ b/geo_app/schema.py @@ -6,6 +6,7 @@ import requests import graphene from django.conf import settings from .arango_client import get_db, ensure_graph +from .cluster_index import get_clustered_nodes, invalidate_cache logger = logging.getLogger(__name__) @@ -80,6 +81,16 @@ class ProductRouteOptionType(graphene.ObjectType): routes = graphene.List(RoutePathType) +class ClusterPointType(graphene.ObjectType): + """Cluster or individual point for map display.""" + id = graphene.String(description="UUID for points, 'cluster-N' for clusters") + latitude = graphene.Float() + longitude = graphene.Float() + count = graphene.Int(description="1 for single point, >1 for cluster") + expansion_zoom = graphene.Int(description="Zoom level to expand cluster") + name = graphene.String(description="Node name (only for single points)") + + class Query(graphene.ObjectType): """Root query.""" MAX_EXPANSIONS = 20000 @@ -161,6 +172,17 @@ class Query(graphene.ObjectType): description="Find routes from product offer nodes to destination", ) + clustered_nodes = graphene.List( + ClusterPointType, + west=graphene.Float(required=True, description="Bounding box west longitude"), + south=graphene.Float(required=True, description="Bounding box south latitude"), + east=graphene.Float(required=True, description="Bounding box east longitude"), + north=graphene.Float(required=True, description="Bounding box north latitude"), + zoom=graphene.Int(required=True, description="Map zoom level (0-16)"), + transport_type=graphene.String(description="Filter by transport type"), + description="Get clustered nodes for map display (server-side clustering)", + ) + @staticmethod def _build_routes(db, from_uuid, to_uuid, limit): """Shared helper to compute K shortest routes between two nodes.""" @@ -740,6 +762,12 @@ class Query(graphene.ObjectType): return found_routes + def resolve_clustered_nodes(self, info, west, south, east, north, zoom, transport_type=None): + """Get clustered nodes for map display using server-side SuperCluster.""" + db = get_db() + clusters = get_clustered_nodes(db, west, south, east, north, zoom, transport_type) + return [ClusterPointType(**c) for c in clusters] + schema = graphene.Schema(query=Query) diff --git a/pyproject.toml b/pyproject.toml index e3c84fd..4763c2c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,9 @@ dependencies = [ "infisicalsdk (>=1.0.12,<2.0.0)", "gunicorn (>=23.0.0,<24.0.0)", "whitenoise (>=6.7.0,<7.0.0)", - "sentry-sdk (>=2.47.0,<3.0.0)" + "sentry-sdk (>=2.47.0,<3.0.0)", + "pysupercluster (>=0.7.7,<1.0.0)", + "numpy (>=1.26.0,<3.0.0)" ] [tool.poetry]