diff --git a/geo_app/schema.py b/geo_app/schema.py index 91f9c38..e4271c4 100644 --- a/geo_app/schema.py +++ b/geo_app/schema.py @@ -1474,87 +1474,87 @@ class Query(graphene.ObjectType): nodes_col = db.collection('nodes') start = nodes_col.get(source_uuid) if not start: - logger.warning("Source node %s not found for nearest hubs", source_uuid) - return [] + logger.warning("Source node %s not found for nearest hubs, falling back to radius search", source_uuid) + source_uuid = None + else: + def is_target_hub(doc): + if doc.get('_key') == source_uuid: + return False + if doc.get('node_type') not in ('logistics', None): + return False + types = doc.get('transport_types') or [] + return ('rail' in types) or ('sea' in types) - def is_target_hub(doc): - if doc.get('_key') == source_uuid: - return False - if doc.get('node_type') not in ('logistics', None): - return False - types = doc.get('transport_types') or [] - return ('rail' in types) or ('sea' in types) + def fetch_neighbors(node_key): + aql = """ + FOR edge IN edges + FILTER edge.transport_type IN ['auto', 'rail', 'offer'] + FILTER edge._from == @node_id OR edge._to == @node_id + LET neighbor_id = edge._from == @node_id ? edge._to : edge._from + LET neighbor = DOCUMENT(neighbor_id) + FILTER neighbor != null + RETURN { + neighbor_key: neighbor._key, + neighbor_doc: neighbor, + transport_type: edge.transport_type, + distance_km: edge.distance_km, + travel_time_seconds: edge.travel_time_seconds + } + """ + cursor = db.aql.execute( + aql, + bind_vars={'node_id': f"nodes/{node_key}"}, + ) + return list(cursor) - def fetch_neighbors(node_key): - aql = """ - FOR edge IN edges - FILTER edge.transport_type IN ['auto', 'rail', 'offer'] - FILTER edge._from == @node_id OR edge._to == @node_id - LET neighbor_id = edge._from == @node_id ? edge._to : edge._from - LET neighbor = DOCUMENT(neighbor_id) - FILTER neighbor != null - RETURN { - neighbor_key: neighbor._key, - neighbor_doc: neighbor, - transport_type: edge.transport_type, - distance_km: edge.distance_km, - travel_time_seconds: edge.travel_time_seconds - } - """ - cursor = db.aql.execute( - aql, - bind_vars={'node_id': f"nodes/{node_key}"}, - ) - return list(cursor) + queue = [] + counter = 0 + heapq.heappush(queue, (0, counter, source_uuid)) + visited = {} + node_docs = {source_uuid: start} + found = [] + expansions = 0 - queue = [] - counter = 0 - heapq.heappush(queue, (0, counter, source_uuid)) - visited = {} - node_docs = {source_uuid: start} - found = [] - expansions = 0 - - while queue and len(found) < limit and expansions < Query.MAX_EXPANSIONS: - cost, _, node_key = heapq.heappop(queue) - if node_key in visited and cost > visited[node_key]: - continue - visited[node_key] = cost - - node_doc = node_docs.get(node_key) - if node_doc and is_target_hub(node_doc): - found.append(node_doc) - if len(found) >= limit: - break - - neighbors = fetch_neighbors(node_key) - expansions += 1 - for neighbor in neighbors: - neighbor_key = neighbor.get('neighbor_key') - if not neighbor_key: + while queue and len(found) < limit and expansions < Query.MAX_EXPANSIONS: + cost, _, node_key = heapq.heappop(queue) + if node_key in visited and cost > visited[node_key]: continue - node_docs[neighbor_key] = neighbor.get('neighbor_doc') - step_cost = neighbor.get('travel_time_seconds') or neighbor.get('distance_km') or 0 - new_cost = cost + step_cost - if neighbor_key in visited and new_cost >= visited[neighbor_key]: - continue - counter += 1 - heapq.heappush(queue, (new_cost, counter, neighbor_key)) + visited[node_key] = cost - hubs = [] - for node in found: - hubs.append(NodeType( - uuid=node.get('_key'), - name=node.get('name'), - latitude=node.get('latitude'), - longitude=node.get('longitude'), - country=node.get('country'), - country_code=node.get('country_code'), - synced_at=node.get('synced_at'), - transport_types=node.get('transport_types') or [], - edges=[], - )) - return hubs + node_doc = node_docs.get(node_key) + if node_doc and is_target_hub(node_doc): + found.append(node_doc) + if len(found) >= limit: + break + + neighbors = fetch_neighbors(node_key) + expansions += 1 + for neighbor in neighbors: + neighbor_key = neighbor.get('neighbor_key') + if not neighbor_key: + continue + node_docs[neighbor_key] = neighbor.get('neighbor_doc') + step_cost = neighbor.get('travel_time_seconds') or neighbor.get('distance_km') or 0 + new_cost = cost + step_cost + if neighbor_key in visited and new_cost >= visited[neighbor_key]: + continue + counter += 1 + heapq.heappush(queue, (new_cost, counter, neighbor_key)) + + hubs = [] + for node in found: + hubs.append(NodeType( + uuid=node.get('_key'), + name=node.get('name'), + latitude=node.get('latitude'), + longitude=node.get('longitude'), + country=node.get('country'), + country_code=node.get('country_code'), + synced_at=node.get('synced_at'), + transport_types=node.get('transport_types') or [], + edges=[], + )) + return hubs if product_uuid: # Find hubs that have offers for this product within radius @@ -1626,6 +1626,46 @@ class Query(graphene.ObjectType): db = get_db() ensure_graph() + # If hub_uuid + product_uuid provided, use graph search to return only offers with routes. + if hub_uuid and product_uuid: + try: + nodes_col = db.collection('nodes') + expanded_limit = max(limit * 5, limit) + route_options = Query.resolve_offers_by_hub( + Query, info, hub_uuid, product_uuid, expanded_limit + ) + offers = [] + for option in route_options or []: + if not option.routes: + continue + node = nodes_col.get(option.source_uuid) + if not node: + continue + offers.append(OfferWithRouteType( + uuid=node['_key'], + product_uuid=node.get('product_uuid'), + product_name=node.get('product_name'), + supplier_uuid=node.get('supplier_uuid'), + supplier_name=node.get('supplier_name'), + latitude=node.get('latitude'), + longitude=node.get('longitude'), + country=node.get('country'), + country_code=node.get('country_code'), + price_per_unit=node.get('price_per_unit'), + currency=node.get('currency'), + quantity=node.get('quantity'), + unit=node.get('unit'), + distance_km=option.distance_km, + routes=option.routes, + )) + if len(offers) >= limit: + break + logger.info("Found %d offers by graph for hub %s", len(offers), hub_uuid) + return offers + except Exception as e: + logger.error("Error finding offers by hub %s: %s", hub_uuid, e) + return [] + aql = """ FOR offer IN nodes FILTER offer.node_type == 'offer'