Files
pickle_vision/dagster_project/assets/net_detection.py
Ruslan Bakiev 549fd1da9d Initial commit
2026-03-06 09:43:52 +07:00

73 lines
2.1 KiB
Python

"""Asset: Detect tennis/pickleball net using Roboflow"""
import os
import cv2
import numpy as np
from pathlib import Path
from typing import Dict
from dagster import asset, AssetExecutionContext
from inference_sdk import InferenceHTTPClient
@asset(
io_manager_key="json_io_manager",
compute_kind="roboflow",
description="Detect pickleball/tennis net using Roboflow model"
)
def detect_net(
context: AssetExecutionContext,
extract_video_frames: Dict,
detect_court_keypoints: Dict
) -> Dict:
"""
Detect net on first frame using Roboflow model
NO FALLBACKS - if model doesn't detect net, this will fail
Inputs:
- extract_video_frames: frame metadata
- detect_court_keypoints: court corners (for visualization)
Outputs:
- data/{run_id}/net_detection_preview.jpg: visualization
- JSON with net detection results
Returns:
Dict with net detection data
"""
run_id = context.run_id
frames_dir = Path(extract_video_frames['frames_dir'])
first_frame_path = frames_dir / "frame_0000.jpg"
context.log.info(f"Loading first frame: {first_frame_path}")
# Load frame
frame = cv2.imread(str(first_frame_path))
h, w = frame.shape[:2]
context.log.info(f"Frame dimensions: {w}x{h}")
# Get API key
api_key = os.getenv("ROBOFLOW_API_KEY")
if not api_key:
raise ValueError("ROBOFLOW_API_KEY environment variable is not set")
context.log.info("Detecting net using Roboflow model...")
client = InferenceHTTPClient(
api_url="https://serverless.roboflow.com",
api_key=api_key
)
# Call Roboflow model - MODEL_ID WILL BE PROVIDED BY USER
# Placeholder - user will provide correct model
model_id = "MODEL_ID_PLACEHOLDER"
result = client.infer(str(first_frame_path), model_id=model_id)
context.log.info(f"Roboflow response: {result}")
# TODO: Parse result based on actual model output format
# User will provide correct model and we'll update parsing logic
raise NotImplementedError("Waiting for correct Roboflow model from user")