Replace pysupercluster with h3 for clustering
All checks were successful
Build Docker Image / build (push) Successful in 1m38s

This commit is contained in:
Ruslan Bakiev
2026-01-14 10:24:40 +07:00
parent 7efa753092
commit 0330203a58
2 changed files with 81 additions and 135 deletions

View File

@@ -1,104 +1,33 @@
""" """
Cached SuperCluster index for server-side map clustering. Server-side map clustering using Uber H3 hexagonal grid.
Uses pysupercluster for fast geospatial point clustering. Maps zoom levels to h3 resolutions and groups nodes by cell.
Index is lazily initialized on first request and cached in memory.
""" """
import logging import logging
import threading import threading
import numpy as np import h3
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Global cache for cluster indices # Global cache for nodes
_cluster_cache = {} _nodes_cache = {}
_cache_lock = threading.Lock() _cache_lock = threading.Lock()
# Map zoom level to h3 resolution
def _build_index(nodes, transport_type=None): # Higher zoom = higher resolution = smaller cells
""" ZOOM_TO_RES = {
Build SuperCluster index from node list. 0: 0, 1: 0, 2: 1, 3: 1, 4: 2, 5: 2,
6: 3, 7: 3, 8: 4, 9: 4, 10: 5, 11: 5,
Args: 12: 6, 13: 7, 14: 8, 15: 9, 16: 10
nodes: List of node dicts with latitude, longitude, _key, name }
transport_type: Optional filter for transport type
Returns:
Tuple of (SuperCluster index, node_data dict keyed by index)
"""
try:
import pysupercluster
except ImportError:
logger.error("pysupercluster not installed")
return None, {}
# Filter nodes with valid coordinates
valid_nodes = []
for node in nodes:
lat = node.get('latitude')
lon = node.get('longitude')
if lat is not None and lon is not None:
# Filter by transport type if specified
if transport_type:
types = node.get('transport_types') or []
if transport_type not in types:
continue
valid_nodes.append(node)
if not valid_nodes:
logger.warning("No valid nodes for clustering")
return None, {}
# Build numpy array of coordinates (lon, lat)
coords = np.array([
(node['longitude'], node['latitude'])
for node in valid_nodes
])
# Build node data lookup by index
node_data = {
i: {
'uuid': node.get('_key'),
'name': node.get('name'),
'latitude': node.get('latitude'),
'longitude': node.get('longitude'),
}
for i, node in enumerate(valid_nodes)
}
# Create SuperCluster index
# min_zoom=0, max_zoom=16 covers typical map zoom range
# radius=60 pixels for clustering
index = pysupercluster.SuperCluster(
coords,
min_zoom=0,
max_zoom=16,
radius=60,
extent=512,
)
logger.info("Built cluster index with %d points", len(valid_nodes))
return index, node_data
def get_clustered_nodes(db, west, south, east, north, zoom, transport_type=None): def _fetch_nodes(db, transport_type=None):
""" """Fetch nodes from database with caching."""
Get clustered nodes for given bounding box and zoom level.
Args:
db: ArangoDB connection
west, south, east, north: Bounding box coordinates
zoom: Map zoom level (integer)
transport_type: Optional filter
Returns:
List of cluster/point dicts with id, latitude, longitude, count, expansion_zoom, name
"""
cache_key = f"nodes:{transport_type or 'all'}" cache_key = f"nodes:{transport_type or 'all'}"
with _cache_lock: with _cache_lock:
if cache_key not in _cluster_cache: if cache_key not in _nodes_cache:
# Load all nodes from database
aql = """ aql = """
FOR node IN nodes FOR node IN nodes
FILTER node.node_type == 'logistics' OR node.node_type == null FILTER node.node_type == 'logistics' OR node.node_type == null
@@ -108,68 +37,86 @@ def get_clustered_nodes(db, west, south, east, north, zoom, transport_type=None)
cursor = db.aql.execute(aql) cursor = db.aql.execute(aql)
all_nodes = list(cursor) all_nodes = list(cursor)
# Build index # Filter by transport type if specified
index, node_data = _build_index(all_nodes, transport_type) if transport_type:
_cluster_cache[cache_key] = (index, node_data, all_nodes) all_nodes = [
n for n in all_nodes
if transport_type in (n.get('transport_types') or [])
]
index, node_data, all_nodes = _cluster_cache[cache_key] _nodes_cache[cache_key] = all_nodes
logger.info("Cached %d nodes for %s", len(all_nodes), cache_key)
if index is None: return _nodes_cache[cache_key]
def get_clustered_nodes(db, west, south, east, north, zoom, transport_type=None):
"""
Get clustered nodes for given bounding box and zoom level.
Uses H3 hexagonal grid to group nearby nodes.
"""
resolution = ZOOM_TO_RES.get(int(zoom), 5)
nodes = _fetch_nodes(db, transport_type)
if not nodes:
return [] return []
# Get clusters for bounding box # Group nodes by h3 cell
# pysupercluster uses top_left (lon, lat) and bottom_right (lon, lat) cells = {}
try: for node in nodes:
clusters = index.getClusters( lat = node.get('latitude')
top_left=(west, north), lng = node.get('longitude')
bottom_right=(east, south),
zoom=int(zoom),
)
except Exception as e:
logger.error("getClusters failed: %s", e)
return []
# Skip nodes outside bounding box (rough filter)
if lat < south or lat > north or lng < west or lng > east:
continue
cell = h3.latlng_to_cell(lat, lng, resolution)
if cell not in cells:
cells[cell] = []
cells[cell].append(node)
# Build results
results = [] results = []
for cluster in clusters: for cell, nodes_in_cell in cells.items():
cluster_id = cluster.get('id') count = len(nodes_in_cell)
count = cluster.get('count', 1)
lat = cluster.get('latitude')
lon = cluster.get('longitude')
expansion_zoom = cluster.get('expansion_zoom')
# For single points (count=1), get the actual node data if count == 1:
name = None # Single point — return actual node data
uuid = None node = nodes_in_cell[0]
if count == 1 and cluster_id is not None and cluster_id in node_data: results.append({
node_info = node_data[cluster_id] 'id': node.get('_key'),
name = node_info.get('name') 'latitude': node.get('latitude'),
uuid = node_info.get('uuid') 'longitude': node.get('longitude'),
'count': 1,
'expansion_zoom': None,
'name': node.get('name'),
})
else:
# Cluster — return cell centroid
lat, lng = h3.cell_to_latlng(cell)
results.append({
'id': f"cluster-{cell}",
'latitude': lat,
'longitude': lng,
'count': count,
'expansion_zoom': min(zoom + 2, 16),
'name': None,
})
results.append({ logger.info("Returning %d clusters/points for zoom=%d res=%d", len(results), zoom, resolution)
'id': uuid or f"cluster-{cluster_id}",
'latitude': lat,
'longitude': lon,
'count': count,
'expansion_zoom': expansion_zoom,
'name': name,
})
logger.info("Returning %d clusters/points for zoom=%d", len(results), zoom)
return results return results
def invalidate_cache(transport_type=None): def invalidate_cache(transport_type=None):
""" """Invalidate node cache after data changes."""
Invalidate cluster cache.
Call this after nodes are updated in the database.
"""
with _cache_lock: with _cache_lock:
if transport_type: if transport_type:
cache_key = f"nodes:{transport_type}" cache_key = f"nodes:{transport_type}"
if cache_key in _cluster_cache: if cache_key in _nodes_cache:
del _cluster_cache[cache_key] del _nodes_cache[cache_key]
else: else:
_cluster_cache.clear() _nodes_cache.clear()
logger.info("Cluster cache invalidated") logger.info("Cluster cache invalidated")

View File

@@ -16,8 +16,7 @@ dependencies = [
"gunicorn (>=23.0.0,<24.0.0)", "gunicorn (>=23.0.0,<24.0.0)",
"whitenoise (>=6.7.0,<7.0.0)", "whitenoise (>=6.7.0,<7.0.0)",
"sentry-sdk (>=2.47.0,<3.0.0)", "sentry-sdk (>=2.47.0,<3.0.0)",
"pysupercluster (>=0.7.7,<1.0.0)", "h3 (>=4.0.0,<5.0.0)"
"numpy (>=1.26.0,<3.0.0)"
] ]
[tool.poetry] [tool.poetry]