Initial commit

This commit is contained in:
Ruslan Bakiev
2026-03-06 09:43:52 +07:00
commit 549fd1da9d
250 changed files with 9114 additions and 0 deletions

17
src/__init__.py Normal file
View File

@@ -0,0 +1,17 @@
"""
Pickle - Pickleball Ball Tracking System
"""
from .ball_detector import BallDetector
from .court_calibrator import CourtCalibrator, InteractiveCalibrator
from .ball_tracker import BallTracker, MultiObjectTracker
from .video_processor import VideoProcessor
__all__ = [
'BallDetector',
'CourtCalibrator',
'InteractiveCalibrator',
'BallTracker',
'MultiObjectTracker',
'VideoProcessor'
]

263
src/ball_detector.py Normal file
View File

@@ -0,0 +1,263 @@
"""
Ball detector module using Roboflow Hosted Inference and YOLO v8
"""
import os
import numpy as np
from typing import List, Tuple, Optional, Dict
from inference_sdk import InferenceHTTPClient
import cv2
class BallDetector:
"""
Detects pickleball balls in video frames using Roboflow pre-trained models
Implements InferenceSlicer technique for better small object detection
"""
def __init__(
self,
model_id: str = "pickleball-detection-1oqlw/1",
confidence_threshold: float = 0.4,
iou_threshold: float = 0.5,
slice_enabled: bool = True,
slice_height: int = 320,
slice_width: int = 320,
overlap_ratio: float = 0.2
):
"""
Initialize the ball detector
Args:
model_id: Roboflow model ID (format: workspace/project/version)
confidence_threshold: Minimum confidence for detections
iou_threshold: IoU threshold for NMS
slice_enabled: Enable frame slicing for better small object detection
slice_height: Height of each slice
slice_width: Width of each slice
overlap_ratio: Overlap ratio between slices
"""
self.model_id = model_id
self.confidence_threshold = confidence_threshold
self.iou_threshold = iou_threshold
self.slice_enabled = slice_enabled
self.slice_height = slice_height
self.slice_width = slice_width
self.overlap_ratio = overlap_ratio
# Initialize Roboflow Hosted Inference client
api_key = os.getenv("ROBOFLOW_API_KEY")
if not api_key:
print("✗ ROBOFLOW_API_KEY not set, using YOLO v8 fallback")
from ultralytics import YOLO
self.model = YOLO('yolov8n.pt')
self.use_fallback = True
else:
try:
self.client = InferenceHTTPClient(
api_url="https://serverless.roboflow.com",
api_key=api_key
)
print(f"✓ Initialized Roboflow Hosted Inference: {model_id}")
self.use_fallback = False
except Exception as e:
print(f"✗ Failed to initialize Roboflow client: {e}")
print("Falling back to YOLO v8 base model for sports ball detection")
from ultralytics import YOLO
self.model = YOLO('yolov8n.pt')
self.use_fallback = True
def detect(self, frame: np.ndarray) -> List[Dict]:
"""
Detect balls in a single frame
Args:
frame: Input frame (numpy array in BGR format)
Returns:
List of detections with format:
[
{
'bbox': [x1, y1, x2, y2],
'confidence': float,
'center': [cx, cy]
},
...
]
"""
if self.use_fallback:
return self._detect_with_yolo(frame)
if self.slice_enabled:
return self._detect_with_slicing(frame)
else:
return self._detect_single(frame)
def _detect_single(self, frame: np.ndarray) -> List[Dict]:
"""Detect on full frame without slicing"""
try:
# Hosted API expects file path or base64, so we save temp image
import tempfile
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp:
cv2.imwrite(tmp.name, frame)
tmp_path = tmp.name
try:
results = self.client.infer(
tmp_path,
model_id=self.model_id
)
return self._parse_results(results)
finally:
import os
os.unlink(tmp_path)
except Exception as e:
print(f"Detection error: {e}")
return []
def _detect_with_slicing(self, frame: np.ndarray) -> List[Dict]:
"""
Detect using InferenceSlicer technique
Divides frame into overlapping tiles for better small object detection
"""
height, width = frame.shape[:2]
detections = []
# Calculate number of slices
stride_h = int(self.slice_height * (1 - self.overlap_ratio))
stride_w = int(self.slice_width * (1 - self.overlap_ratio))
for y in range(0, height, stride_h):
for x in range(0, width, stride_w):
# Extract slice
y_end = min(y + self.slice_height, height)
x_end = min(x + self.slice_width, width)
slice_img = frame[y:y_end, x:x_end]
# Detect on slice
try:
# Hosted API expects file path, save temp slice
import tempfile
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp:
cv2.imwrite(tmp.name, slice_img)
tmp_path = tmp.name
try:
results = self.client.infer(
tmp_path,
model_id=self.model_id
)
# Parse and adjust coordinates
slice_detections = self._parse_results(results)
for det in slice_detections:
# Adjust bbox coordinates to full frame
det['bbox'][0] += x # x1
det['bbox'][1] += y # y1
det['bbox'][2] += x # x2
det['bbox'][3] += y # y2
det['center'][0] += x # cx
det['center'][1] += y # cy
detections.append(det)
finally:
import os
os.unlink(tmp_path)
except Exception as e:
continue
# Apply NMS to remove duplicate detections from overlapping slices
detections = self._apply_nms(detections)
return detections
def _detect_with_yolo(self, frame: np.ndarray) -> List[Dict]:
"""Fallback detection using ultralytics YOLO"""
results = self.model(frame, conf=self.confidence_threshold, verbose=False)
detections = []
for result in results:
boxes = result.boxes
for box in boxes:
# Filter for sports ball class (class 32 in COCO)
if int(box.cls[0]) == 32: # sports ball
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
conf = float(box.conf[0])
cx = (x1 + x2) / 2
cy = (y1 + y2) / 2
detections.append({
'bbox': [float(x1), float(y1), float(x2), float(y2)],
'confidence': conf,
'center': [float(cx), float(cy)]
})
return detections
def _parse_results(self, results) -> List[Dict]:
"""Parse Roboflow Hosted API results"""
detections = []
if not results:
return detections
# Hosted API returns dict with 'predictions' key
predictions = results.get('predictions', []) if isinstance(results, dict) else []
for pred in predictions:
# Hosted API format: dict with 'x', 'y', 'width', 'height', 'class', 'confidence'
if not isinstance(pred, dict):
continue
# Extract bbox
if 'x' in pred and 'y' in pred and 'width' in pred and 'height' in pred:
cx = pred['x']
cy = pred['y']
w = pred['width']
h = pred['height']
x1 = cx - w / 2
y1 = cy - h / 2
x2 = cx + w / 2
y2 = cy + h / 2
else:
continue
# Filter for ball class
class_name = pred.get('class', '')
if 'ball' not in class_name.lower():
continue
conf = pred.get('confidence', 0.0)
detections.append({
'bbox': [x1, y1, x2, y2],
'confidence': conf,
'center': [cx, cy]
})
return detections
def _apply_nms(self, detections: List[Dict]) -> List[Dict]:
"""Apply Non-Maximum Suppression to remove duplicate detections"""
if len(detections) == 0:
return []
# Extract boxes and scores
boxes = np.array([det['bbox'] for det in detections])
scores = np.array([det['confidence'] for det in detections])
# Apply NMS using OpenCV
indices = cv2.dnn.NMSBoxes(
boxes.tolist(),
scores.tolist(),
self.confidence_threshold,
self.iou_threshold
)
if len(indices) == 0:
return []
# Return filtered detections
indices = indices.flatten()
return [detections[i] for i in indices]

419
src/ball_tracker.py Normal file
View File

@@ -0,0 +1,419 @@
"""
Ball tracking module with buffer-based filtering and trajectory smoothing
"""
import numpy as np
from typing import List, Dict, Optional, Tuple
from collections import deque
import cv2
class BallTracker:
"""
Tracks ball across frames using buffer-based filtering
Handles occlusions and false detections
"""
def __init__(
self,
buffer_size: int = 10,
max_distance_threshold: int = 100,
min_confidence: float = 0.3
):
"""
Initialize ball tracker
Args:
buffer_size: Number of recent positions to store for filtering
max_distance_threshold: Maximum pixel distance between frames to consider same ball
min_confidence: Minimum confidence threshold for detections
"""
self.buffer_size = buffer_size
self.max_distance_threshold = max_distance_threshold
self.min_confidence = min_confidence
# Buffer for storing recent ball positions
self.position_buffer = deque(maxlen=buffer_size)
# Track ball state
self.current_position = None
self.lost_frames = 0
self.max_lost_frames = 10 # Maximum frames to interpolate when ball is lost
# Trajectory history
self.trajectory = []
def update(self, detections: List[Dict], frame_number: int) -> Optional[Dict]:
"""
Update tracker with new detections from current frame
Args:
detections: List of ball detections from detector
frame_number: Current frame number
Returns:
Best ball detection (filtered), or None if no valid detection
"""
# Filter low confidence detections
valid_detections = [
det for det in detections
if det['confidence'] >= self.min_confidence
]
if len(valid_detections) == 0:
# No detections - try to interpolate if recently had detection
return self._handle_missing_detection(frame_number)
# Select best detection
best_detection = self._select_best_detection(valid_detections)
if best_detection is None:
return self._handle_missing_detection(frame_number)
# Update buffer and state
self.position_buffer.append(best_detection['center'])
self.current_position = best_detection['center']
self.lost_frames = 0
# Add to trajectory
self.trajectory.append({
'frame': frame_number,
'position': best_detection['center'],
'pixel_coords': best_detection['center'],
'real_coords': None, # Will be filled by processor
'confidence': best_detection['confidence']
})
return best_detection
def _select_best_detection(self, detections: List[Dict]) -> Optional[Dict]:
"""
Select the most likely ball detection from multiple candidates
Args:
detections: List of valid detections
Returns:
Best detection, or None
"""
if len(detections) == 0:
return None
if len(detections) == 1:
return detections[0]
# If we have position history, use it to filter
if len(self.position_buffer) > 0:
return self._select_by_proximity(detections)
else:
# No history - return highest confidence
return max(detections, key=lambda d: d['confidence'])
def _select_by_proximity(self, detections: List[Dict]) -> Optional[Dict]:
"""
Select detection closest to predicted position based on buffer
Args:
detections: List of detections
Returns:
Detection closest to predicted position
"""
# Calculate average position from buffer
avg_position = np.mean(self.position_buffer, axis=0)
# Find detection closest to average
min_distance = float('inf')
best_detection = None
for det in detections:
distance = np.linalg.norm(
np.array(det['center']) - avg_position
)
# Check if within threshold
if distance < self.max_distance_threshold and distance < min_distance:
min_distance = distance
best_detection = det
# If no detection within threshold, return highest confidence
if best_detection is None:
best_detection = max(detections, key=lambda d: d['confidence'])
return best_detection
def _handle_missing_detection(self, frame_number: int) -> Optional[Dict]:
"""
Handle case when no valid detection in current frame
Args:
frame_number: Current frame number
Returns:
Interpolated detection if possible, None otherwise
"""
self.lost_frames += 1
# If lost for too long, stop interpolating
if self.lost_frames > self.max_lost_frames:
self.current_position = None
return None
# Try to interpolate based on recent positions
if len(self.position_buffer) >= 2:
interpolated = self._interpolate_position()
# Add interpolated position to trajectory (marked as interpolated)
self.trajectory.append({
'frame': frame_number,
'position': interpolated,
'pixel_coords': interpolated,
'real_coords': None,
'confidence': 0.0, # Mark as interpolated
'interpolated': True
})
return {
'center': interpolated,
'bbox': self._estimate_bbox(interpolated),
'confidence': 0.0,
'interpolated': True
}
return None
def _interpolate_position(self) -> Tuple[float, float]:
"""
Interpolate ball position based on recent trajectory
Returns:
Estimated (x, y) position
"""
if len(self.position_buffer) < 2:
return self.current_position
# Simple linear interpolation based on last two positions
positions = np.array(self.position_buffer)
# Calculate velocity
velocity = positions[-1] - positions[-2]
# Predict next position
predicted = positions[-1] + velocity
return tuple(predicted)
def _estimate_bbox(self, center: Tuple[float, float], size: int = 20) -> List[float]:
"""
Estimate bounding box for interpolated position
Args:
center: Center position
size: Estimated ball size in pixels
Returns:
[x1, y1, x2, y2] bbox
"""
cx, cy = center
half_size = size / 2
return [cx - half_size, cy - half_size, cx + half_size, cy + half_size]
def get_trajectory(self) -> List[Dict]:
"""
Get full trajectory history
Returns:
List of trajectory points
"""
return self.trajectory
def get_smoothed_trajectory(self, window_size: int = 5) -> List[Dict]:
"""
Get smoothed trajectory using moving average
Args:
window_size: Size of smoothing window
Returns:
Smoothed trajectory
"""
if len(self.trajectory) < window_size:
return self.trajectory
smoothed = []
for i, point in enumerate(self.trajectory):
# Get window
start = max(0, i - window_size // 2)
end = min(len(self.trajectory), i + window_size // 2 + 1)
window = self.trajectory[start:end]
# Calculate average position (only non-interpolated points)
valid_positions = [
p['position'] for p in window
if not p.get('interpolated', False)
]
if len(valid_positions) > 0:
avg_position = np.mean(valid_positions, axis=0)
smoothed_point = point.copy()
smoothed_point['smoothed_position'] = tuple(avg_position)
smoothed.append(smoothed_point)
else:
smoothed.append(point)
return smoothed
def reset(self):
"""Reset tracker state"""
self.position_buffer.clear()
self.current_position = None
self.lost_frames = 0
self.trajectory = []
def draw_trajectory(self, frame: np.ndarray, max_points: int = 30) -> np.ndarray:
"""
Draw ball trajectory on frame
Args:
frame: Input frame
max_points: Maximum number of trajectory points to draw
Returns:
Frame with trajectory overlay
"""
if len(self.trajectory) == 0:
return frame
overlay = frame.copy()
# Get recent trajectory points
recent_trajectory = self.trajectory[-max_points:]
# Draw trajectory line
points = [point['position'] for point in recent_trajectory]
if len(points) > 1:
pts = np.array(points, dtype=np.int32).reshape((-1, 1, 2))
cv2.polylines(overlay, [pts], False, (0, 255, 255), 2)
# Draw current position
if self.current_position is not None:
cv2.circle(
overlay,
(int(self.current_position[0]), int(self.current_position[1])),
8,
(0, 0, 255),
-1
)
return overlay
class MultiObjectTracker:
"""
Extended tracker for tracking multiple objects (ball, players, paddles)
Uses simple centroid-based tracking
"""
def __init__(self, max_disappeared: int = 10, max_distance: int = 100):
"""
Initialize multi-object tracker
Args:
max_disappeared: Max frames object can disappear before removing
max_distance: Max distance to associate detection with existing object
"""
self.next_object_id = 0
self.objects = {} # object_id -> centroid
self.disappeared = {} # object_id -> num_frames_disappeared
self.max_disappeared = max_disappeared
self.max_distance = max_distance
def register(self, centroid: Tuple[float, float]) -> int:
"""Register new object"""
object_id = self.next_object_id
self.objects[object_id] = centroid
self.disappeared[object_id] = 0
self.next_object_id += 1
return object_id
def deregister(self, object_id: int):
"""Remove object from tracking"""
del self.objects[object_id]
del self.disappeared[object_id]
def update(self, detections: List[Dict]) -> Dict[int, Dict]:
"""
Update tracker with new detections
Args:
detections: List of detections with 'center' key
Returns:
Dict mapping object_id to detection
"""
# If no detections, mark all as disappeared
if len(detections) == 0:
for object_id in list(self.disappeared.keys()):
self.disappeared[object_id] += 1
if self.disappeared[object_id] > self.max_disappeared:
self.deregister(object_id)
return {}
centroids = np.array([det['center'] for det in detections])
# If no objects being tracked, register all
if len(self.objects) == 0:
result = {}
for i, det in enumerate(detections):
object_id = self.register(centroids[i])
result[object_id] = det
return result
# Associate detections with existing objects
object_ids = list(self.objects.keys())
object_centroids = np.array([self.objects[oid] for oid in object_ids])
# Calculate distance matrix
D = np.linalg.norm(
object_centroids[:, np.newaxis] - centroids,
axis=2
)
# Find best matches
rows = D.min(axis=1).argsort()
cols = D.argmin(axis=1)[rows]
used_rows = set()
used_cols = set()
result = {}
for row, col in zip(rows, cols):
if row in used_rows or col in used_cols:
continue
if D[row, col] > self.max_distance:
continue
object_id = object_ids[row]
self.objects[object_id] = centroids[col]
self.disappeared[object_id] = 0
result[object_id] = detections[col]
used_rows.add(row)
used_cols.add(col)
# Handle disappeared objects
unused_rows = set(range(D.shape[0])) - used_rows
for row in unused_rows:
object_id = object_ids[row]
self.disappeared[object_id] += 1
if self.disappeared[object_id] > self.max_disappeared:
self.deregister(object_id)
# Register new objects
unused_cols = set(range(D.shape[1])) - used_cols
for col in unused_cols:
object_id = self.register(centroids[col])
result[object_id] = detections[col]
return result

284
src/court_calibrator.py Normal file
View File

@@ -0,0 +1,284 @@
"""
Court calibration module for mapping pixel coordinates to real-world coordinates
Uses homography transformation based on court keypoints
"""
import numpy as np
import cv2
from typing import List, Tuple, Optional, Dict
import json
class CourtCalibrator:
"""
Calibrates camera perspective to map pixel coordinates to real-world court coordinates
Uses homography transformation
"""
def __init__(self, court_width_m: float = 6.1, court_length_m: float = 13.4):
"""
Initialize court calibrator
Args:
court_width_m: Width of pickleball court in meters (default: 6.1m)
court_length_m: Length of pickleball court in meters (default: 13.4m)
"""
self.court_width = court_width_m
self.court_length = court_length_m
self.homography_matrix = None
self.court_corners_pixel = None
self.court_corners_real = np.array([
[0, 0], # Top-left
[court_length_m, 0], # Top-right
[court_length_m, court_width_m], # Bottom-right
[0, court_width_m] # Bottom-left
], dtype=np.float32)
def calibrate_manual(self, corner_points: List[Tuple[float, float]]) -> bool:
"""
Manually calibrate using 4 corner points of the court
Args:
corner_points: List of 4 (x, y) tuples representing court corners in pixels
Order: [top-left, top-right, bottom-right, bottom-left]
Returns:
True if calibration successful, False otherwise
"""
if len(corner_points) != 4:
print("Error: Need exactly 4 corner points")
return False
self.court_corners_pixel = np.array(corner_points, dtype=np.float32)
# Calculate homography matrix
self.homography_matrix, status = cv2.findHomography(
self.court_corners_pixel,
self.court_corners_real,
method=cv2.RANSAC
)
if self.homography_matrix is None:
print("Error: Failed to calculate homography matrix")
return False
print("✓ Court calibration successful")
return True
def calibrate_auto(self, frame: np.ndarray) -> bool:
"""
Automatically detect court corners using computer vision
This is a simplified version - can be enhanced with YOLO keypoint detection
Args:
frame: Video frame to detect court corners from
Returns:
True if calibration successful, False otherwise
"""
# Convert to grayscale
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
# Apply edge detection
edges = cv2.Canny(gray, 50, 150, apertureSize=3)
# Detect lines using Hough Transform
lines = cv2.HoughLinesP(edges, 1, np.pi/180, threshold=100,
minLineLength=100, maxLineGap=10)
if lines is None or len(lines) < 4:
print("Error: Could not detect enough lines for court corners")
return False
# This is a placeholder - in production, you'd use:
# 1. YOLO keypoint detection model trained on court corners
# 2. More sophisticated line intersection detection
# 3. Court line template matching
print("Warning: Auto-calibration not fully implemented")
print("Please use calibrate_manual() with corner points")
return False
def pixel_to_real(self, pixel_coords: Tuple[float, float]) -> Optional[Tuple[float, float]]:
"""
Transform pixel coordinates to real-world court coordinates
Args:
pixel_coords: (x, y) tuple in pixel space
Returns:
(x, y) tuple in real-world meters, or None if not calibrated
"""
if self.homography_matrix is None:
print("Error: Court not calibrated. Call calibrate_manual() first")
return None
# Convert to homogeneous coordinates
pixel_point = np.array([[pixel_coords[0], pixel_coords[1]]], dtype=np.float32)
pixel_point = pixel_point.reshape(-1, 1, 2)
# Apply homography transformation
real_point = cv2.perspectiveTransform(pixel_point, self.homography_matrix)
x, y = real_point[0][0]
return (float(x), float(y))
def real_to_pixel(self, real_coords: Tuple[float, float]) -> Optional[Tuple[float, float]]:
"""
Transform real-world court coordinates to pixel coordinates
Args:
real_coords: (x, y) tuple in real-world meters
Returns:
(x, y) tuple in pixel space, or None if not calibrated
"""
if self.homography_matrix is None:
print("Error: Court not calibrated")
return None
# Use inverse homography
inv_homography = np.linalg.inv(self.homography_matrix)
real_point = np.array([[real_coords[0], real_coords[1]]], dtype=np.float32)
real_point = real_point.reshape(-1, 1, 2)
pixel_point = cv2.perspectiveTransform(real_point, inv_homography)
x, y = pixel_point[0][0]
return (float(x), float(y))
def is_calibrated(self) -> bool:
"""Check if court is calibrated"""
return self.homography_matrix is not None
def draw_court_overlay(self, frame: np.ndarray) -> np.ndarray:
"""
Draw court boundaries on frame for visualization
Args:
frame: Input frame
Returns:
Frame with court overlay
"""
if not self.is_calibrated():
return frame
overlay = frame.copy()
# Draw court corners
if self.court_corners_pixel is not None:
for point in self.court_corners_pixel:
cv2.circle(overlay, (int(point[0]), int(point[1])), 10, (0, 255, 0), -1)
# Draw court boundary lines
pts = self.court_corners_pixel.astype(np.int32).reshape((-1, 1, 2))
cv2.polylines(overlay, [pts], True, (0, 255, 0), 2)
return overlay
def save_calibration(self, filepath: str):
"""Save calibration data to file"""
if not self.is_calibrated():
print("Error: No calibration to save")
return
data = {
'court_width_m': self.court_width,
'court_length_m': self.court_length,
'homography_matrix': self.homography_matrix.tolist(),
'court_corners_pixel': self.court_corners_pixel.tolist()
}
with open(filepath, 'w') as f:
json.dump(data, f, indent=2)
print(f"✓ Calibration saved to {filepath}")
def load_calibration(self, filepath: str) -> bool:
"""Load calibration data from file"""
try:
with open(filepath, 'r') as f:
data = json.load(f)
self.court_width = data['court_width_m']
self.court_length = data['court_length_m']
self.homography_matrix = np.array(data['homography_matrix'])
self.court_corners_pixel = np.array(data['court_corners_pixel'])
print(f"✓ Calibration loaded from {filepath}")
return True
except Exception as e:
print(f"Error loading calibration: {e}")
return False
class InteractiveCalibrator:
"""
Interactive tool for manual court calibration
Click on 4 corners of the court in order: top-left, top-right, bottom-right, bottom-left
"""
def __init__(self):
self.points = []
self.window_name = "Court Calibration - Click 4 corners (TL, TR, BR, BL)"
def _mouse_callback(self, event, x, y, flags, param):
"""Handle mouse clicks"""
if event == cv2.EVENT_LBUTTONDOWN and len(self.points) < 4:
self.points.append((x, y))
print(f"Point {len(self.points)}: ({x}, {y})")
def calibrate_interactive(self, frame: np.ndarray) -> Optional[List[Tuple[float, float]]]:
"""
Interactive calibration - user clicks 4 corners
Args:
frame: First frame of video to calibrate on
Returns:
List of 4 corner points, or None if cancelled
"""
display = frame.copy()
cv2.namedWindow(self.window_name)
cv2.setMouseCallback(self.window_name, self._mouse_callback)
print("\nClick on 4 court corners in this order:")
print("1. Top-left")
print("2. Top-right")
print("3. Bottom-right")
print("4. Bottom-left")
print("Press 'q' to cancel, 'r' to reset")
while True:
# Draw points
temp = display.copy()
for i, point in enumerate(self.points):
cv2.circle(temp, point, 5, (0, 255, 0), -1)
cv2.putText(temp, str(i+1), (point[0]+10, point[1]),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
# Draw lines between points
if len(self.points) > 1:
for i in range(len(self.points) - 1):
cv2.line(temp, self.points[i], self.points[i+1], (0, 255, 0), 2)
cv2.imshow(self.window_name, temp)
key = cv2.waitKey(1) & 0xFF
if key == ord('q'):
print("Calibration cancelled")
cv2.destroyAllWindows()
return None
if key == ord('r'):
print("Reset points")
self.points = []
if len(self.points) == 4:
print("\n✓ All 4 points selected")
cv2.destroyAllWindows()
return self.points
return None

378
src/video_processor.py Normal file
View File

@@ -0,0 +1,378 @@
"""
Main video processing pipeline
Combines ball detection, court calibration, and tracking
"""
import cv2
import numpy as np
import json
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from tqdm import tqdm
import time
from .ball_detector import BallDetector
from .court_calibrator import CourtCalibrator, InteractiveCalibrator
from .ball_tracker import BallTracker
class VideoProcessor:
"""
Main pipeline for processing pickleball videos
"""
def __init__(
self,
model_id: str = "pickleball-detection-1oqlw/1",
config_path: Optional[str] = None
):
"""
Initialize video processor
Args:
model_id: Roboflow model ID
config_path: Path to configuration JSON file
"""
# Load config
self.config = self._load_config(config_path)
# Initialize components
self.detector = BallDetector(
model_id=model_id,
confidence_threshold=self.config['detection']['confidence_threshold'],
slice_enabled=self.config['detection']['slice_enabled'],
slice_height=self.config['detection']['slice_height'],
slice_width=self.config['detection']['slice_width']
)
self.calibrator = CourtCalibrator(
court_width_m=self.config['court']['width_m'],
court_length_m=self.config['court']['length_m']
)
self.tracker = BallTracker(
buffer_size=self.config['tracking']['buffer_size'],
max_distance_threshold=self.config['tracking']['max_distance_threshold']
)
self.video_path = None
self.cap = None
self.fps = None
self.total_frames = None
def _load_config(self, config_path: Optional[str]) -> Dict:
"""Load configuration from JSON file"""
default_config = {
'court': {
'width_m': 6.1,
'length_m': 13.4
},
'detection': {
'confidence_threshold': 0.4,
'iou_threshold': 0.5,
'slice_enabled': True,
'slice_height': 320,
'slice_width': 320
},
'tracking': {
'buffer_size': 10,
'max_distance_threshold': 100
}
}
if config_path and Path(config_path).exists():
with open(config_path, 'r') as f:
config = json.load(f)
# Merge with defaults
for key in default_config:
if key not in config:
config[key] = default_config[key]
return config
return default_config
def load_video(self, video_path: str) -> bool:
"""
Load video file
Args:
video_path: Path to video file
Returns:
True if successful, False otherwise
"""
self.video_path = video_path
self.cap = cv2.VideoCapture(video_path)
if not self.cap.isOpened():
print(f"Error: Could not open video {video_path}")
return False
self.fps = self.cap.get(cv2.CAP_PROP_FPS)
self.total_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
print(f"✓ Loaded video: {video_path}")
print(f" FPS: {self.fps}")
print(f" Total frames: {self.total_frames}")
print(f" Duration: {self.total_frames / self.fps:.2f} seconds")
return True
def calibrate_court(
self,
corner_points: Optional[List[Tuple[float, float]]] = None,
interactive: bool = False
) -> bool:
"""
Calibrate court for coordinate transformation
Args:
corner_points: Manual corner points [TL, TR, BR, BL], or None for interactive
interactive: If True, use interactive calibration tool
Returns:
True if calibration successful
"""
if corner_points is not None:
return self.calibrator.calibrate_manual(corner_points)
if interactive:
# Get first frame
self.cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
ret, frame = self.cap.read()
if not ret:
print("Error: Could not read first frame")
return False
# Interactive calibration
calibrator = InteractiveCalibrator()
points = calibrator.calibrate_interactive(frame)
if points is None:
return False
return self.calibrator.calibrate_manual(points)
print("Warning: No calibration method specified")
print("Processing will continue without coordinate transformation")
return False
def process_video(
self,
output_path: Optional[str] = None,
save_visualization: bool = False,
visualization_path: Optional[str] = None,
start_frame: int = 0,
end_frame: Optional[int] = None
) -> Dict:
"""
Process video and extract ball trajectory
Args:
output_path: Path to save JSON results
save_visualization: If True, save video with annotations
visualization_path: Path to save visualization video
start_frame: Frame to start processing from
end_frame: Frame to end processing at (None = end of video)
Returns:
Dictionary with processing results
"""
if self.cap is None:
raise ValueError("No video loaded. Call load_video() first")
# Reset tracker
self.tracker.reset()
# Set frame range
if end_frame is None:
end_frame = self.total_frames
# Setup visualization writer if needed
video_writer = None
if save_visualization:
if visualization_path is None:
visualization_path = str(Path(self.video_path).stem) + "_tracked.mp4"
frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
video_writer = cv2.VideoWriter(
visualization_path,
fourcc,
self.fps,
(frame_width, frame_height)
)
# Processing results
results = {
'video_path': self.video_path,
'fps': self.fps,
'total_frames': end_frame - start_frame,
'duration_sec': (end_frame - start_frame) / self.fps,
'court': {
'width_m': self.config['court']['width_m'],
'length_m': self.config['court']['length_m'],
'calibrated': self.calibrator.is_calibrated()
},
'frames': []
}
# Process frames
self.cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
print(f"\nProcessing frames {start_frame} to {end_frame}...")
start_time = time.time()
for frame_num in tqdm(range(start_frame, end_frame)):
ret, frame = self.cap.read()
if not ret:
break
# Detect ball
detections = self.detector.detect(frame)
# Track ball
ball = self.tracker.update(detections, frame_num)
# Transform to real-world coordinates if calibrated
frame_data = {
'frame_number': frame_num,
'timestamp': frame_num / self.fps,
'ball': None
}
if ball is not None:
pixel_coords = ball['center']
real_coords = None
if self.calibrator.is_calibrated():
real_coords = self.calibrator.pixel_to_real(pixel_coords)
frame_data['ball'] = {
'detected': True,
'pixel_coords': {
'x': float(pixel_coords[0]),
'y': float(pixel_coords[1])
},
'real_coords_m': {
'x': float(real_coords[0]) if real_coords else None,
'y': float(real_coords[1]) if real_coords else None
} if real_coords else None,
'confidence': float(ball['confidence']),
'interpolated': ball.get('interpolated', False)
}
results['frames'].append(frame_data)
# Draw visualization if needed
if save_visualization:
vis_frame = self._draw_visualization(frame, ball, frame_num)
video_writer.write(vis_frame)
# Cleanup
if video_writer is not None:
video_writer.release()
processing_time = time.time() - start_time
results['processing_time_sec'] = processing_time
results['fps_processing'] = (end_frame - start_frame) / processing_time
print(f"\n✓ Processing complete!")
print(f" Time: {processing_time:.2f} seconds")
print(f" Speed: {results['fps_processing']:.2f} FPS")
# Save results
if output_path:
self.save_results(results, output_path)
return results
def _draw_visualization(
self,
frame: np.ndarray,
ball: Optional[Dict],
frame_num: int
) -> np.ndarray:
"""
Draw visualization on frame
Args:
frame: Input frame
ball: Ball detection
frame_num: Current frame number
Returns:
Frame with visualization
"""
vis = frame.copy()
# Draw court overlay if calibrated
if self.calibrator.is_calibrated():
vis = self.calibrator.draw_court_overlay(vis)
# Draw ball trajectory
vis = self.tracker.draw_trajectory(vis)
# Draw current ball
if ball is not None:
center = ball['center']
color = (0, 255, 0) if not ball.get('interpolated', False) else (255, 0, 0)
# Draw bounding box
bbox = ball.get('bbox', self.tracker._estimate_bbox(center))
cv2.rectangle(
vis,
(int(bbox[0]), int(bbox[1])),
(int(bbox[2]), int(bbox[3])),
color,
2
)
# Draw center point
cv2.circle(vis, (int(center[0]), int(center[1])), 5, color, -1)
# Draw confidence
if ball['confidence'] > 0:
cv2.putText(
vis,
f"{ball['confidence']:.2f}",
(int(center[0]) + 10, int(center[1]) - 10),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
color,
2
)
# Draw frame number
cv2.putText(
vis,
f"Frame: {frame_num}",
(10, 30),
cv2.FONT_HERSHEY_SIMPLEX,
1,
(255, 255, 255),
2
)
return vis
def save_results(self, results: Dict, output_path: str):
"""
Save processing results to JSON file
Args:
results: Results dictionary
output_path: Path to save JSON
"""
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'w') as f:
json.dump(results, f, indent=2)
print(f"✓ Results saved to {output_path}")
def close(self):
"""Release video resources"""
if self.cap is not None:
self.cap.release()