Initial commit
This commit is contained in:
72
dagster_project/assets/net_detection.py
Normal file
72
dagster_project/assets/net_detection.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""Asset: Detect tennis/pickleball net using Roboflow"""
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
from dagster import asset, AssetExecutionContext
|
||||
from inference_sdk import InferenceHTTPClient
|
||||
|
||||
|
||||
@asset(
|
||||
io_manager_key="json_io_manager",
|
||||
compute_kind="roboflow",
|
||||
description="Detect pickleball/tennis net using Roboflow model"
|
||||
)
|
||||
def detect_net(
|
||||
context: AssetExecutionContext,
|
||||
extract_video_frames: Dict,
|
||||
detect_court_keypoints: Dict
|
||||
) -> Dict:
|
||||
"""
|
||||
Detect net on first frame using Roboflow model
|
||||
|
||||
NO FALLBACKS - if model doesn't detect net, this will fail
|
||||
|
||||
Inputs:
|
||||
- extract_video_frames: frame metadata
|
||||
- detect_court_keypoints: court corners (for visualization)
|
||||
|
||||
Outputs:
|
||||
- data/{run_id}/net_detection_preview.jpg: visualization
|
||||
- JSON with net detection results
|
||||
|
||||
Returns:
|
||||
Dict with net detection data
|
||||
"""
|
||||
run_id = context.run_id
|
||||
frames_dir = Path(extract_video_frames['frames_dir'])
|
||||
first_frame_path = frames_dir / "frame_0000.jpg"
|
||||
|
||||
context.log.info(f"Loading first frame: {first_frame_path}")
|
||||
|
||||
# Load frame
|
||||
frame = cv2.imread(str(first_frame_path))
|
||||
h, w = frame.shape[:2]
|
||||
context.log.info(f"Frame dimensions: {w}x{h}")
|
||||
|
||||
# Get API key
|
||||
api_key = os.getenv("ROBOFLOW_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError("ROBOFLOW_API_KEY environment variable is not set")
|
||||
|
||||
context.log.info("Detecting net using Roboflow model...")
|
||||
|
||||
client = InferenceHTTPClient(
|
||||
api_url="https://serverless.roboflow.com",
|
||||
api_key=api_key
|
||||
)
|
||||
|
||||
# Call Roboflow model - MODEL_ID WILL BE PROVIDED BY USER
|
||||
# Placeholder - user will provide correct model
|
||||
model_id = "MODEL_ID_PLACEHOLDER"
|
||||
|
||||
result = client.infer(str(first_frame_path), model_id=model_id)
|
||||
|
||||
context.log.info(f"Roboflow response: {result}")
|
||||
|
||||
# TODO: Parse result based on actual model output format
|
||||
# User will provide correct model and we'll update parsing logic
|
||||
|
||||
raise NotImplementedError("Waiting for correct Roboflow model from user")
|
||||
Reference in New Issue
Block a user