diff --git a/geo_app/cluster_index.py b/geo_app/cluster_index.py index 10a9ce6..6f2bf02 100644 --- a/geo_app/cluster_index.py +++ b/geo_app/cluster_index.py @@ -4,15 +4,10 @@ Server-side map clustering using Uber H3 hexagonal grid. Maps zoom levels to h3 resolutions and groups nodes by cell. """ import logging -import threading import h3 logger = logging.getLogger(__name__) -# Global cache for nodes -_nodes_cache = {} -_cache_lock = threading.Lock() - # Map zoom level to h3 resolution # Higher zoom = higher resolution = smaller cells ZOOM_TO_RES = { @@ -22,70 +17,77 @@ ZOOM_TO_RES = { } -def _fetch_nodes(db, transport_type=None, node_type=None): - """Fetch nodes from database with caching. +def _fetch_nodes(db, west, south, east, north, transport_type=None, node_type=None): + """Fetch nodes from database for a bounding box. Args: db: Database connection + west, south, east, north: Bounding box coordinates transport_type: Filter by transport type (auto, rail, sea, air) node_type: Type of nodes to fetch ('logistics', 'offer', 'supplier') """ - cache_key = f"nodes:{transport_type or 'all'}:{node_type or 'logistics'}" + bind_vars = { + 'west': west, + 'south': south, + 'east': east, + 'north': north, + } - with _cache_lock: - if cache_key not in _nodes_cache: - # Select AQL query based on node_type - if node_type == 'offer': - aql = """ - FOR node IN nodes - FILTER node.node_type == 'offer' - FILTER node.latitude != null AND node.longitude != null - RETURN node - """ - elif node_type == 'supplier': - # Get suppliers that have offers (aggregate through offers) - aql = """ - FOR offer IN nodes - FILTER offer.node_type == 'offer' - FILTER offer.supplier_uuid != null - LET supplier = DOCUMENT(CONCAT('nodes/', offer.supplier_uuid)) - FILTER supplier != null - FILTER supplier.latitude != null AND supplier.longitude != null - COLLECT sup_uuid = offer.supplier_uuid INTO offers - LET sup = DOCUMENT(CONCAT('nodes/', sup_uuid)) - RETURN { - _key: sup_uuid, - name: sup.name, - latitude: sup.latitude, - longitude: sup.longitude, - country: sup.country, - country_code: sup.country_code, - node_type: 'supplier', - offers_count: LENGTH(offers) - } - """ - else: # logistics (default) - aql = """ - FOR node IN nodes - FILTER node.node_type == 'logistics' OR node.node_type == null - FILTER node.latitude != null AND node.longitude != null - RETURN node - """ + # Select AQL query based on node_type + if node_type == 'offer': + aql = """ + FOR node IN nodes + FILTER node.node_type == 'offer' + FILTER node.latitude != null AND node.longitude != null + FILTER node.latitude >= @south AND node.latitude <= @north + FILTER node.longitude >= @west AND node.longitude <= @east + RETURN node + """ + elif node_type == 'supplier': + # Get suppliers that have offers (aggregate through offers) + aql = """ + FOR offer IN nodes + FILTER offer.node_type == 'offer' + FILTER offer.supplier_uuid != null + LET supplier = DOCUMENT(CONCAT('nodes/', offer.supplier_uuid)) + FILTER supplier != null + FILTER supplier.latitude != null AND supplier.longitude != null + FILTER supplier.latitude >= @south AND supplier.latitude <= @north + FILTER supplier.longitude >= @west AND supplier.longitude <= @east + COLLECT sup_uuid = offer.supplier_uuid INTO offers + LET sup = DOCUMENT(CONCAT('nodes/', sup_uuid)) + RETURN { + _key: sup_uuid, + name: sup.name, + latitude: sup.latitude, + longitude: sup.longitude, + country: sup.country, + country_code: sup.country_code, + node_type: 'supplier', + offers_count: LENGTH(offers) + } + """ + else: # logistics (default) + aql = """ + FOR node IN nodes + FILTER node.node_type == 'logistics' OR node.node_type == null + FILTER node.latitude != null AND node.longitude != null + FILTER node.latitude >= @south AND node.latitude <= @north + FILTER node.longitude >= @west AND node.longitude <= @east + RETURN node + """ - cursor = db.aql.execute(aql) - all_nodes = list(cursor) + cursor = db.aql.execute(aql, bind_vars=bind_vars) + nodes = list(cursor) - # Filter by transport type if specified (only for logistics nodes) - if transport_type and node_type in (None, 'logistics'): - all_nodes = [ - n for n in all_nodes - if transport_type in (n.get('transport_types') or []) - ] + # Filter by transport type if specified (only for logistics nodes) + if transport_type and node_type in (None, 'logistics'): + nodes = [ + n for n in nodes + if transport_type in (n.get('transport_types') or []) + ] - _nodes_cache[cache_key] = all_nodes - logger.info("Cached %d nodes for %s", len(all_nodes), cache_key) - - return _nodes_cache[cache_key] + return nodes def get_clustered_nodes(db, west, south, east, north, zoom, transport_type=None, node_type=None): @@ -102,7 +104,7 @@ def get_clustered_nodes(db, west, south, east, north, zoom, transport_type=None, node_type: Type of nodes ('logistics', 'offer', 'supplier') """ resolution = ZOOM_TO_RES.get(int(zoom), 5) - nodes = _fetch_nodes(db, transport_type, node_type) + nodes = _fetch_nodes(db, west, south, east, north, transport_type, node_type) if not nodes: return [] @@ -113,10 +115,6 @@ def get_clustered_nodes(db, west, south, east, north, zoom, transport_type=None, lat = node.get('latitude') lng = node.get('longitude') - # Skip nodes outside bounding box (rough filter) - if lat < south or lat > north or lng < west or lng > east: - continue - cell = h3.latlng_to_cell(lat, lng, resolution) if cell not in cells: cells[cell] = [] @@ -154,14 +152,3 @@ def get_clustered_nodes(db, west, south, east, north, zoom, transport_type=None, return results -def invalidate_cache(transport_type=None): - """Invalidate node cache after data changes.""" - with _cache_lock: - if transport_type: - cache_key = f"nodes:{transport_type}" - if cache_key in _nodes_cache: - del _nodes_cache[cache_key] - else: - _nodes_cache.clear() - - logger.info("Cluster cache invalidated") diff --git a/geo_app/schema.py b/geo_app/schema.py index a401c3e..61f268c 100644 --- a/geo_app/schema.py +++ b/geo_app/schema.py @@ -6,7 +6,7 @@ import requests import graphene from django.conf import settings from .arango_client import get_db, ensure_graph -from .cluster_index import get_clustered_nodes, invalidate_cache +from .cluster_index import get_clustered_nodes logger = logging.getLogger(__name__)