Initial commit
This commit is contained in:
378
src/video_processor.py
Normal file
378
src/video_processor.py
Normal file
@@ -0,0 +1,378 @@
|
||||
"""
|
||||
Main video processing pipeline
|
||||
Combines ball detection, court calibration, and tracking
|
||||
"""
|
||||
import cv2
|
||||
import numpy as np
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from tqdm import tqdm
|
||||
import time
|
||||
|
||||
from .ball_detector import BallDetector
|
||||
from .court_calibrator import CourtCalibrator, InteractiveCalibrator
|
||||
from .ball_tracker import BallTracker
|
||||
|
||||
|
||||
class VideoProcessor:
|
||||
"""
|
||||
Main pipeline for processing pickleball videos
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str = "pickleball-detection-1oqlw/1",
|
||||
config_path: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Initialize video processor
|
||||
|
||||
Args:
|
||||
model_id: Roboflow model ID
|
||||
config_path: Path to configuration JSON file
|
||||
"""
|
||||
# Load config
|
||||
self.config = self._load_config(config_path)
|
||||
|
||||
# Initialize components
|
||||
self.detector = BallDetector(
|
||||
model_id=model_id,
|
||||
confidence_threshold=self.config['detection']['confidence_threshold'],
|
||||
slice_enabled=self.config['detection']['slice_enabled'],
|
||||
slice_height=self.config['detection']['slice_height'],
|
||||
slice_width=self.config['detection']['slice_width']
|
||||
)
|
||||
|
||||
self.calibrator = CourtCalibrator(
|
||||
court_width_m=self.config['court']['width_m'],
|
||||
court_length_m=self.config['court']['length_m']
|
||||
)
|
||||
|
||||
self.tracker = BallTracker(
|
||||
buffer_size=self.config['tracking']['buffer_size'],
|
||||
max_distance_threshold=self.config['tracking']['max_distance_threshold']
|
||||
)
|
||||
|
||||
self.video_path = None
|
||||
self.cap = None
|
||||
self.fps = None
|
||||
self.total_frames = None
|
||||
|
||||
def _load_config(self, config_path: Optional[str]) -> Dict:
|
||||
"""Load configuration from JSON file"""
|
||||
default_config = {
|
||||
'court': {
|
||||
'width_m': 6.1,
|
||||
'length_m': 13.4
|
||||
},
|
||||
'detection': {
|
||||
'confidence_threshold': 0.4,
|
||||
'iou_threshold': 0.5,
|
||||
'slice_enabled': True,
|
||||
'slice_height': 320,
|
||||
'slice_width': 320
|
||||
},
|
||||
'tracking': {
|
||||
'buffer_size': 10,
|
||||
'max_distance_threshold': 100
|
||||
}
|
||||
}
|
||||
|
||||
if config_path and Path(config_path).exists():
|
||||
with open(config_path, 'r') as f:
|
||||
config = json.load(f)
|
||||
# Merge with defaults
|
||||
for key in default_config:
|
||||
if key not in config:
|
||||
config[key] = default_config[key]
|
||||
return config
|
||||
|
||||
return default_config
|
||||
|
||||
def load_video(self, video_path: str) -> bool:
|
||||
"""
|
||||
Load video file
|
||||
|
||||
Args:
|
||||
video_path: Path to video file
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
self.video_path = video_path
|
||||
self.cap = cv2.VideoCapture(video_path)
|
||||
|
||||
if not self.cap.isOpened():
|
||||
print(f"Error: Could not open video {video_path}")
|
||||
return False
|
||||
|
||||
self.fps = self.cap.get(cv2.CAP_PROP_FPS)
|
||||
self.total_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
|
||||
print(f"✓ Loaded video: {video_path}")
|
||||
print(f" FPS: {self.fps}")
|
||||
print(f" Total frames: {self.total_frames}")
|
||||
print(f" Duration: {self.total_frames / self.fps:.2f} seconds")
|
||||
|
||||
return True
|
||||
|
||||
def calibrate_court(
|
||||
self,
|
||||
corner_points: Optional[List[Tuple[float, float]]] = None,
|
||||
interactive: bool = False
|
||||
) -> bool:
|
||||
"""
|
||||
Calibrate court for coordinate transformation
|
||||
|
||||
Args:
|
||||
corner_points: Manual corner points [TL, TR, BR, BL], or None for interactive
|
||||
interactive: If True, use interactive calibration tool
|
||||
|
||||
Returns:
|
||||
True if calibration successful
|
||||
"""
|
||||
if corner_points is not None:
|
||||
return self.calibrator.calibrate_manual(corner_points)
|
||||
|
||||
if interactive:
|
||||
# Get first frame
|
||||
self.cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
|
||||
ret, frame = self.cap.read()
|
||||
if not ret:
|
||||
print("Error: Could not read first frame")
|
||||
return False
|
||||
|
||||
# Interactive calibration
|
||||
calibrator = InteractiveCalibrator()
|
||||
points = calibrator.calibrate_interactive(frame)
|
||||
|
||||
if points is None:
|
||||
return False
|
||||
|
||||
return self.calibrator.calibrate_manual(points)
|
||||
|
||||
print("Warning: No calibration method specified")
|
||||
print("Processing will continue without coordinate transformation")
|
||||
return False
|
||||
|
||||
def process_video(
|
||||
self,
|
||||
output_path: Optional[str] = None,
|
||||
save_visualization: bool = False,
|
||||
visualization_path: Optional[str] = None,
|
||||
start_frame: int = 0,
|
||||
end_frame: Optional[int] = None
|
||||
) -> Dict:
|
||||
"""
|
||||
Process video and extract ball trajectory
|
||||
|
||||
Args:
|
||||
output_path: Path to save JSON results
|
||||
save_visualization: If True, save video with annotations
|
||||
visualization_path: Path to save visualization video
|
||||
start_frame: Frame to start processing from
|
||||
end_frame: Frame to end processing at (None = end of video)
|
||||
|
||||
Returns:
|
||||
Dictionary with processing results
|
||||
"""
|
||||
if self.cap is None:
|
||||
raise ValueError("No video loaded. Call load_video() first")
|
||||
|
||||
# Reset tracker
|
||||
self.tracker.reset()
|
||||
|
||||
# Set frame range
|
||||
if end_frame is None:
|
||||
end_frame = self.total_frames
|
||||
|
||||
# Setup visualization writer if needed
|
||||
video_writer = None
|
||||
if save_visualization:
|
||||
if visualization_path is None:
|
||||
visualization_path = str(Path(self.video_path).stem) + "_tracked.mp4"
|
||||
|
||||
frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||||
video_writer = cv2.VideoWriter(
|
||||
visualization_path,
|
||||
fourcc,
|
||||
self.fps,
|
||||
(frame_width, frame_height)
|
||||
)
|
||||
|
||||
# Processing results
|
||||
results = {
|
||||
'video_path': self.video_path,
|
||||
'fps': self.fps,
|
||||
'total_frames': end_frame - start_frame,
|
||||
'duration_sec': (end_frame - start_frame) / self.fps,
|
||||
'court': {
|
||||
'width_m': self.config['court']['width_m'],
|
||||
'length_m': self.config['court']['length_m'],
|
||||
'calibrated': self.calibrator.is_calibrated()
|
||||
},
|
||||
'frames': []
|
||||
}
|
||||
|
||||
# Process frames
|
||||
self.cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
|
||||
|
||||
print(f"\nProcessing frames {start_frame} to {end_frame}...")
|
||||
start_time = time.time()
|
||||
|
||||
for frame_num in tqdm(range(start_frame, end_frame)):
|
||||
ret, frame = self.cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
# Detect ball
|
||||
detections = self.detector.detect(frame)
|
||||
|
||||
# Track ball
|
||||
ball = self.tracker.update(detections, frame_num)
|
||||
|
||||
# Transform to real-world coordinates if calibrated
|
||||
frame_data = {
|
||||
'frame_number': frame_num,
|
||||
'timestamp': frame_num / self.fps,
|
||||
'ball': None
|
||||
}
|
||||
|
||||
if ball is not None:
|
||||
pixel_coords = ball['center']
|
||||
real_coords = None
|
||||
|
||||
if self.calibrator.is_calibrated():
|
||||
real_coords = self.calibrator.pixel_to_real(pixel_coords)
|
||||
|
||||
frame_data['ball'] = {
|
||||
'detected': True,
|
||||
'pixel_coords': {
|
||||
'x': float(pixel_coords[0]),
|
||||
'y': float(pixel_coords[1])
|
||||
},
|
||||
'real_coords_m': {
|
||||
'x': float(real_coords[0]) if real_coords else None,
|
||||
'y': float(real_coords[1]) if real_coords else None
|
||||
} if real_coords else None,
|
||||
'confidence': float(ball['confidence']),
|
||||
'interpolated': ball.get('interpolated', False)
|
||||
}
|
||||
|
||||
results['frames'].append(frame_data)
|
||||
|
||||
# Draw visualization if needed
|
||||
if save_visualization:
|
||||
vis_frame = self._draw_visualization(frame, ball, frame_num)
|
||||
video_writer.write(vis_frame)
|
||||
|
||||
# Cleanup
|
||||
if video_writer is not None:
|
||||
video_writer.release()
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
results['processing_time_sec'] = processing_time
|
||||
results['fps_processing'] = (end_frame - start_frame) / processing_time
|
||||
|
||||
print(f"\n✓ Processing complete!")
|
||||
print(f" Time: {processing_time:.2f} seconds")
|
||||
print(f" Speed: {results['fps_processing']:.2f} FPS")
|
||||
|
||||
# Save results
|
||||
if output_path:
|
||||
self.save_results(results, output_path)
|
||||
|
||||
return results
|
||||
|
||||
def _draw_visualization(
|
||||
self,
|
||||
frame: np.ndarray,
|
||||
ball: Optional[Dict],
|
||||
frame_num: int
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Draw visualization on frame
|
||||
|
||||
Args:
|
||||
frame: Input frame
|
||||
ball: Ball detection
|
||||
frame_num: Current frame number
|
||||
|
||||
Returns:
|
||||
Frame with visualization
|
||||
"""
|
||||
vis = frame.copy()
|
||||
|
||||
# Draw court overlay if calibrated
|
||||
if self.calibrator.is_calibrated():
|
||||
vis = self.calibrator.draw_court_overlay(vis)
|
||||
|
||||
# Draw ball trajectory
|
||||
vis = self.tracker.draw_trajectory(vis)
|
||||
|
||||
# Draw current ball
|
||||
if ball is not None:
|
||||
center = ball['center']
|
||||
color = (0, 255, 0) if not ball.get('interpolated', False) else (255, 0, 0)
|
||||
|
||||
# Draw bounding box
|
||||
bbox = ball.get('bbox', self.tracker._estimate_bbox(center))
|
||||
cv2.rectangle(
|
||||
vis,
|
||||
(int(bbox[0]), int(bbox[1])),
|
||||
(int(bbox[2]), int(bbox[3])),
|
||||
color,
|
||||
2
|
||||
)
|
||||
|
||||
# Draw center point
|
||||
cv2.circle(vis, (int(center[0]), int(center[1])), 5, color, -1)
|
||||
|
||||
# Draw confidence
|
||||
if ball['confidence'] > 0:
|
||||
cv2.putText(
|
||||
vis,
|
||||
f"{ball['confidence']:.2f}",
|
||||
(int(center[0]) + 10, int(center[1]) - 10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.5,
|
||||
color,
|
||||
2
|
||||
)
|
||||
|
||||
# Draw frame number
|
||||
cv2.putText(
|
||||
vis,
|
||||
f"Frame: {frame_num}",
|
||||
(10, 30),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
1,
|
||||
(255, 255, 255),
|
||||
2
|
||||
)
|
||||
|
||||
return vis
|
||||
|
||||
def save_results(self, results: Dict, output_path: str):
|
||||
"""
|
||||
Save processing results to JSON file
|
||||
|
||||
Args:
|
||||
results: Results dictionary
|
||||
output_path: Path to save JSON
|
||||
"""
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(output_path, 'w') as f:
|
||||
json.dump(results, f, indent=2)
|
||||
|
||||
print(f"✓ Results saved to {output_path}")
|
||||
|
||||
def close(self):
|
||||
"""Release video resources"""
|
||||
if self.cap is not None:
|
||||
self.cap.release()
|
||||
Reference in New Issue
Block a user