182 lines
5.5 KiB
Python
182 lines
5.5 KiB
Python
"""Asset 3: Detect ball positions using YOLO"""
|
|
|
|
import sys
|
|
import cv2
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional
|
|
from dagster import asset, AssetExecutionContext
|
|
from tqdm import tqdm
|
|
|
|
# Add src to path to import existing ball detector
|
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
|
|
|
|
|
@asset(
|
|
io_manager_key="json_io_manager",
|
|
compute_kind="yolo",
|
|
description="Detect ball positions on all frames using YOLO"
|
|
)
|
|
def detect_ball_positions(
|
|
context: AssetExecutionContext,
|
|
extract_video_frames: Dict
|
|
) -> List[Dict]:
|
|
"""
|
|
Detect ball positions on all extracted frames
|
|
|
|
Inputs:
|
|
- extract_video_frames: metadata from frame extraction
|
|
- data/frames/*.jpg: all extracted frames
|
|
|
|
Outputs:
|
|
- data/detect_ball_positions.json
|
|
|
|
Returns:
|
|
List of dicts with:
|
|
- frame: frame number
|
|
- x: pixel x coordinate (or None if not detected)
|
|
- y: pixel y coordinate (or None if not detected)
|
|
- confidence: detection confidence (0-1)
|
|
- diameter_px: estimated ball diameter in pixels
|
|
"""
|
|
from src.ball_detector import BallDetector
|
|
|
|
frames_dir = Path(extract_video_frames['frames_dir'])
|
|
num_frames = extract_video_frames['num_frames']
|
|
|
|
context.log.info(f"Initializing YOLO ball detector...")
|
|
|
|
# Initialize detector
|
|
detector = BallDetector(
|
|
model_id="pickleball-moving-ball/5",
|
|
confidence_threshold=0.3, # Lower threshold to catch more detections
|
|
slice_enabled=False # Disable slicing for faster Hosted API inference
|
|
)
|
|
|
|
context.log.info(f"Processing {num_frames} frames for ball detection...")
|
|
|
|
detections = []
|
|
frames_with_ball = 0
|
|
|
|
for i in tqdm(range(num_frames), desc="Detecting ball"):
|
|
frame_path = frames_dir / f"frame_{i:04d}.jpg"
|
|
|
|
if not frame_path.exists():
|
|
context.log.warning(f"Frame {i} not found: {frame_path}")
|
|
detections.append({
|
|
"frame": i,
|
|
"x": None,
|
|
"y": None,
|
|
"confidence": 0.0,
|
|
"diameter_px": None,
|
|
"bbox": None
|
|
})
|
|
continue
|
|
|
|
# Load frame
|
|
frame = cv2.imread(str(frame_path))
|
|
|
|
# Detect ball
|
|
results = detector.detect(frame)
|
|
|
|
if results and len(results) > 0:
|
|
# Take highest confidence detection
|
|
ball = results[0]
|
|
|
|
# Calculate diameter from bbox
|
|
bbox = ball.get('bbox')
|
|
diameter_px = None
|
|
if bbox:
|
|
width = bbox[2] - bbox[0]
|
|
height = bbox[3] - bbox[1]
|
|
diameter_px = (width + height) / 2
|
|
|
|
detections.append({
|
|
"frame": i,
|
|
"x": float(ball['center'][0]),
|
|
"y": float(ball['center'][1]),
|
|
"confidence": float(ball['confidence']),
|
|
"diameter_px": float(diameter_px) if diameter_px else None,
|
|
"bbox": [float(b) for b in bbox] if bbox else None
|
|
})
|
|
|
|
frames_with_ball += 1
|
|
else:
|
|
# No detection
|
|
detections.append({
|
|
"frame": i,
|
|
"x": None,
|
|
"y": None,
|
|
"confidence": 0.0,
|
|
"diameter_px": None,
|
|
"bbox": None
|
|
})
|
|
|
|
# Log progress every 20 frames
|
|
if (i + 1) % 20 == 0:
|
|
detection_rate = frames_with_ball / (i + 1) * 100
|
|
context.log.info(f"Processed {i + 1}/{num_frames} frames. Detection rate: {detection_rate:.1f}%")
|
|
|
|
detection_rate = frames_with_ball / num_frames * 100
|
|
context.log.info(f"✓ Ball detected in {frames_with_ball}/{num_frames} frames ({detection_rate:.1f}%)")
|
|
|
|
# Save ALL detection images
|
|
_save_detection_preview(context, frames_dir, detections, num_preview=999)
|
|
|
|
return detections
|
|
|
|
|
|
def _save_detection_preview(
|
|
context: AssetExecutionContext,
|
|
frames_dir: Path,
|
|
detections: List[Dict],
|
|
num_preview: int = 5
|
|
):
|
|
"""Save preview images showing ball detections"""
|
|
run_id = context.run_id
|
|
preview_dir = Path(f"data/{run_id}/ball_detections")
|
|
preview_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Find first N frames with detections
|
|
detected_frames = [d for d in detections if d['x'] is not None][:num_preview]
|
|
|
|
for detection in detected_frames:
|
|
frame_num = detection['frame']
|
|
frame_path = frames_dir / f"frame_{frame_num:04d}.jpg"
|
|
|
|
if not frame_path.exists():
|
|
continue
|
|
|
|
frame = cv2.imread(str(frame_path))
|
|
|
|
# Draw ball
|
|
x, y = int(detection['x']), int(detection['y'])
|
|
cv2.circle(frame, (x, y), 8, (0, 0, 255), -1)
|
|
|
|
# Draw bbox if available
|
|
if detection['bbox']:
|
|
bbox = detection['bbox']
|
|
cv2.rectangle(
|
|
frame,
|
|
(int(bbox[0]), int(bbox[1])),
|
|
(int(bbox[2]), int(bbox[3])),
|
|
(0, 255, 0),
|
|
2
|
|
)
|
|
|
|
# Draw confidence
|
|
cv2.putText(
|
|
frame,
|
|
f"Conf: {detection['confidence']:.2f}",
|
|
(x + 15, y - 10),
|
|
cv2.FONT_HERSHEY_SIMPLEX,
|
|
0.6,
|
|
(255, 255, 255),
|
|
2
|
|
)
|
|
|
|
# Save preview
|
|
preview_path = preview_dir / f"detection_frame_{frame_num:04d}.jpg"
|
|
cv2.imwrite(str(preview_path), frame)
|
|
|
|
context.log.info(f"Saved {len(detected_frames)} preview images to {preview_dir}")
|