Initial commit
This commit is contained in:
17
src/__init__.py
Normal file
17
src/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""
|
||||
Pickle - Pickleball Ball Tracking System
|
||||
"""
|
||||
|
||||
from .ball_detector import BallDetector
|
||||
from .court_calibrator import CourtCalibrator, InteractiveCalibrator
|
||||
from .ball_tracker import BallTracker, MultiObjectTracker
|
||||
from .video_processor import VideoProcessor
|
||||
|
||||
__all__ = [
|
||||
'BallDetector',
|
||||
'CourtCalibrator',
|
||||
'InteractiveCalibrator',
|
||||
'BallTracker',
|
||||
'MultiObjectTracker',
|
||||
'VideoProcessor'
|
||||
]
|
||||
263
src/ball_detector.py
Normal file
263
src/ball_detector.py
Normal file
@@ -0,0 +1,263 @@
|
||||
"""
|
||||
Ball detector module using Roboflow Hosted Inference and YOLO v8
|
||||
"""
|
||||
import os
|
||||
import numpy as np
|
||||
from typing import List, Tuple, Optional, Dict
|
||||
from inference_sdk import InferenceHTTPClient
|
||||
import cv2
|
||||
|
||||
|
||||
class BallDetector:
|
||||
"""
|
||||
Detects pickleball balls in video frames using Roboflow pre-trained models
|
||||
Implements InferenceSlicer technique for better small object detection
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str = "pickleball-detection-1oqlw/1",
|
||||
confidence_threshold: float = 0.4,
|
||||
iou_threshold: float = 0.5,
|
||||
slice_enabled: bool = True,
|
||||
slice_height: int = 320,
|
||||
slice_width: int = 320,
|
||||
overlap_ratio: float = 0.2
|
||||
):
|
||||
"""
|
||||
Initialize the ball detector
|
||||
|
||||
Args:
|
||||
model_id: Roboflow model ID (format: workspace/project/version)
|
||||
confidence_threshold: Minimum confidence for detections
|
||||
iou_threshold: IoU threshold for NMS
|
||||
slice_enabled: Enable frame slicing for better small object detection
|
||||
slice_height: Height of each slice
|
||||
slice_width: Width of each slice
|
||||
overlap_ratio: Overlap ratio between slices
|
||||
"""
|
||||
self.model_id = model_id
|
||||
self.confidence_threshold = confidence_threshold
|
||||
self.iou_threshold = iou_threshold
|
||||
self.slice_enabled = slice_enabled
|
||||
self.slice_height = slice_height
|
||||
self.slice_width = slice_width
|
||||
self.overlap_ratio = overlap_ratio
|
||||
|
||||
# Initialize Roboflow Hosted Inference client
|
||||
api_key = os.getenv("ROBOFLOW_API_KEY")
|
||||
|
||||
if not api_key:
|
||||
print("✗ ROBOFLOW_API_KEY not set, using YOLO v8 fallback")
|
||||
from ultralytics import YOLO
|
||||
self.model = YOLO('yolov8n.pt')
|
||||
self.use_fallback = True
|
||||
else:
|
||||
try:
|
||||
self.client = InferenceHTTPClient(
|
||||
api_url="https://serverless.roboflow.com",
|
||||
api_key=api_key
|
||||
)
|
||||
print(f"✓ Initialized Roboflow Hosted Inference: {model_id}")
|
||||
self.use_fallback = False
|
||||
except Exception as e:
|
||||
print(f"✗ Failed to initialize Roboflow client: {e}")
|
||||
print("Falling back to YOLO v8 base model for sports ball detection")
|
||||
from ultralytics import YOLO
|
||||
self.model = YOLO('yolov8n.pt')
|
||||
self.use_fallback = True
|
||||
|
||||
def detect(self, frame: np.ndarray) -> List[Dict]:
|
||||
"""
|
||||
Detect balls in a single frame
|
||||
|
||||
Args:
|
||||
frame: Input frame (numpy array in BGR format)
|
||||
|
||||
Returns:
|
||||
List of detections with format:
|
||||
[
|
||||
{
|
||||
'bbox': [x1, y1, x2, y2],
|
||||
'confidence': float,
|
||||
'center': [cx, cy]
|
||||
},
|
||||
...
|
||||
]
|
||||
"""
|
||||
if self.use_fallback:
|
||||
return self._detect_with_yolo(frame)
|
||||
|
||||
if self.slice_enabled:
|
||||
return self._detect_with_slicing(frame)
|
||||
else:
|
||||
return self._detect_single(frame)
|
||||
|
||||
def _detect_single(self, frame: np.ndarray) -> List[Dict]:
|
||||
"""Detect on full frame without slicing"""
|
||||
try:
|
||||
# Hosted API expects file path or base64, so we save temp image
|
||||
import tempfile
|
||||
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp:
|
||||
cv2.imwrite(tmp.name, frame)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
results = self.client.infer(
|
||||
tmp_path,
|
||||
model_id=self.model_id
|
||||
)
|
||||
return self._parse_results(results)
|
||||
finally:
|
||||
import os
|
||||
os.unlink(tmp_path)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Detection error: {e}")
|
||||
return []
|
||||
|
||||
def _detect_with_slicing(self, frame: np.ndarray) -> List[Dict]:
|
||||
"""
|
||||
Detect using InferenceSlicer technique
|
||||
Divides frame into overlapping tiles for better small object detection
|
||||
"""
|
||||
height, width = frame.shape[:2]
|
||||
detections = []
|
||||
|
||||
# Calculate number of slices
|
||||
stride_h = int(self.slice_height * (1 - self.overlap_ratio))
|
||||
stride_w = int(self.slice_width * (1 - self.overlap_ratio))
|
||||
|
||||
for y in range(0, height, stride_h):
|
||||
for x in range(0, width, stride_w):
|
||||
# Extract slice
|
||||
y_end = min(y + self.slice_height, height)
|
||||
x_end = min(x + self.slice_width, width)
|
||||
slice_img = frame[y:y_end, x:x_end]
|
||||
|
||||
# Detect on slice
|
||||
try:
|
||||
# Hosted API expects file path, save temp slice
|
||||
import tempfile
|
||||
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp:
|
||||
cv2.imwrite(tmp.name, slice_img)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
results = self.client.infer(
|
||||
tmp_path,
|
||||
model_id=self.model_id
|
||||
)
|
||||
|
||||
# Parse and adjust coordinates
|
||||
slice_detections = self._parse_results(results)
|
||||
for det in slice_detections:
|
||||
# Adjust bbox coordinates to full frame
|
||||
det['bbox'][0] += x # x1
|
||||
det['bbox'][1] += y # y1
|
||||
det['bbox'][2] += x # x2
|
||||
det['bbox'][3] += y # y2
|
||||
det['center'][0] += x # cx
|
||||
det['center'][1] += y # cy
|
||||
detections.append(det)
|
||||
finally:
|
||||
import os
|
||||
os.unlink(tmp_path)
|
||||
|
||||
except Exception as e:
|
||||
continue
|
||||
|
||||
# Apply NMS to remove duplicate detections from overlapping slices
|
||||
detections = self._apply_nms(detections)
|
||||
|
||||
return detections
|
||||
|
||||
def _detect_with_yolo(self, frame: np.ndarray) -> List[Dict]:
|
||||
"""Fallback detection using ultralytics YOLO"""
|
||||
results = self.model(frame, conf=self.confidence_threshold, verbose=False)
|
||||
|
||||
detections = []
|
||||
for result in results:
|
||||
boxes = result.boxes
|
||||
for box in boxes:
|
||||
# Filter for sports ball class (class 32 in COCO)
|
||||
if int(box.cls[0]) == 32: # sports ball
|
||||
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
|
||||
conf = float(box.conf[0])
|
||||
cx = (x1 + x2) / 2
|
||||
cy = (y1 + y2) / 2
|
||||
|
||||
detections.append({
|
||||
'bbox': [float(x1), float(y1), float(x2), float(y2)],
|
||||
'confidence': conf,
|
||||
'center': [float(cx), float(cy)]
|
||||
})
|
||||
|
||||
return detections
|
||||
|
||||
def _parse_results(self, results) -> List[Dict]:
|
||||
"""Parse Roboflow Hosted API results"""
|
||||
detections = []
|
||||
|
||||
if not results:
|
||||
return detections
|
||||
|
||||
# Hosted API returns dict with 'predictions' key
|
||||
predictions = results.get('predictions', []) if isinstance(results, dict) else []
|
||||
|
||||
for pred in predictions:
|
||||
# Hosted API format: dict with 'x', 'y', 'width', 'height', 'class', 'confidence'
|
||||
if not isinstance(pred, dict):
|
||||
continue
|
||||
|
||||
# Extract bbox
|
||||
if 'x' in pred and 'y' in pred and 'width' in pred and 'height' in pred:
|
||||
cx = pred['x']
|
||||
cy = pred['y']
|
||||
w = pred['width']
|
||||
h = pred['height']
|
||||
x1 = cx - w / 2
|
||||
y1 = cy - h / 2
|
||||
x2 = cx + w / 2
|
||||
y2 = cy + h / 2
|
||||
else:
|
||||
continue
|
||||
|
||||
# Filter for ball class
|
||||
class_name = pred.get('class', '')
|
||||
if 'ball' not in class_name.lower():
|
||||
continue
|
||||
|
||||
conf = pred.get('confidence', 0.0)
|
||||
|
||||
detections.append({
|
||||
'bbox': [x1, y1, x2, y2],
|
||||
'confidence': conf,
|
||||
'center': [cx, cy]
|
||||
})
|
||||
|
||||
return detections
|
||||
|
||||
def _apply_nms(self, detections: List[Dict]) -> List[Dict]:
|
||||
"""Apply Non-Maximum Suppression to remove duplicate detections"""
|
||||
if len(detections) == 0:
|
||||
return []
|
||||
|
||||
# Extract boxes and scores
|
||||
boxes = np.array([det['bbox'] for det in detections])
|
||||
scores = np.array([det['confidence'] for det in detections])
|
||||
|
||||
# Apply NMS using OpenCV
|
||||
indices = cv2.dnn.NMSBoxes(
|
||||
boxes.tolist(),
|
||||
scores.tolist(),
|
||||
self.confidence_threshold,
|
||||
self.iou_threshold
|
||||
)
|
||||
|
||||
if len(indices) == 0:
|
||||
return []
|
||||
|
||||
# Return filtered detections
|
||||
indices = indices.flatten()
|
||||
return [detections[i] for i in indices]
|
||||
419
src/ball_tracker.py
Normal file
419
src/ball_tracker.py
Normal file
@@ -0,0 +1,419 @@
|
||||
"""
|
||||
Ball tracking module with buffer-based filtering and trajectory smoothing
|
||||
"""
|
||||
import numpy as np
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
from collections import deque
|
||||
import cv2
|
||||
|
||||
|
||||
class BallTracker:
|
||||
"""
|
||||
Tracks ball across frames using buffer-based filtering
|
||||
Handles occlusions and false detections
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
buffer_size: int = 10,
|
||||
max_distance_threshold: int = 100,
|
||||
min_confidence: float = 0.3
|
||||
):
|
||||
"""
|
||||
Initialize ball tracker
|
||||
|
||||
Args:
|
||||
buffer_size: Number of recent positions to store for filtering
|
||||
max_distance_threshold: Maximum pixel distance between frames to consider same ball
|
||||
min_confidence: Minimum confidence threshold for detections
|
||||
"""
|
||||
self.buffer_size = buffer_size
|
||||
self.max_distance_threshold = max_distance_threshold
|
||||
self.min_confidence = min_confidence
|
||||
|
||||
# Buffer for storing recent ball positions
|
||||
self.position_buffer = deque(maxlen=buffer_size)
|
||||
|
||||
# Track ball state
|
||||
self.current_position = None
|
||||
self.lost_frames = 0
|
||||
self.max_lost_frames = 10 # Maximum frames to interpolate when ball is lost
|
||||
|
||||
# Trajectory history
|
||||
self.trajectory = []
|
||||
|
||||
def update(self, detections: List[Dict], frame_number: int) -> Optional[Dict]:
|
||||
"""
|
||||
Update tracker with new detections from current frame
|
||||
|
||||
Args:
|
||||
detections: List of ball detections from detector
|
||||
frame_number: Current frame number
|
||||
|
||||
Returns:
|
||||
Best ball detection (filtered), or None if no valid detection
|
||||
"""
|
||||
# Filter low confidence detections
|
||||
valid_detections = [
|
||||
det for det in detections
|
||||
if det['confidence'] >= self.min_confidence
|
||||
]
|
||||
|
||||
if len(valid_detections) == 0:
|
||||
# No detections - try to interpolate if recently had detection
|
||||
return self._handle_missing_detection(frame_number)
|
||||
|
||||
# Select best detection
|
||||
best_detection = self._select_best_detection(valid_detections)
|
||||
|
||||
if best_detection is None:
|
||||
return self._handle_missing_detection(frame_number)
|
||||
|
||||
# Update buffer and state
|
||||
self.position_buffer.append(best_detection['center'])
|
||||
self.current_position = best_detection['center']
|
||||
self.lost_frames = 0
|
||||
|
||||
# Add to trajectory
|
||||
self.trajectory.append({
|
||||
'frame': frame_number,
|
||||
'position': best_detection['center'],
|
||||
'pixel_coords': best_detection['center'],
|
||||
'real_coords': None, # Will be filled by processor
|
||||
'confidence': best_detection['confidence']
|
||||
})
|
||||
|
||||
return best_detection
|
||||
|
||||
def _select_best_detection(self, detections: List[Dict]) -> Optional[Dict]:
|
||||
"""
|
||||
Select the most likely ball detection from multiple candidates
|
||||
|
||||
Args:
|
||||
detections: List of valid detections
|
||||
|
||||
Returns:
|
||||
Best detection, or None
|
||||
"""
|
||||
if len(detections) == 0:
|
||||
return None
|
||||
|
||||
if len(detections) == 1:
|
||||
return detections[0]
|
||||
|
||||
# If we have position history, use it to filter
|
||||
if len(self.position_buffer) > 0:
|
||||
return self._select_by_proximity(detections)
|
||||
else:
|
||||
# No history - return highest confidence
|
||||
return max(detections, key=lambda d: d['confidence'])
|
||||
|
||||
def _select_by_proximity(self, detections: List[Dict]) -> Optional[Dict]:
|
||||
"""
|
||||
Select detection closest to predicted position based on buffer
|
||||
|
||||
Args:
|
||||
detections: List of detections
|
||||
|
||||
Returns:
|
||||
Detection closest to predicted position
|
||||
"""
|
||||
# Calculate average position from buffer
|
||||
avg_position = np.mean(self.position_buffer, axis=0)
|
||||
|
||||
# Find detection closest to average
|
||||
min_distance = float('inf')
|
||||
best_detection = None
|
||||
|
||||
for det in detections:
|
||||
distance = np.linalg.norm(
|
||||
np.array(det['center']) - avg_position
|
||||
)
|
||||
|
||||
# Check if within threshold
|
||||
if distance < self.max_distance_threshold and distance < min_distance:
|
||||
min_distance = distance
|
||||
best_detection = det
|
||||
|
||||
# If no detection within threshold, return highest confidence
|
||||
if best_detection is None:
|
||||
best_detection = max(detections, key=lambda d: d['confidence'])
|
||||
|
||||
return best_detection
|
||||
|
||||
def _handle_missing_detection(self, frame_number: int) -> Optional[Dict]:
|
||||
"""
|
||||
Handle case when no valid detection in current frame
|
||||
|
||||
Args:
|
||||
frame_number: Current frame number
|
||||
|
||||
Returns:
|
||||
Interpolated detection if possible, None otherwise
|
||||
"""
|
||||
self.lost_frames += 1
|
||||
|
||||
# If lost for too long, stop interpolating
|
||||
if self.lost_frames > self.max_lost_frames:
|
||||
self.current_position = None
|
||||
return None
|
||||
|
||||
# Try to interpolate based on recent positions
|
||||
if len(self.position_buffer) >= 2:
|
||||
interpolated = self._interpolate_position()
|
||||
|
||||
# Add interpolated position to trajectory (marked as interpolated)
|
||||
self.trajectory.append({
|
||||
'frame': frame_number,
|
||||
'position': interpolated,
|
||||
'pixel_coords': interpolated,
|
||||
'real_coords': None,
|
||||
'confidence': 0.0, # Mark as interpolated
|
||||
'interpolated': True
|
||||
})
|
||||
|
||||
return {
|
||||
'center': interpolated,
|
||||
'bbox': self._estimate_bbox(interpolated),
|
||||
'confidence': 0.0,
|
||||
'interpolated': True
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
def _interpolate_position(self) -> Tuple[float, float]:
|
||||
"""
|
||||
Interpolate ball position based on recent trajectory
|
||||
|
||||
Returns:
|
||||
Estimated (x, y) position
|
||||
"""
|
||||
if len(self.position_buffer) < 2:
|
||||
return self.current_position
|
||||
|
||||
# Simple linear interpolation based on last two positions
|
||||
positions = np.array(self.position_buffer)
|
||||
|
||||
# Calculate velocity
|
||||
velocity = positions[-1] - positions[-2]
|
||||
|
||||
# Predict next position
|
||||
predicted = positions[-1] + velocity
|
||||
|
||||
return tuple(predicted)
|
||||
|
||||
def _estimate_bbox(self, center: Tuple[float, float], size: int = 20) -> List[float]:
|
||||
"""
|
||||
Estimate bounding box for interpolated position
|
||||
|
||||
Args:
|
||||
center: Center position
|
||||
size: Estimated ball size in pixels
|
||||
|
||||
Returns:
|
||||
[x1, y1, x2, y2] bbox
|
||||
"""
|
||||
cx, cy = center
|
||||
half_size = size / 2
|
||||
return [cx - half_size, cy - half_size, cx + half_size, cy + half_size]
|
||||
|
||||
def get_trajectory(self) -> List[Dict]:
|
||||
"""
|
||||
Get full trajectory history
|
||||
|
||||
Returns:
|
||||
List of trajectory points
|
||||
"""
|
||||
return self.trajectory
|
||||
|
||||
def get_smoothed_trajectory(self, window_size: int = 5) -> List[Dict]:
|
||||
"""
|
||||
Get smoothed trajectory using moving average
|
||||
|
||||
Args:
|
||||
window_size: Size of smoothing window
|
||||
|
||||
Returns:
|
||||
Smoothed trajectory
|
||||
"""
|
||||
if len(self.trajectory) < window_size:
|
||||
return self.trajectory
|
||||
|
||||
smoothed = []
|
||||
|
||||
for i, point in enumerate(self.trajectory):
|
||||
# Get window
|
||||
start = max(0, i - window_size // 2)
|
||||
end = min(len(self.trajectory), i + window_size // 2 + 1)
|
||||
window = self.trajectory[start:end]
|
||||
|
||||
# Calculate average position (only non-interpolated points)
|
||||
valid_positions = [
|
||||
p['position'] for p in window
|
||||
if not p.get('interpolated', False)
|
||||
]
|
||||
|
||||
if len(valid_positions) > 0:
|
||||
avg_position = np.mean(valid_positions, axis=0)
|
||||
smoothed_point = point.copy()
|
||||
smoothed_point['smoothed_position'] = tuple(avg_position)
|
||||
smoothed.append(smoothed_point)
|
||||
else:
|
||||
smoothed.append(point)
|
||||
|
||||
return smoothed
|
||||
|
||||
def reset(self):
|
||||
"""Reset tracker state"""
|
||||
self.position_buffer.clear()
|
||||
self.current_position = None
|
||||
self.lost_frames = 0
|
||||
self.trajectory = []
|
||||
|
||||
def draw_trajectory(self, frame: np.ndarray, max_points: int = 30) -> np.ndarray:
|
||||
"""
|
||||
Draw ball trajectory on frame
|
||||
|
||||
Args:
|
||||
frame: Input frame
|
||||
max_points: Maximum number of trajectory points to draw
|
||||
|
||||
Returns:
|
||||
Frame with trajectory overlay
|
||||
"""
|
||||
if len(self.trajectory) == 0:
|
||||
return frame
|
||||
|
||||
overlay = frame.copy()
|
||||
|
||||
# Get recent trajectory points
|
||||
recent_trajectory = self.trajectory[-max_points:]
|
||||
|
||||
# Draw trajectory line
|
||||
points = [point['position'] for point in recent_trajectory]
|
||||
if len(points) > 1:
|
||||
pts = np.array(points, dtype=np.int32).reshape((-1, 1, 2))
|
||||
cv2.polylines(overlay, [pts], False, (0, 255, 255), 2)
|
||||
|
||||
# Draw current position
|
||||
if self.current_position is not None:
|
||||
cv2.circle(
|
||||
overlay,
|
||||
(int(self.current_position[0]), int(self.current_position[1])),
|
||||
8,
|
||||
(0, 0, 255),
|
||||
-1
|
||||
)
|
||||
|
||||
return overlay
|
||||
|
||||
|
||||
class MultiObjectTracker:
|
||||
"""
|
||||
Extended tracker for tracking multiple objects (ball, players, paddles)
|
||||
Uses simple centroid-based tracking
|
||||
"""
|
||||
|
||||
def __init__(self, max_disappeared: int = 10, max_distance: int = 100):
|
||||
"""
|
||||
Initialize multi-object tracker
|
||||
|
||||
Args:
|
||||
max_disappeared: Max frames object can disappear before removing
|
||||
max_distance: Max distance to associate detection with existing object
|
||||
"""
|
||||
self.next_object_id = 0
|
||||
self.objects = {} # object_id -> centroid
|
||||
self.disappeared = {} # object_id -> num_frames_disappeared
|
||||
self.max_disappeared = max_disappeared
|
||||
self.max_distance = max_distance
|
||||
|
||||
def register(self, centroid: Tuple[float, float]) -> int:
|
||||
"""Register new object"""
|
||||
object_id = self.next_object_id
|
||||
self.objects[object_id] = centroid
|
||||
self.disappeared[object_id] = 0
|
||||
self.next_object_id += 1
|
||||
return object_id
|
||||
|
||||
def deregister(self, object_id: int):
|
||||
"""Remove object from tracking"""
|
||||
del self.objects[object_id]
|
||||
del self.disappeared[object_id]
|
||||
|
||||
def update(self, detections: List[Dict]) -> Dict[int, Dict]:
|
||||
"""
|
||||
Update tracker with new detections
|
||||
|
||||
Args:
|
||||
detections: List of detections with 'center' key
|
||||
|
||||
Returns:
|
||||
Dict mapping object_id to detection
|
||||
"""
|
||||
# If no detections, mark all as disappeared
|
||||
if len(detections) == 0:
|
||||
for object_id in list(self.disappeared.keys()):
|
||||
self.disappeared[object_id] += 1
|
||||
if self.disappeared[object_id] > self.max_disappeared:
|
||||
self.deregister(object_id)
|
||||
return {}
|
||||
|
||||
centroids = np.array([det['center'] for det in detections])
|
||||
|
||||
# If no objects being tracked, register all
|
||||
if len(self.objects) == 0:
|
||||
result = {}
|
||||
for i, det in enumerate(detections):
|
||||
object_id = self.register(centroids[i])
|
||||
result[object_id] = det
|
||||
return result
|
||||
|
||||
# Associate detections with existing objects
|
||||
object_ids = list(self.objects.keys())
|
||||
object_centroids = np.array([self.objects[oid] for oid in object_ids])
|
||||
|
||||
# Calculate distance matrix
|
||||
D = np.linalg.norm(
|
||||
object_centroids[:, np.newaxis] - centroids,
|
||||
axis=2
|
||||
)
|
||||
|
||||
# Find best matches
|
||||
rows = D.min(axis=1).argsort()
|
||||
cols = D.argmin(axis=1)[rows]
|
||||
|
||||
used_rows = set()
|
||||
used_cols = set()
|
||||
result = {}
|
||||
|
||||
for row, col in zip(rows, cols):
|
||||
if row in used_rows or col in used_cols:
|
||||
continue
|
||||
|
||||
if D[row, col] > self.max_distance:
|
||||
continue
|
||||
|
||||
object_id = object_ids[row]
|
||||
self.objects[object_id] = centroids[col]
|
||||
self.disappeared[object_id] = 0
|
||||
result[object_id] = detections[col]
|
||||
|
||||
used_rows.add(row)
|
||||
used_cols.add(col)
|
||||
|
||||
# Handle disappeared objects
|
||||
unused_rows = set(range(D.shape[0])) - used_rows
|
||||
for row in unused_rows:
|
||||
object_id = object_ids[row]
|
||||
self.disappeared[object_id] += 1
|
||||
if self.disappeared[object_id] > self.max_disappeared:
|
||||
self.deregister(object_id)
|
||||
|
||||
# Register new objects
|
||||
unused_cols = set(range(D.shape[1])) - used_cols
|
||||
for col in unused_cols:
|
||||
object_id = self.register(centroids[col])
|
||||
result[object_id] = detections[col]
|
||||
|
||||
return result
|
||||
284
src/court_calibrator.py
Normal file
284
src/court_calibrator.py
Normal file
@@ -0,0 +1,284 @@
|
||||
"""
|
||||
Court calibration module for mapping pixel coordinates to real-world coordinates
|
||||
Uses homography transformation based on court keypoints
|
||||
"""
|
||||
import numpy as np
|
||||
import cv2
|
||||
from typing import List, Tuple, Optional, Dict
|
||||
import json
|
||||
|
||||
|
||||
class CourtCalibrator:
|
||||
"""
|
||||
Calibrates camera perspective to map pixel coordinates to real-world court coordinates
|
||||
Uses homography transformation
|
||||
"""
|
||||
|
||||
def __init__(self, court_width_m: float = 6.1, court_length_m: float = 13.4):
|
||||
"""
|
||||
Initialize court calibrator
|
||||
|
||||
Args:
|
||||
court_width_m: Width of pickleball court in meters (default: 6.1m)
|
||||
court_length_m: Length of pickleball court in meters (default: 13.4m)
|
||||
"""
|
||||
self.court_width = court_width_m
|
||||
self.court_length = court_length_m
|
||||
self.homography_matrix = None
|
||||
self.court_corners_pixel = None
|
||||
self.court_corners_real = np.array([
|
||||
[0, 0], # Top-left
|
||||
[court_length_m, 0], # Top-right
|
||||
[court_length_m, court_width_m], # Bottom-right
|
||||
[0, court_width_m] # Bottom-left
|
||||
], dtype=np.float32)
|
||||
|
||||
def calibrate_manual(self, corner_points: List[Tuple[float, float]]) -> bool:
|
||||
"""
|
||||
Manually calibrate using 4 corner points of the court
|
||||
|
||||
Args:
|
||||
corner_points: List of 4 (x, y) tuples representing court corners in pixels
|
||||
Order: [top-left, top-right, bottom-right, bottom-left]
|
||||
|
||||
Returns:
|
||||
True if calibration successful, False otherwise
|
||||
"""
|
||||
if len(corner_points) != 4:
|
||||
print("Error: Need exactly 4 corner points")
|
||||
return False
|
||||
|
||||
self.court_corners_pixel = np.array(corner_points, dtype=np.float32)
|
||||
|
||||
# Calculate homography matrix
|
||||
self.homography_matrix, status = cv2.findHomography(
|
||||
self.court_corners_pixel,
|
||||
self.court_corners_real,
|
||||
method=cv2.RANSAC
|
||||
)
|
||||
|
||||
if self.homography_matrix is None:
|
||||
print("Error: Failed to calculate homography matrix")
|
||||
return False
|
||||
|
||||
print("✓ Court calibration successful")
|
||||
return True
|
||||
|
||||
def calibrate_auto(self, frame: np.ndarray) -> bool:
|
||||
"""
|
||||
Automatically detect court corners using computer vision
|
||||
This is a simplified version - can be enhanced with YOLO keypoint detection
|
||||
|
||||
Args:
|
||||
frame: Video frame to detect court corners from
|
||||
|
||||
Returns:
|
||||
True if calibration successful, False otherwise
|
||||
"""
|
||||
# Convert to grayscale
|
||||
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# Apply edge detection
|
||||
edges = cv2.Canny(gray, 50, 150, apertureSize=3)
|
||||
|
||||
# Detect lines using Hough Transform
|
||||
lines = cv2.HoughLinesP(edges, 1, np.pi/180, threshold=100,
|
||||
minLineLength=100, maxLineGap=10)
|
||||
|
||||
if lines is None or len(lines) < 4:
|
||||
print("Error: Could not detect enough lines for court corners")
|
||||
return False
|
||||
|
||||
# This is a placeholder - in production, you'd use:
|
||||
# 1. YOLO keypoint detection model trained on court corners
|
||||
# 2. More sophisticated line intersection detection
|
||||
# 3. Court line template matching
|
||||
|
||||
print("Warning: Auto-calibration not fully implemented")
|
||||
print("Please use calibrate_manual() with corner points")
|
||||
return False
|
||||
|
||||
def pixel_to_real(self, pixel_coords: Tuple[float, float]) -> Optional[Tuple[float, float]]:
|
||||
"""
|
||||
Transform pixel coordinates to real-world court coordinates
|
||||
|
||||
Args:
|
||||
pixel_coords: (x, y) tuple in pixel space
|
||||
|
||||
Returns:
|
||||
(x, y) tuple in real-world meters, or None if not calibrated
|
||||
"""
|
||||
if self.homography_matrix is None:
|
||||
print("Error: Court not calibrated. Call calibrate_manual() first")
|
||||
return None
|
||||
|
||||
# Convert to homogeneous coordinates
|
||||
pixel_point = np.array([[pixel_coords[0], pixel_coords[1]]], dtype=np.float32)
|
||||
pixel_point = pixel_point.reshape(-1, 1, 2)
|
||||
|
||||
# Apply homography transformation
|
||||
real_point = cv2.perspectiveTransform(pixel_point, self.homography_matrix)
|
||||
|
||||
x, y = real_point[0][0]
|
||||
return (float(x), float(y))
|
||||
|
||||
def real_to_pixel(self, real_coords: Tuple[float, float]) -> Optional[Tuple[float, float]]:
|
||||
"""
|
||||
Transform real-world court coordinates to pixel coordinates
|
||||
|
||||
Args:
|
||||
real_coords: (x, y) tuple in real-world meters
|
||||
|
||||
Returns:
|
||||
(x, y) tuple in pixel space, or None if not calibrated
|
||||
"""
|
||||
if self.homography_matrix is None:
|
||||
print("Error: Court not calibrated")
|
||||
return None
|
||||
|
||||
# Use inverse homography
|
||||
inv_homography = np.linalg.inv(self.homography_matrix)
|
||||
|
||||
real_point = np.array([[real_coords[0], real_coords[1]]], dtype=np.float32)
|
||||
real_point = real_point.reshape(-1, 1, 2)
|
||||
|
||||
pixel_point = cv2.perspectiveTransform(real_point, inv_homography)
|
||||
|
||||
x, y = pixel_point[0][0]
|
||||
return (float(x), float(y))
|
||||
|
||||
def is_calibrated(self) -> bool:
|
||||
"""Check if court is calibrated"""
|
||||
return self.homography_matrix is not None
|
||||
|
||||
def draw_court_overlay(self, frame: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Draw court boundaries on frame for visualization
|
||||
|
||||
Args:
|
||||
frame: Input frame
|
||||
|
||||
Returns:
|
||||
Frame with court overlay
|
||||
"""
|
||||
if not self.is_calibrated():
|
||||
return frame
|
||||
|
||||
overlay = frame.copy()
|
||||
|
||||
# Draw court corners
|
||||
if self.court_corners_pixel is not None:
|
||||
for point in self.court_corners_pixel:
|
||||
cv2.circle(overlay, (int(point[0]), int(point[1])), 10, (0, 255, 0), -1)
|
||||
|
||||
# Draw court boundary lines
|
||||
pts = self.court_corners_pixel.astype(np.int32).reshape((-1, 1, 2))
|
||||
cv2.polylines(overlay, [pts], True, (0, 255, 0), 2)
|
||||
|
||||
return overlay
|
||||
|
||||
def save_calibration(self, filepath: str):
|
||||
"""Save calibration data to file"""
|
||||
if not self.is_calibrated():
|
||||
print("Error: No calibration to save")
|
||||
return
|
||||
|
||||
data = {
|
||||
'court_width_m': self.court_width,
|
||||
'court_length_m': self.court_length,
|
||||
'homography_matrix': self.homography_matrix.tolist(),
|
||||
'court_corners_pixel': self.court_corners_pixel.tolist()
|
||||
}
|
||||
|
||||
with open(filepath, 'w') as f:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
print(f"✓ Calibration saved to {filepath}")
|
||||
|
||||
def load_calibration(self, filepath: str) -> bool:
|
||||
"""Load calibration data from file"""
|
||||
try:
|
||||
with open(filepath, 'r') as f:
|
||||
data = json.load(f)
|
||||
|
||||
self.court_width = data['court_width_m']
|
||||
self.court_length = data['court_length_m']
|
||||
self.homography_matrix = np.array(data['homography_matrix'])
|
||||
self.court_corners_pixel = np.array(data['court_corners_pixel'])
|
||||
|
||||
print(f"✓ Calibration loaded from {filepath}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error loading calibration: {e}")
|
||||
return False
|
||||
|
||||
|
||||
class InteractiveCalibrator:
|
||||
"""
|
||||
Interactive tool for manual court calibration
|
||||
Click on 4 corners of the court in order: top-left, top-right, bottom-right, bottom-left
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.points = []
|
||||
self.window_name = "Court Calibration - Click 4 corners (TL, TR, BR, BL)"
|
||||
|
||||
def _mouse_callback(self, event, x, y, flags, param):
|
||||
"""Handle mouse clicks"""
|
||||
if event == cv2.EVENT_LBUTTONDOWN and len(self.points) < 4:
|
||||
self.points.append((x, y))
|
||||
print(f"Point {len(self.points)}: ({x}, {y})")
|
||||
|
||||
def calibrate_interactive(self, frame: np.ndarray) -> Optional[List[Tuple[float, float]]]:
|
||||
"""
|
||||
Interactive calibration - user clicks 4 corners
|
||||
|
||||
Args:
|
||||
frame: First frame of video to calibrate on
|
||||
|
||||
Returns:
|
||||
List of 4 corner points, or None if cancelled
|
||||
"""
|
||||
display = frame.copy()
|
||||
cv2.namedWindow(self.window_name)
|
||||
cv2.setMouseCallback(self.window_name, self._mouse_callback)
|
||||
|
||||
print("\nClick on 4 court corners in this order:")
|
||||
print("1. Top-left")
|
||||
print("2. Top-right")
|
||||
print("3. Bottom-right")
|
||||
print("4. Bottom-left")
|
||||
print("Press 'q' to cancel, 'r' to reset")
|
||||
|
||||
while True:
|
||||
# Draw points
|
||||
temp = display.copy()
|
||||
for i, point in enumerate(self.points):
|
||||
cv2.circle(temp, point, 5, (0, 255, 0), -1)
|
||||
cv2.putText(temp, str(i+1), (point[0]+10, point[1]),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
|
||||
|
||||
# Draw lines between points
|
||||
if len(self.points) > 1:
|
||||
for i in range(len(self.points) - 1):
|
||||
cv2.line(temp, self.points[i], self.points[i+1], (0, 255, 0), 2)
|
||||
|
||||
cv2.imshow(self.window_name, temp)
|
||||
key = cv2.waitKey(1) & 0xFF
|
||||
|
||||
if key == ord('q'):
|
||||
print("Calibration cancelled")
|
||||
cv2.destroyAllWindows()
|
||||
return None
|
||||
|
||||
if key == ord('r'):
|
||||
print("Reset points")
|
||||
self.points = []
|
||||
|
||||
if len(self.points) == 4:
|
||||
print("\n✓ All 4 points selected")
|
||||
cv2.destroyAllWindows()
|
||||
return self.points
|
||||
|
||||
return None
|
||||
378
src/video_processor.py
Normal file
378
src/video_processor.py
Normal file
@@ -0,0 +1,378 @@
|
||||
"""
|
||||
Main video processing pipeline
|
||||
Combines ball detection, court calibration, and tracking
|
||||
"""
|
||||
import cv2
|
||||
import numpy as np
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from tqdm import tqdm
|
||||
import time
|
||||
|
||||
from .ball_detector import BallDetector
|
||||
from .court_calibrator import CourtCalibrator, InteractiveCalibrator
|
||||
from .ball_tracker import BallTracker
|
||||
|
||||
|
||||
class VideoProcessor:
|
||||
"""
|
||||
Main pipeline for processing pickleball videos
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str = "pickleball-detection-1oqlw/1",
|
||||
config_path: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Initialize video processor
|
||||
|
||||
Args:
|
||||
model_id: Roboflow model ID
|
||||
config_path: Path to configuration JSON file
|
||||
"""
|
||||
# Load config
|
||||
self.config = self._load_config(config_path)
|
||||
|
||||
# Initialize components
|
||||
self.detector = BallDetector(
|
||||
model_id=model_id,
|
||||
confidence_threshold=self.config['detection']['confidence_threshold'],
|
||||
slice_enabled=self.config['detection']['slice_enabled'],
|
||||
slice_height=self.config['detection']['slice_height'],
|
||||
slice_width=self.config['detection']['slice_width']
|
||||
)
|
||||
|
||||
self.calibrator = CourtCalibrator(
|
||||
court_width_m=self.config['court']['width_m'],
|
||||
court_length_m=self.config['court']['length_m']
|
||||
)
|
||||
|
||||
self.tracker = BallTracker(
|
||||
buffer_size=self.config['tracking']['buffer_size'],
|
||||
max_distance_threshold=self.config['tracking']['max_distance_threshold']
|
||||
)
|
||||
|
||||
self.video_path = None
|
||||
self.cap = None
|
||||
self.fps = None
|
||||
self.total_frames = None
|
||||
|
||||
def _load_config(self, config_path: Optional[str]) -> Dict:
|
||||
"""Load configuration from JSON file"""
|
||||
default_config = {
|
||||
'court': {
|
||||
'width_m': 6.1,
|
||||
'length_m': 13.4
|
||||
},
|
||||
'detection': {
|
||||
'confidence_threshold': 0.4,
|
||||
'iou_threshold': 0.5,
|
||||
'slice_enabled': True,
|
||||
'slice_height': 320,
|
||||
'slice_width': 320
|
||||
},
|
||||
'tracking': {
|
||||
'buffer_size': 10,
|
||||
'max_distance_threshold': 100
|
||||
}
|
||||
}
|
||||
|
||||
if config_path and Path(config_path).exists():
|
||||
with open(config_path, 'r') as f:
|
||||
config = json.load(f)
|
||||
# Merge with defaults
|
||||
for key in default_config:
|
||||
if key not in config:
|
||||
config[key] = default_config[key]
|
||||
return config
|
||||
|
||||
return default_config
|
||||
|
||||
def load_video(self, video_path: str) -> bool:
|
||||
"""
|
||||
Load video file
|
||||
|
||||
Args:
|
||||
video_path: Path to video file
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
self.video_path = video_path
|
||||
self.cap = cv2.VideoCapture(video_path)
|
||||
|
||||
if not self.cap.isOpened():
|
||||
print(f"Error: Could not open video {video_path}")
|
||||
return False
|
||||
|
||||
self.fps = self.cap.get(cv2.CAP_PROP_FPS)
|
||||
self.total_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
|
||||
print(f"✓ Loaded video: {video_path}")
|
||||
print(f" FPS: {self.fps}")
|
||||
print(f" Total frames: {self.total_frames}")
|
||||
print(f" Duration: {self.total_frames / self.fps:.2f} seconds")
|
||||
|
||||
return True
|
||||
|
||||
def calibrate_court(
|
||||
self,
|
||||
corner_points: Optional[List[Tuple[float, float]]] = None,
|
||||
interactive: bool = False
|
||||
) -> bool:
|
||||
"""
|
||||
Calibrate court for coordinate transformation
|
||||
|
||||
Args:
|
||||
corner_points: Manual corner points [TL, TR, BR, BL], or None for interactive
|
||||
interactive: If True, use interactive calibration tool
|
||||
|
||||
Returns:
|
||||
True if calibration successful
|
||||
"""
|
||||
if corner_points is not None:
|
||||
return self.calibrator.calibrate_manual(corner_points)
|
||||
|
||||
if interactive:
|
||||
# Get first frame
|
||||
self.cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
|
||||
ret, frame = self.cap.read()
|
||||
if not ret:
|
||||
print("Error: Could not read first frame")
|
||||
return False
|
||||
|
||||
# Interactive calibration
|
||||
calibrator = InteractiveCalibrator()
|
||||
points = calibrator.calibrate_interactive(frame)
|
||||
|
||||
if points is None:
|
||||
return False
|
||||
|
||||
return self.calibrator.calibrate_manual(points)
|
||||
|
||||
print("Warning: No calibration method specified")
|
||||
print("Processing will continue without coordinate transformation")
|
||||
return False
|
||||
|
||||
def process_video(
|
||||
self,
|
||||
output_path: Optional[str] = None,
|
||||
save_visualization: bool = False,
|
||||
visualization_path: Optional[str] = None,
|
||||
start_frame: int = 0,
|
||||
end_frame: Optional[int] = None
|
||||
) -> Dict:
|
||||
"""
|
||||
Process video and extract ball trajectory
|
||||
|
||||
Args:
|
||||
output_path: Path to save JSON results
|
||||
save_visualization: If True, save video with annotations
|
||||
visualization_path: Path to save visualization video
|
||||
start_frame: Frame to start processing from
|
||||
end_frame: Frame to end processing at (None = end of video)
|
||||
|
||||
Returns:
|
||||
Dictionary with processing results
|
||||
"""
|
||||
if self.cap is None:
|
||||
raise ValueError("No video loaded. Call load_video() first")
|
||||
|
||||
# Reset tracker
|
||||
self.tracker.reset()
|
||||
|
||||
# Set frame range
|
||||
if end_frame is None:
|
||||
end_frame = self.total_frames
|
||||
|
||||
# Setup visualization writer if needed
|
||||
video_writer = None
|
||||
if save_visualization:
|
||||
if visualization_path is None:
|
||||
visualization_path = str(Path(self.video_path).stem) + "_tracked.mp4"
|
||||
|
||||
frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||||
video_writer = cv2.VideoWriter(
|
||||
visualization_path,
|
||||
fourcc,
|
||||
self.fps,
|
||||
(frame_width, frame_height)
|
||||
)
|
||||
|
||||
# Processing results
|
||||
results = {
|
||||
'video_path': self.video_path,
|
||||
'fps': self.fps,
|
||||
'total_frames': end_frame - start_frame,
|
||||
'duration_sec': (end_frame - start_frame) / self.fps,
|
||||
'court': {
|
||||
'width_m': self.config['court']['width_m'],
|
||||
'length_m': self.config['court']['length_m'],
|
||||
'calibrated': self.calibrator.is_calibrated()
|
||||
},
|
||||
'frames': []
|
||||
}
|
||||
|
||||
# Process frames
|
||||
self.cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
|
||||
|
||||
print(f"\nProcessing frames {start_frame} to {end_frame}...")
|
||||
start_time = time.time()
|
||||
|
||||
for frame_num in tqdm(range(start_frame, end_frame)):
|
||||
ret, frame = self.cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
# Detect ball
|
||||
detections = self.detector.detect(frame)
|
||||
|
||||
# Track ball
|
||||
ball = self.tracker.update(detections, frame_num)
|
||||
|
||||
# Transform to real-world coordinates if calibrated
|
||||
frame_data = {
|
||||
'frame_number': frame_num,
|
||||
'timestamp': frame_num / self.fps,
|
||||
'ball': None
|
||||
}
|
||||
|
||||
if ball is not None:
|
||||
pixel_coords = ball['center']
|
||||
real_coords = None
|
||||
|
||||
if self.calibrator.is_calibrated():
|
||||
real_coords = self.calibrator.pixel_to_real(pixel_coords)
|
||||
|
||||
frame_data['ball'] = {
|
||||
'detected': True,
|
||||
'pixel_coords': {
|
||||
'x': float(pixel_coords[0]),
|
||||
'y': float(pixel_coords[1])
|
||||
},
|
||||
'real_coords_m': {
|
||||
'x': float(real_coords[0]) if real_coords else None,
|
||||
'y': float(real_coords[1]) if real_coords else None
|
||||
} if real_coords else None,
|
||||
'confidence': float(ball['confidence']),
|
||||
'interpolated': ball.get('interpolated', False)
|
||||
}
|
||||
|
||||
results['frames'].append(frame_data)
|
||||
|
||||
# Draw visualization if needed
|
||||
if save_visualization:
|
||||
vis_frame = self._draw_visualization(frame, ball, frame_num)
|
||||
video_writer.write(vis_frame)
|
||||
|
||||
# Cleanup
|
||||
if video_writer is not None:
|
||||
video_writer.release()
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
results['processing_time_sec'] = processing_time
|
||||
results['fps_processing'] = (end_frame - start_frame) / processing_time
|
||||
|
||||
print(f"\n✓ Processing complete!")
|
||||
print(f" Time: {processing_time:.2f} seconds")
|
||||
print(f" Speed: {results['fps_processing']:.2f} FPS")
|
||||
|
||||
# Save results
|
||||
if output_path:
|
||||
self.save_results(results, output_path)
|
||||
|
||||
return results
|
||||
|
||||
def _draw_visualization(
|
||||
self,
|
||||
frame: np.ndarray,
|
||||
ball: Optional[Dict],
|
||||
frame_num: int
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Draw visualization on frame
|
||||
|
||||
Args:
|
||||
frame: Input frame
|
||||
ball: Ball detection
|
||||
frame_num: Current frame number
|
||||
|
||||
Returns:
|
||||
Frame with visualization
|
||||
"""
|
||||
vis = frame.copy()
|
||||
|
||||
# Draw court overlay if calibrated
|
||||
if self.calibrator.is_calibrated():
|
||||
vis = self.calibrator.draw_court_overlay(vis)
|
||||
|
||||
# Draw ball trajectory
|
||||
vis = self.tracker.draw_trajectory(vis)
|
||||
|
||||
# Draw current ball
|
||||
if ball is not None:
|
||||
center = ball['center']
|
||||
color = (0, 255, 0) if not ball.get('interpolated', False) else (255, 0, 0)
|
||||
|
||||
# Draw bounding box
|
||||
bbox = ball.get('bbox', self.tracker._estimate_bbox(center))
|
||||
cv2.rectangle(
|
||||
vis,
|
||||
(int(bbox[0]), int(bbox[1])),
|
||||
(int(bbox[2]), int(bbox[3])),
|
||||
color,
|
||||
2
|
||||
)
|
||||
|
||||
# Draw center point
|
||||
cv2.circle(vis, (int(center[0]), int(center[1])), 5, color, -1)
|
||||
|
||||
# Draw confidence
|
||||
if ball['confidence'] > 0:
|
||||
cv2.putText(
|
||||
vis,
|
||||
f"{ball['confidence']:.2f}",
|
||||
(int(center[0]) + 10, int(center[1]) - 10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.5,
|
||||
color,
|
||||
2
|
||||
)
|
||||
|
||||
# Draw frame number
|
||||
cv2.putText(
|
||||
vis,
|
||||
f"Frame: {frame_num}",
|
||||
(10, 30),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
1,
|
||||
(255, 255, 255),
|
||||
2
|
||||
)
|
||||
|
||||
return vis
|
||||
|
||||
def save_results(self, results: Dict, output_path: str):
|
||||
"""
|
||||
Save processing results to JSON file
|
||||
|
||||
Args:
|
||||
results: Results dictionary
|
||||
output_path: Path to save JSON
|
||||
"""
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(output_path, 'w') as f:
|
||||
json.dump(results, f, indent=2)
|
||||
|
||||
print(f"✓ Results saved to {output_path}")
|
||||
|
||||
def close(self):
|
||||
"""Release video resources"""
|
||||
if self.cap is not None:
|
||||
self.cap.release()
|
||||
Reference in New Issue
Block a user