379 lines
12 KiB
Python
379 lines
12 KiB
Python
"""
|
|
Main video processing pipeline
|
|
Combines ball detection, court calibration, and tracking
|
|
"""
|
|
import cv2
|
|
import numpy as np
|
|
import json
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional, Tuple
|
|
from tqdm import tqdm
|
|
import time
|
|
|
|
from .ball_detector import BallDetector
|
|
from .court_calibrator import CourtCalibrator, InteractiveCalibrator
|
|
from .ball_tracker import BallTracker
|
|
|
|
|
|
class VideoProcessor:
|
|
"""
|
|
Main pipeline for processing pickleball videos
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model_id: str = "pickleball-detection-1oqlw/1",
|
|
config_path: Optional[str] = None
|
|
):
|
|
"""
|
|
Initialize video processor
|
|
|
|
Args:
|
|
model_id: Roboflow model ID
|
|
config_path: Path to configuration JSON file
|
|
"""
|
|
# Load config
|
|
self.config = self._load_config(config_path)
|
|
|
|
# Initialize components
|
|
self.detector = BallDetector(
|
|
model_id=model_id,
|
|
confidence_threshold=self.config['detection']['confidence_threshold'],
|
|
slice_enabled=self.config['detection']['slice_enabled'],
|
|
slice_height=self.config['detection']['slice_height'],
|
|
slice_width=self.config['detection']['slice_width']
|
|
)
|
|
|
|
self.calibrator = CourtCalibrator(
|
|
court_width_m=self.config['court']['width_m'],
|
|
court_length_m=self.config['court']['length_m']
|
|
)
|
|
|
|
self.tracker = BallTracker(
|
|
buffer_size=self.config['tracking']['buffer_size'],
|
|
max_distance_threshold=self.config['tracking']['max_distance_threshold']
|
|
)
|
|
|
|
self.video_path = None
|
|
self.cap = None
|
|
self.fps = None
|
|
self.total_frames = None
|
|
|
|
def _load_config(self, config_path: Optional[str]) -> Dict:
|
|
"""Load configuration from JSON file"""
|
|
default_config = {
|
|
'court': {
|
|
'width_m': 6.1,
|
|
'length_m': 13.4
|
|
},
|
|
'detection': {
|
|
'confidence_threshold': 0.4,
|
|
'iou_threshold': 0.5,
|
|
'slice_enabled': True,
|
|
'slice_height': 320,
|
|
'slice_width': 320
|
|
},
|
|
'tracking': {
|
|
'buffer_size': 10,
|
|
'max_distance_threshold': 100
|
|
}
|
|
}
|
|
|
|
if config_path and Path(config_path).exists():
|
|
with open(config_path, 'r') as f:
|
|
config = json.load(f)
|
|
# Merge with defaults
|
|
for key in default_config:
|
|
if key not in config:
|
|
config[key] = default_config[key]
|
|
return config
|
|
|
|
return default_config
|
|
|
|
def load_video(self, video_path: str) -> bool:
|
|
"""
|
|
Load video file
|
|
|
|
Args:
|
|
video_path: Path to video file
|
|
|
|
Returns:
|
|
True if successful, False otherwise
|
|
"""
|
|
self.video_path = video_path
|
|
self.cap = cv2.VideoCapture(video_path)
|
|
|
|
if not self.cap.isOpened():
|
|
print(f"Error: Could not open video {video_path}")
|
|
return False
|
|
|
|
self.fps = self.cap.get(cv2.CAP_PROP_FPS)
|
|
self.total_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
|
|
print(f"✓ Loaded video: {video_path}")
|
|
print(f" FPS: {self.fps}")
|
|
print(f" Total frames: {self.total_frames}")
|
|
print(f" Duration: {self.total_frames / self.fps:.2f} seconds")
|
|
|
|
return True
|
|
|
|
def calibrate_court(
|
|
self,
|
|
corner_points: Optional[List[Tuple[float, float]]] = None,
|
|
interactive: bool = False
|
|
) -> bool:
|
|
"""
|
|
Calibrate court for coordinate transformation
|
|
|
|
Args:
|
|
corner_points: Manual corner points [TL, TR, BR, BL], or None for interactive
|
|
interactive: If True, use interactive calibration tool
|
|
|
|
Returns:
|
|
True if calibration successful
|
|
"""
|
|
if corner_points is not None:
|
|
return self.calibrator.calibrate_manual(corner_points)
|
|
|
|
if interactive:
|
|
# Get first frame
|
|
self.cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
|
|
ret, frame = self.cap.read()
|
|
if not ret:
|
|
print("Error: Could not read first frame")
|
|
return False
|
|
|
|
# Interactive calibration
|
|
calibrator = InteractiveCalibrator()
|
|
points = calibrator.calibrate_interactive(frame)
|
|
|
|
if points is None:
|
|
return False
|
|
|
|
return self.calibrator.calibrate_manual(points)
|
|
|
|
print("Warning: No calibration method specified")
|
|
print("Processing will continue without coordinate transformation")
|
|
return False
|
|
|
|
def process_video(
|
|
self,
|
|
output_path: Optional[str] = None,
|
|
save_visualization: bool = False,
|
|
visualization_path: Optional[str] = None,
|
|
start_frame: int = 0,
|
|
end_frame: Optional[int] = None
|
|
) -> Dict:
|
|
"""
|
|
Process video and extract ball trajectory
|
|
|
|
Args:
|
|
output_path: Path to save JSON results
|
|
save_visualization: If True, save video with annotations
|
|
visualization_path: Path to save visualization video
|
|
start_frame: Frame to start processing from
|
|
end_frame: Frame to end processing at (None = end of video)
|
|
|
|
Returns:
|
|
Dictionary with processing results
|
|
"""
|
|
if self.cap is None:
|
|
raise ValueError("No video loaded. Call load_video() first")
|
|
|
|
# Reset tracker
|
|
self.tracker.reset()
|
|
|
|
# Set frame range
|
|
if end_frame is None:
|
|
end_frame = self.total_frames
|
|
|
|
# Setup visualization writer if needed
|
|
video_writer = None
|
|
if save_visualization:
|
|
if visualization_path is None:
|
|
visualization_path = str(Path(self.video_path).stem) + "_tracked.mp4"
|
|
|
|
frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
|
video_writer = cv2.VideoWriter(
|
|
visualization_path,
|
|
fourcc,
|
|
self.fps,
|
|
(frame_width, frame_height)
|
|
)
|
|
|
|
# Processing results
|
|
results = {
|
|
'video_path': self.video_path,
|
|
'fps': self.fps,
|
|
'total_frames': end_frame - start_frame,
|
|
'duration_sec': (end_frame - start_frame) / self.fps,
|
|
'court': {
|
|
'width_m': self.config['court']['width_m'],
|
|
'length_m': self.config['court']['length_m'],
|
|
'calibrated': self.calibrator.is_calibrated()
|
|
},
|
|
'frames': []
|
|
}
|
|
|
|
# Process frames
|
|
self.cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
|
|
|
|
print(f"\nProcessing frames {start_frame} to {end_frame}...")
|
|
start_time = time.time()
|
|
|
|
for frame_num in tqdm(range(start_frame, end_frame)):
|
|
ret, frame = self.cap.read()
|
|
if not ret:
|
|
break
|
|
|
|
# Detect ball
|
|
detections = self.detector.detect(frame)
|
|
|
|
# Track ball
|
|
ball = self.tracker.update(detections, frame_num)
|
|
|
|
# Transform to real-world coordinates if calibrated
|
|
frame_data = {
|
|
'frame_number': frame_num,
|
|
'timestamp': frame_num / self.fps,
|
|
'ball': None
|
|
}
|
|
|
|
if ball is not None:
|
|
pixel_coords = ball['center']
|
|
real_coords = None
|
|
|
|
if self.calibrator.is_calibrated():
|
|
real_coords = self.calibrator.pixel_to_real(pixel_coords)
|
|
|
|
frame_data['ball'] = {
|
|
'detected': True,
|
|
'pixel_coords': {
|
|
'x': float(pixel_coords[0]),
|
|
'y': float(pixel_coords[1])
|
|
},
|
|
'real_coords_m': {
|
|
'x': float(real_coords[0]) if real_coords else None,
|
|
'y': float(real_coords[1]) if real_coords else None
|
|
} if real_coords else None,
|
|
'confidence': float(ball['confidence']),
|
|
'interpolated': ball.get('interpolated', False)
|
|
}
|
|
|
|
results['frames'].append(frame_data)
|
|
|
|
# Draw visualization if needed
|
|
if save_visualization:
|
|
vis_frame = self._draw_visualization(frame, ball, frame_num)
|
|
video_writer.write(vis_frame)
|
|
|
|
# Cleanup
|
|
if video_writer is not None:
|
|
video_writer.release()
|
|
|
|
processing_time = time.time() - start_time
|
|
results['processing_time_sec'] = processing_time
|
|
results['fps_processing'] = (end_frame - start_frame) / processing_time
|
|
|
|
print(f"\n✓ Processing complete!")
|
|
print(f" Time: {processing_time:.2f} seconds")
|
|
print(f" Speed: {results['fps_processing']:.2f} FPS")
|
|
|
|
# Save results
|
|
if output_path:
|
|
self.save_results(results, output_path)
|
|
|
|
return results
|
|
|
|
def _draw_visualization(
|
|
self,
|
|
frame: np.ndarray,
|
|
ball: Optional[Dict],
|
|
frame_num: int
|
|
) -> np.ndarray:
|
|
"""
|
|
Draw visualization on frame
|
|
|
|
Args:
|
|
frame: Input frame
|
|
ball: Ball detection
|
|
frame_num: Current frame number
|
|
|
|
Returns:
|
|
Frame with visualization
|
|
"""
|
|
vis = frame.copy()
|
|
|
|
# Draw court overlay if calibrated
|
|
if self.calibrator.is_calibrated():
|
|
vis = self.calibrator.draw_court_overlay(vis)
|
|
|
|
# Draw ball trajectory
|
|
vis = self.tracker.draw_trajectory(vis)
|
|
|
|
# Draw current ball
|
|
if ball is not None:
|
|
center = ball['center']
|
|
color = (0, 255, 0) if not ball.get('interpolated', False) else (255, 0, 0)
|
|
|
|
# Draw bounding box
|
|
bbox = ball.get('bbox', self.tracker._estimate_bbox(center))
|
|
cv2.rectangle(
|
|
vis,
|
|
(int(bbox[0]), int(bbox[1])),
|
|
(int(bbox[2]), int(bbox[3])),
|
|
color,
|
|
2
|
|
)
|
|
|
|
# Draw center point
|
|
cv2.circle(vis, (int(center[0]), int(center[1])), 5, color, -1)
|
|
|
|
# Draw confidence
|
|
if ball['confidence'] > 0:
|
|
cv2.putText(
|
|
vis,
|
|
f"{ball['confidence']:.2f}",
|
|
(int(center[0]) + 10, int(center[1]) - 10),
|
|
cv2.FONT_HERSHEY_SIMPLEX,
|
|
0.5,
|
|
color,
|
|
2
|
|
)
|
|
|
|
# Draw frame number
|
|
cv2.putText(
|
|
vis,
|
|
f"Frame: {frame_num}",
|
|
(10, 30),
|
|
cv2.FONT_HERSHEY_SIMPLEX,
|
|
1,
|
|
(255, 255, 255),
|
|
2
|
|
)
|
|
|
|
return vis
|
|
|
|
def save_results(self, results: Dict, output_path: str):
|
|
"""
|
|
Save processing results to JSON file
|
|
|
|
Args:
|
|
results: Results dictionary
|
|
output_path: Path to save JSON
|
|
"""
|
|
output_path = Path(output_path)
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
with open(output_path, 'w') as f:
|
|
json.dump(results, f, indent=2)
|
|
|
|
print(f"✓ Results saved to {output_path}")
|
|
|
|
def close(self):
|
|
"""Release video resources"""
|
|
if self.cap is not None:
|
|
self.cap.release()
|