From 0330203a5812cf7945b01b1c5fb47fb9a5c99131 Mon Sep 17 00:00:00 2001 From: Ruslan Bakiev <572431+veikab@users.noreply.github.com> Date: Wed, 14 Jan 2026 10:24:40 +0700 Subject: [PATCH] Replace pysupercluster with h3 for clustering --- geo_app/cluster_index.py | 213 +++++++++++++++------------------------ pyproject.toml | 3 +- 2 files changed, 81 insertions(+), 135 deletions(-) diff --git a/geo_app/cluster_index.py b/geo_app/cluster_index.py index 231ac3d..6fafea4 100644 --- a/geo_app/cluster_index.py +++ b/geo_app/cluster_index.py @@ -1,104 +1,33 @@ """ -Cached SuperCluster index for server-side map clustering. +Server-side map clustering using Uber H3 hexagonal grid. -Uses pysupercluster for fast geospatial point clustering. -Index is lazily initialized on first request and cached in memory. +Maps zoom levels to h3 resolutions and groups nodes by cell. """ import logging import threading -import numpy as np +import h3 logger = logging.getLogger(__name__) -# Global cache for cluster indices -_cluster_cache = {} +# Global cache for nodes +_nodes_cache = {} _cache_lock = threading.Lock() - -def _build_index(nodes, transport_type=None): - """ - Build SuperCluster index from node list. - - Args: - nodes: List of node dicts with latitude, longitude, _key, name - transport_type: Optional filter for transport type - - Returns: - Tuple of (SuperCluster index, node_data dict keyed by index) - """ - try: - import pysupercluster - except ImportError: - logger.error("pysupercluster not installed") - return None, {} - - # Filter nodes with valid coordinates - valid_nodes = [] - for node in nodes: - lat = node.get('latitude') - lon = node.get('longitude') - if lat is not None and lon is not None: - # Filter by transport type if specified - if transport_type: - types = node.get('transport_types') or [] - if transport_type not in types: - continue - valid_nodes.append(node) - - if not valid_nodes: - logger.warning("No valid nodes for clustering") - return None, {} - - # Build numpy array of coordinates (lon, lat) - coords = np.array([ - (node['longitude'], node['latitude']) - for node in valid_nodes - ]) - - # Build node data lookup by index - node_data = { - i: { - 'uuid': node.get('_key'), - 'name': node.get('name'), - 'latitude': node.get('latitude'), - 'longitude': node.get('longitude'), - } - for i, node in enumerate(valid_nodes) - } - - # Create SuperCluster index - # min_zoom=0, max_zoom=16 covers typical map zoom range - # radius=60 pixels for clustering - index = pysupercluster.SuperCluster( - coords, - min_zoom=0, - max_zoom=16, - radius=60, - extent=512, - ) - - logger.info("Built cluster index with %d points", len(valid_nodes)) - return index, node_data +# Map zoom level to h3 resolution +# Higher zoom = higher resolution = smaller cells +ZOOM_TO_RES = { + 0: 0, 1: 0, 2: 1, 3: 1, 4: 2, 5: 2, + 6: 3, 7: 3, 8: 4, 9: 4, 10: 5, 11: 5, + 12: 6, 13: 7, 14: 8, 15: 9, 16: 10 +} -def get_clustered_nodes(db, west, south, east, north, zoom, transport_type=None): - """ - Get clustered nodes for given bounding box and zoom level. - - Args: - db: ArangoDB connection - west, south, east, north: Bounding box coordinates - zoom: Map zoom level (integer) - transport_type: Optional filter - - Returns: - List of cluster/point dicts with id, latitude, longitude, count, expansion_zoom, name - """ +def _fetch_nodes(db, transport_type=None): + """Fetch nodes from database with caching.""" cache_key = f"nodes:{transport_type or 'all'}" with _cache_lock: - if cache_key not in _cluster_cache: - # Load all nodes from database + if cache_key not in _nodes_cache: aql = """ FOR node IN nodes FILTER node.node_type == 'logistics' OR node.node_type == null @@ -108,68 +37,86 @@ def get_clustered_nodes(db, west, south, east, north, zoom, transport_type=None) cursor = db.aql.execute(aql) all_nodes = list(cursor) - # Build index - index, node_data = _build_index(all_nodes, transport_type) - _cluster_cache[cache_key] = (index, node_data, all_nodes) + # Filter by transport type if specified + if transport_type: + all_nodes = [ + n for n in all_nodes + if transport_type in (n.get('transport_types') or []) + ] - index, node_data, all_nodes = _cluster_cache[cache_key] + _nodes_cache[cache_key] = all_nodes + logger.info("Cached %d nodes for %s", len(all_nodes), cache_key) - if index is None: + return _nodes_cache[cache_key] + + +def get_clustered_nodes(db, west, south, east, north, zoom, transport_type=None): + """ + Get clustered nodes for given bounding box and zoom level. + + Uses H3 hexagonal grid to group nearby nodes. + """ + resolution = ZOOM_TO_RES.get(int(zoom), 5) + nodes = _fetch_nodes(db, transport_type) + + if not nodes: return [] - # Get clusters for bounding box - # pysupercluster uses top_left (lon, lat) and bottom_right (lon, lat) - try: - clusters = index.getClusters( - top_left=(west, north), - bottom_right=(east, south), - zoom=int(zoom), - ) - except Exception as e: - logger.error("getClusters failed: %s", e) - return [] + # Group nodes by h3 cell + cells = {} + for node in nodes: + lat = node.get('latitude') + lng = node.get('longitude') + # Skip nodes outside bounding box (rough filter) + if lat < south or lat > north or lng < west or lng > east: + continue + + cell = h3.latlng_to_cell(lat, lng, resolution) + if cell not in cells: + cells[cell] = [] + cells[cell].append(node) + + # Build results results = [] - for cluster in clusters: - cluster_id = cluster.get('id') - count = cluster.get('count', 1) - lat = cluster.get('latitude') - lon = cluster.get('longitude') - expansion_zoom = cluster.get('expansion_zoom') + for cell, nodes_in_cell in cells.items(): + count = len(nodes_in_cell) - # For single points (count=1), get the actual node data - name = None - uuid = None - if count == 1 and cluster_id is not None and cluster_id in node_data: - node_info = node_data[cluster_id] - name = node_info.get('name') - uuid = node_info.get('uuid') + if count == 1: + # Single point — return actual node data + node = nodes_in_cell[0] + results.append({ + 'id': node.get('_key'), + 'latitude': node.get('latitude'), + 'longitude': node.get('longitude'), + 'count': 1, + 'expansion_zoom': None, + 'name': node.get('name'), + }) + else: + # Cluster — return cell centroid + lat, lng = h3.cell_to_latlng(cell) + results.append({ + 'id': f"cluster-{cell}", + 'latitude': lat, + 'longitude': lng, + 'count': count, + 'expansion_zoom': min(zoom + 2, 16), + 'name': None, + }) - results.append({ - 'id': uuid or f"cluster-{cluster_id}", - 'latitude': lat, - 'longitude': lon, - 'count': count, - 'expansion_zoom': expansion_zoom, - 'name': name, - }) - - logger.info("Returning %d clusters/points for zoom=%d", len(results), zoom) + logger.info("Returning %d clusters/points for zoom=%d res=%d", len(results), zoom, resolution) return results def invalidate_cache(transport_type=None): - """ - Invalidate cluster cache. - - Call this after nodes are updated in the database. - """ + """Invalidate node cache after data changes.""" with _cache_lock: if transport_type: cache_key = f"nodes:{transport_type}" - if cache_key in _cluster_cache: - del _cluster_cache[cache_key] + if cache_key in _nodes_cache: + del _nodes_cache[cache_key] else: - _cluster_cache.clear() + _nodes_cache.clear() logger.info("Cluster cache invalidated") diff --git a/pyproject.toml b/pyproject.toml index 4763c2c..bf20835 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,8 +16,7 @@ dependencies = [ "gunicorn (>=23.0.0,<24.0.0)", "whitenoise (>=6.7.0,<7.0.0)", "sentry-sdk (>=2.47.0,<3.0.0)", - "pysupercluster (>=0.7.7,<1.0.0)", - "numpy (>=1.26.0,<3.0.0)" + "h3 (>=4.0.0,<5.0.0)" ] [tool.poetry]