203 lines
6.5 KiB
Python
203 lines
6.5 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Pickleball detection using YOLOv8 on Jetson.
|
|
Works with video file or camera input.
|
|
Outputs RTSP stream with bounding boxes around detected balls.
|
|
"""
|
|
|
|
import cv2
|
|
import time
|
|
import argparse
|
|
import gi
|
|
gi.require_version('Gst', '1.0')
|
|
gi.require_version('GstRtspServer', '1.0')
|
|
from gi.repository import Gst, GstRtspServer, GLib
|
|
from ultralytics import YOLO
|
|
import threading
|
|
|
|
# COCO class 32 = sports ball
|
|
BALL_CLASS_ID = 32
|
|
|
|
# Stream settings
|
|
STREAM_WIDTH = 1280
|
|
STREAM_HEIGHT = 720
|
|
FPS = 30
|
|
|
|
|
|
class RTSPServer:
|
|
"""Simple RTSP server using GStreamer."""
|
|
|
|
def __init__(self, port=8554):
|
|
Gst.init(None)
|
|
self.server = GstRtspServer.RTSPServer()
|
|
self.server.set_service(str(port))
|
|
self.factory = GstRtspServer.RTSPMediaFactory()
|
|
self.factory.set_launch(
|
|
'( appsrc name=source is-live=true block=true format=GST_FORMAT_TIME '
|
|
'caps=video/x-raw,format=BGR,width=1280,height=720,framerate=30/1 ! '
|
|
'videoconvert ! x264enc tune=zerolatency bitrate=2000 speed-preset=ultrafast ! '
|
|
'rtph264pay name=pay0 pt=96 )'
|
|
)
|
|
self.factory.set_shared(True)
|
|
self.server.get_mount_points().add_factory("/live", self.factory)
|
|
self.server.attach(None)
|
|
print(f"RTSP server started at rtsp://pickle:{port}/live")
|
|
|
|
|
|
def detect_ball(frame, model):
|
|
"""Run YOLO detection on frame."""
|
|
results = model(frame, verbose=False, classes=[BALL_CLASS_ID], conf=0.3)
|
|
|
|
detections = []
|
|
for result in results:
|
|
for box in result.boxes:
|
|
x1, y1, x2, y2 = map(int, box.xyxy[0])
|
|
conf = float(box.conf[0])
|
|
detections.append((x1, y1, x2, y2, conf))
|
|
|
|
return detections
|
|
|
|
|
|
def draw_detections(frame, detections):
|
|
"""Draw bounding boxes on frame."""
|
|
for x1, y1, x2, y2, conf in detections:
|
|
# Green box for ball
|
|
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
|
label = f"Ball {conf:.2f}"
|
|
cv2.putText(frame, label, (x1, y1 - 10),
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
|
|
return frame
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='Pickleball Detection Stream')
|
|
parser.add_argument('--source', type=str, default='0',
|
|
help='Video source: 0 for camera, or path to video file')
|
|
parser.add_argument('--rtsp-port', type=int, default=8554,
|
|
help='RTSP server port')
|
|
parser.add_argument('--model', type=str, default='yolov8n.pt',
|
|
help='YOLO model to use')
|
|
parser.add_argument('--display', action='store_true',
|
|
help='Show local display window')
|
|
parser.add_argument('--save', type=str, default=None,
|
|
help='Save output to video file')
|
|
args = parser.parse_args()
|
|
|
|
print(f"Loading YOLO model: {args.model}")
|
|
model = YOLO(args.model)
|
|
|
|
# Try to use CUDA
|
|
try:
|
|
model.to("cuda")
|
|
print("Using CUDA for inference")
|
|
except:
|
|
print("CUDA not available, using CPU")
|
|
|
|
# Open video source
|
|
print(f"Opening video source: {args.source}")
|
|
if args.source.isdigit():
|
|
# Camera
|
|
cap = cv2.VideoCapture(int(args.source))
|
|
cap.set(cv2.CAP_PROP_FRAME_WIDTH, STREAM_WIDTH)
|
|
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, STREAM_HEIGHT)
|
|
cap.set(cv2.CAP_PROP_FPS, FPS)
|
|
else:
|
|
# Video file
|
|
cap = cv2.VideoCapture(args.source)
|
|
|
|
if not cap.isOpened():
|
|
# Try GStreamer pipeline for CSI camera
|
|
print("Trying CSI camera via GStreamer...")
|
|
cap = cv2.VideoCapture(
|
|
"nvarguscamerasrc ! "
|
|
"video/x-raw(memory:NVMM),width=1280,height=720,framerate=30/1 ! "
|
|
"nvvidconv ! video/x-raw,format=BGRx ! "
|
|
"videoconvert ! video/x-raw,format=BGR ! appsink drop=1",
|
|
cv2.CAP_GSTREAMER
|
|
)
|
|
|
|
if not cap.isOpened():
|
|
print("ERROR: Cannot open video source!")
|
|
return
|
|
|
|
# Get video properties
|
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
fps = cap.get(cv2.CAP_PROP_FPS) or 30
|
|
print(f"Video: {width}x{height} @ {fps}fps")
|
|
|
|
# Setup video writer if saving
|
|
out = None
|
|
if args.save:
|
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
|
out = cv2.VideoWriter(args.save, fourcc, fps, (width, height))
|
|
print(f"Saving output to: {args.save}")
|
|
|
|
frame_count = 0
|
|
start_time = time.time()
|
|
total_detections = 0
|
|
|
|
print("Starting detection loop... Press Ctrl+C to stop")
|
|
|
|
try:
|
|
while True:
|
|
ret, frame = cap.read()
|
|
if not ret:
|
|
if not args.source.isdigit():
|
|
# Video file ended, loop
|
|
cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
|
|
continue
|
|
print("Failed to grab frame")
|
|
break
|
|
|
|
# Resize if needed
|
|
if frame.shape[1] != STREAM_WIDTH or frame.shape[0] != STREAM_HEIGHT:
|
|
frame = cv2.resize(frame, (STREAM_WIDTH, STREAM_HEIGHT))
|
|
|
|
# Run detection
|
|
detections = detect_ball(frame, model)
|
|
total_detections += len(detections)
|
|
|
|
# Draw detections
|
|
frame = draw_detections(frame, detections)
|
|
|
|
# Add FPS counter
|
|
frame_count += 1
|
|
if frame_count % 30 == 0:
|
|
elapsed = time.time() - start_time
|
|
current_fps = frame_count / elapsed
|
|
print(f"FPS: {current_fps:.1f}, Frame: {frame_count}, "
|
|
f"Detections this frame: {len(detections)}")
|
|
|
|
# Add FPS to frame
|
|
cv2.putText(frame, f"FPS: {frame_count / (time.time() - start_time):.1f}",
|
|
(10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2)
|
|
|
|
# Save if requested
|
|
if out:
|
|
out.write(frame)
|
|
|
|
# Display if requested
|
|
if args.display:
|
|
cv2.imshow("Pickleball Detection", frame)
|
|
if cv2.waitKey(1) & 0xFF == ord('q'):
|
|
break
|
|
|
|
except KeyboardInterrupt:
|
|
print("\nStopping...")
|
|
|
|
finally:
|
|
elapsed = time.time() - start_time
|
|
print(f"\nProcessed {frame_count} frames in {elapsed:.1f}s")
|
|
print(f"Average FPS: {frame_count / elapsed:.1f}")
|
|
print(f"Total ball detections: {total_detections}")
|
|
|
|
cap.release()
|
|
if out:
|
|
out.release()
|
|
cv2.destroyAllWindows()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|