Replace pysupercluster with h3 for clustering
All checks were successful
Build Docker Image / build (push) Successful in 1m38s

This commit is contained in:
Ruslan Bakiev
2026-01-14 10:24:40 +07:00
parent 7efa753092
commit 0330203a58
2 changed files with 81 additions and 135 deletions

View File

@@ -1,104 +1,33 @@
"""
Cached SuperCluster index for server-side map clustering.
Server-side map clustering using Uber H3 hexagonal grid.
Uses pysupercluster for fast geospatial point clustering.
Index is lazily initialized on first request and cached in memory.
Maps zoom levels to h3 resolutions and groups nodes by cell.
"""
import logging
import threading
import numpy as np
import h3
logger = logging.getLogger(__name__)
# Global cache for cluster indices
_cluster_cache = {}
# Global cache for nodes
_nodes_cache = {}
_cache_lock = threading.Lock()
def _build_index(nodes, transport_type=None):
"""
Build SuperCluster index from node list.
Args:
nodes: List of node dicts with latitude, longitude, _key, name
transport_type: Optional filter for transport type
Returns:
Tuple of (SuperCluster index, node_data dict keyed by index)
"""
try:
import pysupercluster
except ImportError:
logger.error("pysupercluster not installed")
return None, {}
# Filter nodes with valid coordinates
valid_nodes = []
for node in nodes:
lat = node.get('latitude')
lon = node.get('longitude')
if lat is not None and lon is not None:
# Filter by transport type if specified
if transport_type:
types = node.get('transport_types') or []
if transport_type not in types:
continue
valid_nodes.append(node)
if not valid_nodes:
logger.warning("No valid nodes for clustering")
return None, {}
# Build numpy array of coordinates (lon, lat)
coords = np.array([
(node['longitude'], node['latitude'])
for node in valid_nodes
])
# Build node data lookup by index
node_data = {
i: {
'uuid': node.get('_key'),
'name': node.get('name'),
'latitude': node.get('latitude'),
'longitude': node.get('longitude'),
}
for i, node in enumerate(valid_nodes)
}
# Create SuperCluster index
# min_zoom=0, max_zoom=16 covers typical map zoom range
# radius=60 pixels for clustering
index = pysupercluster.SuperCluster(
coords,
min_zoom=0,
max_zoom=16,
radius=60,
extent=512,
)
logger.info("Built cluster index with %d points", len(valid_nodes))
return index, node_data
# Map zoom level to h3 resolution
# Higher zoom = higher resolution = smaller cells
ZOOM_TO_RES = {
0: 0, 1: 0, 2: 1, 3: 1, 4: 2, 5: 2,
6: 3, 7: 3, 8: 4, 9: 4, 10: 5, 11: 5,
12: 6, 13: 7, 14: 8, 15: 9, 16: 10
}
def get_clustered_nodes(db, west, south, east, north, zoom, transport_type=None):
"""
Get clustered nodes for given bounding box and zoom level.
Args:
db: ArangoDB connection
west, south, east, north: Bounding box coordinates
zoom: Map zoom level (integer)
transport_type: Optional filter
Returns:
List of cluster/point dicts with id, latitude, longitude, count, expansion_zoom, name
"""
def _fetch_nodes(db, transport_type=None):
"""Fetch nodes from database with caching."""
cache_key = f"nodes:{transport_type or 'all'}"
with _cache_lock:
if cache_key not in _cluster_cache:
# Load all nodes from database
if cache_key not in _nodes_cache:
aql = """
FOR node IN nodes
FILTER node.node_type == 'logistics' OR node.node_type == null
@@ -108,68 +37,86 @@ def get_clustered_nodes(db, west, south, east, north, zoom, transport_type=None)
cursor = db.aql.execute(aql)
all_nodes = list(cursor)
# Build index
index, node_data = _build_index(all_nodes, transport_type)
_cluster_cache[cache_key] = (index, node_data, all_nodes)
# Filter by transport type if specified
if transport_type:
all_nodes = [
n for n in all_nodes
if transport_type in (n.get('transport_types') or [])
]
index, node_data, all_nodes = _cluster_cache[cache_key]
_nodes_cache[cache_key] = all_nodes
logger.info("Cached %d nodes for %s", len(all_nodes), cache_key)
if index is None:
return _nodes_cache[cache_key]
def get_clustered_nodes(db, west, south, east, north, zoom, transport_type=None):
"""
Get clustered nodes for given bounding box and zoom level.
Uses H3 hexagonal grid to group nearby nodes.
"""
resolution = ZOOM_TO_RES.get(int(zoom), 5)
nodes = _fetch_nodes(db, transport_type)
if not nodes:
return []
# Get clusters for bounding box
# pysupercluster uses top_left (lon, lat) and bottom_right (lon, lat)
try:
clusters = index.getClusters(
top_left=(west, north),
bottom_right=(east, south),
zoom=int(zoom),
)
except Exception as e:
logger.error("getClusters failed: %s", e)
return []
# Group nodes by h3 cell
cells = {}
for node in nodes:
lat = node.get('latitude')
lng = node.get('longitude')
# Skip nodes outside bounding box (rough filter)
if lat < south or lat > north or lng < west or lng > east:
continue
cell = h3.latlng_to_cell(lat, lng, resolution)
if cell not in cells:
cells[cell] = []
cells[cell].append(node)
# Build results
results = []
for cluster in clusters:
cluster_id = cluster.get('id')
count = cluster.get('count', 1)
lat = cluster.get('latitude')
lon = cluster.get('longitude')
expansion_zoom = cluster.get('expansion_zoom')
for cell, nodes_in_cell in cells.items():
count = len(nodes_in_cell)
# For single points (count=1), get the actual node data
name = None
uuid = None
if count == 1 and cluster_id is not None and cluster_id in node_data:
node_info = node_data[cluster_id]
name = node_info.get('name')
uuid = node_info.get('uuid')
if count == 1:
# Single point — return actual node data
node = nodes_in_cell[0]
results.append({
'id': node.get('_key'),
'latitude': node.get('latitude'),
'longitude': node.get('longitude'),
'count': 1,
'expansion_zoom': None,
'name': node.get('name'),
})
else:
# Cluster — return cell centroid
lat, lng = h3.cell_to_latlng(cell)
results.append({
'id': f"cluster-{cell}",
'latitude': lat,
'longitude': lng,
'count': count,
'expansion_zoom': min(zoom + 2, 16),
'name': None,
})
results.append({
'id': uuid or f"cluster-{cluster_id}",
'latitude': lat,
'longitude': lon,
'count': count,
'expansion_zoom': expansion_zoom,
'name': name,
})
logger.info("Returning %d clusters/points for zoom=%d", len(results), zoom)
logger.info("Returning %d clusters/points for zoom=%d res=%d", len(results), zoom, resolution)
return results
def invalidate_cache(transport_type=None):
"""
Invalidate cluster cache.
Call this after nodes are updated in the database.
"""
"""Invalidate node cache after data changes."""
with _cache_lock:
if transport_type:
cache_key = f"nodes:{transport_type}"
if cache_key in _cluster_cache:
del _cluster_cache[cache_key]
if cache_key in _nodes_cache:
del _nodes_cache[cache_key]
else:
_cluster_cache.clear()
_nodes_cache.clear()
logger.info("Cluster cache invalidated")