"""Asset 3: Detect ball positions using YOLO""" import sys import cv2 from pathlib import Path from typing import Dict, List, Optional from dagster import asset, AssetExecutionContext from tqdm import tqdm # Add src to path to import existing ball detector sys.path.insert(0, str(Path(__file__).parent.parent.parent)) @asset( io_manager_key="json_io_manager", compute_kind="yolo", description="Detect ball positions on all frames using YOLO" ) def detect_ball_positions( context: AssetExecutionContext, extract_video_frames: Dict ) -> List[Dict]: """ Detect ball positions on all extracted frames Inputs: - extract_video_frames: metadata from frame extraction - data/frames/*.jpg: all extracted frames Outputs: - data/detect_ball_positions.json Returns: List of dicts with: - frame: frame number - x: pixel x coordinate (or None if not detected) - y: pixel y coordinate (or None if not detected) - confidence: detection confidence (0-1) - diameter_px: estimated ball diameter in pixels """ from src.ball_detector import BallDetector frames_dir = Path(extract_video_frames['frames_dir']) num_frames = extract_video_frames['num_frames'] context.log.info(f"Initializing YOLO ball detector...") # Initialize detector detector = BallDetector( model_id="pickleball-moving-ball/5", confidence_threshold=0.3, # Lower threshold to catch more detections slice_enabled=False # Disable slicing for faster Hosted API inference ) context.log.info(f"Processing {num_frames} frames for ball detection...") detections = [] frames_with_ball = 0 for i in tqdm(range(num_frames), desc="Detecting ball"): frame_path = frames_dir / f"frame_{i:04d}.jpg" if not frame_path.exists(): context.log.warning(f"Frame {i} not found: {frame_path}") detections.append({ "frame": i, "x": None, "y": None, "confidence": 0.0, "diameter_px": None, "bbox": None }) continue # Load frame frame = cv2.imread(str(frame_path)) # Detect ball results = detector.detect(frame) if results and len(results) > 0: # Take highest confidence detection ball = results[0] # Calculate diameter from bbox bbox = ball.get('bbox') diameter_px = None if bbox: width = bbox[2] - bbox[0] height = bbox[3] - bbox[1] diameter_px = (width + height) / 2 detections.append({ "frame": i, "x": float(ball['center'][0]), "y": float(ball['center'][1]), "confidence": float(ball['confidence']), "diameter_px": float(diameter_px) if diameter_px else None, "bbox": [float(b) for b in bbox] if bbox else None }) frames_with_ball += 1 else: # No detection detections.append({ "frame": i, "x": None, "y": None, "confidence": 0.0, "diameter_px": None, "bbox": None }) # Log progress every 20 frames if (i + 1) % 20 == 0: detection_rate = frames_with_ball / (i + 1) * 100 context.log.info(f"Processed {i + 1}/{num_frames} frames. Detection rate: {detection_rate:.1f}%") detection_rate = frames_with_ball / num_frames * 100 context.log.info(f"✓ Ball detected in {frames_with_ball}/{num_frames} frames ({detection_rate:.1f}%)") # Save ALL detection images _save_detection_preview(context, frames_dir, detections, num_preview=999) return detections def _save_detection_preview( context: AssetExecutionContext, frames_dir: Path, detections: List[Dict], num_preview: int = 5 ): """Save preview images showing ball detections""" run_id = context.run_id preview_dir = Path(f"data/{run_id}/ball_detections") preview_dir.mkdir(parents=True, exist_ok=True) # Find first N frames with detections detected_frames = [d for d in detections if d['x'] is not None][:num_preview] for detection in detected_frames: frame_num = detection['frame'] frame_path = frames_dir / f"frame_{frame_num:04d}.jpg" if not frame_path.exists(): continue frame = cv2.imread(str(frame_path)) # Draw ball x, y = int(detection['x']), int(detection['y']) cv2.circle(frame, (x, y), 8, (0, 0, 255), -1) # Draw bbox if available if detection['bbox']: bbox = detection['bbox'] cv2.rectangle( frame, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), (0, 255, 0), 2 ) # Draw confidence cv2.putText( frame, f"Conf: {detection['confidence']:.2f}", (x + 15, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2 ) # Save preview preview_path = preview_dir / f"detection_frame_{frame_num:04d}.jpg" cv2.imwrite(str(preview_path), frame) context.log.info(f"Saved {len(detected_frames)} preview images to {preview_dir}")