Files
pickle_vision/dagster_project/assets/ball_detection.py
Ruslan Bakiev 549fd1da9d Initial commit
2026-03-06 09:43:52 +07:00

182 lines
5.5 KiB
Python

"""Asset 3: Detect ball positions using YOLO"""
import sys
import cv2
from pathlib import Path
from typing import Dict, List, Optional
from dagster import asset, AssetExecutionContext
from tqdm import tqdm
# Add src to path to import existing ball detector
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
@asset(
io_manager_key="json_io_manager",
compute_kind="yolo",
description="Detect ball positions on all frames using YOLO"
)
def detect_ball_positions(
context: AssetExecutionContext,
extract_video_frames: Dict
) -> List[Dict]:
"""
Detect ball positions on all extracted frames
Inputs:
- extract_video_frames: metadata from frame extraction
- data/frames/*.jpg: all extracted frames
Outputs:
- data/detect_ball_positions.json
Returns:
List of dicts with:
- frame: frame number
- x: pixel x coordinate (or None if not detected)
- y: pixel y coordinate (or None if not detected)
- confidence: detection confidence (0-1)
- diameter_px: estimated ball diameter in pixels
"""
from src.ball_detector import BallDetector
frames_dir = Path(extract_video_frames['frames_dir'])
num_frames = extract_video_frames['num_frames']
context.log.info(f"Initializing YOLO ball detector...")
# Initialize detector
detector = BallDetector(
model_id="pickleball-moving-ball/5",
confidence_threshold=0.3, # Lower threshold to catch more detections
slice_enabled=False # Disable slicing for faster Hosted API inference
)
context.log.info(f"Processing {num_frames} frames for ball detection...")
detections = []
frames_with_ball = 0
for i in tqdm(range(num_frames), desc="Detecting ball"):
frame_path = frames_dir / f"frame_{i:04d}.jpg"
if not frame_path.exists():
context.log.warning(f"Frame {i} not found: {frame_path}")
detections.append({
"frame": i,
"x": None,
"y": None,
"confidence": 0.0,
"diameter_px": None,
"bbox": None
})
continue
# Load frame
frame = cv2.imread(str(frame_path))
# Detect ball
results = detector.detect(frame)
if results and len(results) > 0:
# Take highest confidence detection
ball = results[0]
# Calculate diameter from bbox
bbox = ball.get('bbox')
diameter_px = None
if bbox:
width = bbox[2] - bbox[0]
height = bbox[3] - bbox[1]
diameter_px = (width + height) / 2
detections.append({
"frame": i,
"x": float(ball['center'][0]),
"y": float(ball['center'][1]),
"confidence": float(ball['confidence']),
"diameter_px": float(diameter_px) if diameter_px else None,
"bbox": [float(b) for b in bbox] if bbox else None
})
frames_with_ball += 1
else:
# No detection
detections.append({
"frame": i,
"x": None,
"y": None,
"confidence": 0.0,
"diameter_px": None,
"bbox": None
})
# Log progress every 20 frames
if (i + 1) % 20 == 0:
detection_rate = frames_with_ball / (i + 1) * 100
context.log.info(f"Processed {i + 1}/{num_frames} frames. Detection rate: {detection_rate:.1f}%")
detection_rate = frames_with_ball / num_frames * 100
context.log.info(f"✓ Ball detected in {frames_with_ball}/{num_frames} frames ({detection_rate:.1f}%)")
# Save ALL detection images
_save_detection_preview(context, frames_dir, detections, num_preview=999)
return detections
def _save_detection_preview(
context: AssetExecutionContext,
frames_dir: Path,
detections: List[Dict],
num_preview: int = 5
):
"""Save preview images showing ball detections"""
run_id = context.run_id
preview_dir = Path(f"data/{run_id}/ball_detections")
preview_dir.mkdir(parents=True, exist_ok=True)
# Find first N frames with detections
detected_frames = [d for d in detections if d['x'] is not None][:num_preview]
for detection in detected_frames:
frame_num = detection['frame']
frame_path = frames_dir / f"frame_{frame_num:04d}.jpg"
if not frame_path.exists():
continue
frame = cv2.imread(str(frame_path))
# Draw ball
x, y = int(detection['x']), int(detection['y'])
cv2.circle(frame, (x, y), 8, (0, 0, 255), -1)
# Draw bbox if available
if detection['bbox']:
bbox = detection['bbox']
cv2.rectangle(
frame,
(int(bbox[0]), int(bbox[1])),
(int(bbox[2]), int(bbox[3])),
(0, 255, 0),
2
)
# Draw confidence
cv2.putText(
frame,
f"Conf: {detection['confidence']:.2f}",
(x + 15, y - 10),
cv2.FONT_HERSHEY_SIMPLEX,
0.6,
(255, 255, 255),
2
)
# Save preview
preview_path = preview_dir / f"detection_frame_{frame_num:04d}.jpg"
cv2.imwrite(str(preview_path), frame)
context.log.info(f"Saved {len(detected_frames)} preview images to {preview_dir}")