diff --git a/offers/management/commands/seed_exchange.py b/offers/management/commands/seed_exchange.py index 0c66e5d..5392a87 100644 --- a/offers/management/commands/seed_exchange.py +++ b/offers/management/commands/seed_exchange.py @@ -2,8 +2,11 @@ Seed Suppliers and Offers for African cocoa belt. Creates offers via Temporal workflow so they sync to the graph. """ +import csv +import os import random import uuid +from pathlib import Path from decimal import Decimal import time @@ -65,6 +68,9 @@ SUPPLIER_NAMES = [ "Dar Coast Commodities", "Maputo Export House", ] +# Default GLEIF Africa LEI dataset path (repo-local) +DEFAULT_GLEIF_PATH = "datasets/gleif/africa_lei_companies.csv" + # Fixed product catalog (10 items) with realistic prices per ton (USD) PRODUCT_CATALOG = [ {"name": "Cocoa Beans", "category": "Cocoa", "price": Decimal("2450.00")}, @@ -138,8 +144,8 @@ class Command(BaseCommand): parser.add_argument( "--geo-url", type=str, - default="http://geo:8000/graphql/public/", - help="Geo service GraphQL URL (default: http://geo:8000/graphql/public/)", + default=None, + help="Geo service GraphQL URL (defaults to GEO_INTERNAL_URL env var)", ) parser.add_argument( "--odoo-url", @@ -176,6 +182,12 @@ class Command(BaseCommand): default=None, help="Filter offers by product name (e.g., 'Cocoa Beans')", ) + parser.add_argument( + "--company-csv", + type=str, + default=None, + help="Path to CSV with real company names (default: datasets/gleif/africa_lei_companies.csv)", + ) def handle(self, *args, **options): if options["clear"]: @@ -194,13 +206,17 @@ class Command(BaseCommand): use_bulk = options["bulk"] bulk_size = max(1, options["bulk_size"]) sleep_ms = max(0, options["sleep_ms"]) - geo_url = options["geo_url"] + geo_url = options["geo_url"] or os.getenv("GEO_INTERNAL_URL") or os.getenv("GEO_URL") + if not geo_url: + self.stdout.write(self.style.ERROR("Geo URL is not set. Provide --geo-url or GEO_INTERNAL_URL.")) + return odoo_url = options["odoo_url"] product_filter = options["product"] ensure_products = options["ensure_products"] odoo_db = options["odoo_db"] odoo_user = options["odoo_user"] odoo_password = options["odoo_password"] + company_csv = options["company_csv"] # Fetch products from Odoo self.stdout.write("Fetching products from Odoo...") @@ -233,14 +249,13 @@ class Command(BaseCommand): hubs = self._fetch_african_hubs(geo_url) if not hubs: - self.stdout.write(self.style.WARNING( - "No African hubs found. Using default locations." - )) - hubs = self._default_african_hubs() + self.stdout.write(self.style.ERROR("No African hubs found from geo service. Aborting seed.")) + return self.stdout.write(f"Found {len(hubs)} African hubs") # Create suppliers + self._company_pool = self._load_company_pool(company_csv) self.stdout.write(f"Creating {suppliers_count} suppliers...") new_suppliers = self._create_suppliers(suppliers_count, hubs) self.stdout.write(self.style.SUCCESS(f"Created {len(new_suppliers)} suppliers")) @@ -572,15 +587,28 @@ class Command(BaseCommand): lat += random.uniform(-0.5, 0.5) lng += random.uniform(-0.5, 0.5) - name = self._generate_supplier_name(idx) + company = self._pick_company(idx) + if company: + name = company["name"] + company_code = company.get("country_code") + mapped_name = self._country_name_from_code(company_code) + if mapped_name: + country = mapped_name + country_code = company_code + supplier_uuid = self._stable_uuid("supplier", company.get("lei") or name) + team_uuid = self._stable_uuid("team", company.get("lei") or name) + else: + name = self._generate_supplier_name(idx) + supplier_uuid = str(uuid.uuid4()) + team_uuid = str(uuid.uuid4()) description = ( f"{name} is a reliable supplier based in {country}, " "focused on consistent quality and transparent logistics." ) profile = SupplierProfile.objects.create( - uuid=str(uuid.uuid4()), - team_uuid=str(uuid.uuid4()), + uuid=supplier_uuid, + team_uuid=team_uuid, name=name, description=description, country=country, @@ -600,6 +628,68 @@ class Command(BaseCommand): return SUPPLIER_NAMES[index] return f"{random.choice(SUPPLIER_NAMES)} Group" + def _find_default_company_csv(self) -> str | None: + """Locate default company CSV in repo (datasets/gleif/africa_lei_companies.csv).""" + here = Path(__file__).resolve() + for parent in here.parents: + candidate = parent / DEFAULT_GLEIF_PATH + if candidate.exists(): + return str(candidate) + return None + + def _load_company_pool(self, csv_path: str | None) -> list[dict]: + """Load real company names from CSV; returns list of dicts.""" + path = csv_path or self._find_default_company_csv() + if not path or not os.path.exists(path): + self.stdout.write(self.style.WARNING("Company CSV not found; using fallback names.")) + return [] + + companies = [] + seen = set() + try: + with open(path, newline="", encoding="utf-8") as f: + reader = csv.DictReader(f) + for row in reader: + name = (row.get("entity_name") or "").strip() + if not name: + continue + if name in seen: + continue + seen.add(name) + companies.append( + { + "name": name, + "lei": (row.get("lei") or "").strip(), + "country_code": (row.get("legal_address_country") or row.get("headquarters_country") or "").strip(), + "city": (row.get("legal_address_city") or row.get("headquarters_city") or "").strip(), + } + ) + except Exception as e: + self.stdout.write(self.style.WARNING(f"Failed to read company CSV: {e}")) + return [] + + random.shuffle(companies) + self.stdout.write(f"Loaded {len(companies)} company names from CSV") + return companies + + def _pick_company(self, index: int) -> dict | None: + if not getattr(self, "_company_pool", None): + return None + if index < len(self._company_pool): + return self._company_pool[index] + return random.choice(self._company_pool) + + def _stable_uuid(self, prefix: str, value: str) -> str: + return str(uuid.uuid5(uuid.NAMESPACE_DNS, f"{prefix}:{value}")) + + def _country_name_from_code(self, code: str | None) -> str | None: + if not code: + return None + for name, country_code, _, _ in AFRICAN_COUNTRIES: + if country_code == code: + return name + return None + def _price_for_product(self, product_name: str) -> Decimal: for item in PRODUCT_CATALOG: if item["name"].lower() == product_name.lower():