Initial commit
This commit is contained in:
181
dagster_project/assets/ball_detection.py
Normal file
181
dagster_project/assets/ball_detection.py
Normal file
@@ -0,0 +1,181 @@
|
||||
"""Asset 3: Detect ball positions using YOLO"""
|
||||
|
||||
import sys
|
||||
import cv2
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
from dagster import asset, AssetExecutionContext
|
||||
from tqdm import tqdm
|
||||
|
||||
# Add src to path to import existing ball detector
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
|
||||
@asset(
|
||||
io_manager_key="json_io_manager",
|
||||
compute_kind="yolo",
|
||||
description="Detect ball positions on all frames using YOLO"
|
||||
)
|
||||
def detect_ball_positions(
|
||||
context: AssetExecutionContext,
|
||||
extract_video_frames: Dict
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Detect ball positions on all extracted frames
|
||||
|
||||
Inputs:
|
||||
- extract_video_frames: metadata from frame extraction
|
||||
- data/frames/*.jpg: all extracted frames
|
||||
|
||||
Outputs:
|
||||
- data/detect_ball_positions.json
|
||||
|
||||
Returns:
|
||||
List of dicts with:
|
||||
- frame: frame number
|
||||
- x: pixel x coordinate (or None if not detected)
|
||||
- y: pixel y coordinate (or None if not detected)
|
||||
- confidence: detection confidence (0-1)
|
||||
- diameter_px: estimated ball diameter in pixels
|
||||
"""
|
||||
from src.ball_detector import BallDetector
|
||||
|
||||
frames_dir = Path(extract_video_frames['frames_dir'])
|
||||
num_frames = extract_video_frames['num_frames']
|
||||
|
||||
context.log.info(f"Initializing YOLO ball detector...")
|
||||
|
||||
# Initialize detector
|
||||
detector = BallDetector(
|
||||
model_id="pickleball-moving-ball/5",
|
||||
confidence_threshold=0.3, # Lower threshold to catch more detections
|
||||
slice_enabled=False # Disable slicing for faster Hosted API inference
|
||||
)
|
||||
|
||||
context.log.info(f"Processing {num_frames} frames for ball detection...")
|
||||
|
||||
detections = []
|
||||
frames_with_ball = 0
|
||||
|
||||
for i in tqdm(range(num_frames), desc="Detecting ball"):
|
||||
frame_path = frames_dir / f"frame_{i:04d}.jpg"
|
||||
|
||||
if not frame_path.exists():
|
||||
context.log.warning(f"Frame {i} not found: {frame_path}")
|
||||
detections.append({
|
||||
"frame": i,
|
||||
"x": None,
|
||||
"y": None,
|
||||
"confidence": 0.0,
|
||||
"diameter_px": None,
|
||||
"bbox": None
|
||||
})
|
||||
continue
|
||||
|
||||
# Load frame
|
||||
frame = cv2.imread(str(frame_path))
|
||||
|
||||
# Detect ball
|
||||
results = detector.detect(frame)
|
||||
|
||||
if results and len(results) > 0:
|
||||
# Take highest confidence detection
|
||||
ball = results[0]
|
||||
|
||||
# Calculate diameter from bbox
|
||||
bbox = ball.get('bbox')
|
||||
diameter_px = None
|
||||
if bbox:
|
||||
width = bbox[2] - bbox[0]
|
||||
height = bbox[3] - bbox[1]
|
||||
diameter_px = (width + height) / 2
|
||||
|
||||
detections.append({
|
||||
"frame": i,
|
||||
"x": float(ball['center'][0]),
|
||||
"y": float(ball['center'][1]),
|
||||
"confidence": float(ball['confidence']),
|
||||
"diameter_px": float(diameter_px) if diameter_px else None,
|
||||
"bbox": [float(b) for b in bbox] if bbox else None
|
||||
})
|
||||
|
||||
frames_with_ball += 1
|
||||
else:
|
||||
# No detection
|
||||
detections.append({
|
||||
"frame": i,
|
||||
"x": None,
|
||||
"y": None,
|
||||
"confidence": 0.0,
|
||||
"diameter_px": None,
|
||||
"bbox": None
|
||||
})
|
||||
|
||||
# Log progress every 20 frames
|
||||
if (i + 1) % 20 == 0:
|
||||
detection_rate = frames_with_ball / (i + 1) * 100
|
||||
context.log.info(f"Processed {i + 1}/{num_frames} frames. Detection rate: {detection_rate:.1f}%")
|
||||
|
||||
detection_rate = frames_with_ball / num_frames * 100
|
||||
context.log.info(f"✓ Ball detected in {frames_with_ball}/{num_frames} frames ({detection_rate:.1f}%)")
|
||||
|
||||
# Save ALL detection images
|
||||
_save_detection_preview(context, frames_dir, detections, num_preview=999)
|
||||
|
||||
return detections
|
||||
|
||||
|
||||
def _save_detection_preview(
|
||||
context: AssetExecutionContext,
|
||||
frames_dir: Path,
|
||||
detections: List[Dict],
|
||||
num_preview: int = 5
|
||||
):
|
||||
"""Save preview images showing ball detections"""
|
||||
run_id = context.run_id
|
||||
preview_dir = Path(f"data/{run_id}/ball_detections")
|
||||
preview_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Find first N frames with detections
|
||||
detected_frames = [d for d in detections if d['x'] is not None][:num_preview]
|
||||
|
||||
for detection in detected_frames:
|
||||
frame_num = detection['frame']
|
||||
frame_path = frames_dir / f"frame_{frame_num:04d}.jpg"
|
||||
|
||||
if not frame_path.exists():
|
||||
continue
|
||||
|
||||
frame = cv2.imread(str(frame_path))
|
||||
|
||||
# Draw ball
|
||||
x, y = int(detection['x']), int(detection['y'])
|
||||
cv2.circle(frame, (x, y), 8, (0, 0, 255), -1)
|
||||
|
||||
# Draw bbox if available
|
||||
if detection['bbox']:
|
||||
bbox = detection['bbox']
|
||||
cv2.rectangle(
|
||||
frame,
|
||||
(int(bbox[0]), int(bbox[1])),
|
||||
(int(bbox[2]), int(bbox[3])),
|
||||
(0, 255, 0),
|
||||
2
|
||||
)
|
||||
|
||||
# Draw confidence
|
||||
cv2.putText(
|
||||
frame,
|
||||
f"Conf: {detection['confidence']:.2f}",
|
||||
(x + 15, y - 10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.6,
|
||||
(255, 255, 255),
|
||||
2
|
||||
)
|
||||
|
||||
# Save preview
|
||||
preview_path = preview_dir / f"detection_frame_{frame_num:04d}.jpg"
|
||||
cv2.imwrite(str(preview_path), frame)
|
||||
|
||||
context.log.info(f"Saved {len(detected_frames)} preview images to {preview_dir}")
|
||||
Reference in New Issue
Block a user