Files
pickle_vision/jetson/ball_detection_stream.py
Ruslan Bakiev 549fd1da9d Initial commit
2026-03-06 09:43:52 +07:00

203 lines
6.5 KiB
Python

#!/usr/bin/env python3
"""
Pickleball detection using YOLOv8 on Jetson.
Works with video file or camera input.
Outputs RTSP stream with bounding boxes around detected balls.
"""
import cv2
import time
import argparse
import gi
gi.require_version('Gst', '1.0')
gi.require_version('GstRtspServer', '1.0')
from gi.repository import Gst, GstRtspServer, GLib
from ultralytics import YOLO
import threading
# COCO class 32 = sports ball
BALL_CLASS_ID = 32
# Stream settings
STREAM_WIDTH = 1280
STREAM_HEIGHT = 720
FPS = 30
class RTSPServer:
"""Simple RTSP server using GStreamer."""
def __init__(self, port=8554):
Gst.init(None)
self.server = GstRtspServer.RTSPServer()
self.server.set_service(str(port))
self.factory = GstRtspServer.RTSPMediaFactory()
self.factory.set_launch(
'( appsrc name=source is-live=true block=true format=GST_FORMAT_TIME '
'caps=video/x-raw,format=BGR,width=1280,height=720,framerate=30/1 ! '
'videoconvert ! x264enc tune=zerolatency bitrate=2000 speed-preset=ultrafast ! '
'rtph264pay name=pay0 pt=96 )'
)
self.factory.set_shared(True)
self.server.get_mount_points().add_factory("/live", self.factory)
self.server.attach(None)
print(f"RTSP server started at rtsp://pickle:{port}/live")
def detect_ball(frame, model):
"""Run YOLO detection on frame."""
results = model(frame, verbose=False, classes=[BALL_CLASS_ID], conf=0.3)
detections = []
for result in results:
for box in result.boxes:
x1, y1, x2, y2 = map(int, box.xyxy[0])
conf = float(box.conf[0])
detections.append((x1, y1, x2, y2, conf))
return detections
def draw_detections(frame, detections):
"""Draw bounding boxes on frame."""
for x1, y1, x2, y2, conf in detections:
# Green box for ball
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
label = f"Ball {conf:.2f}"
cv2.putText(frame, label, (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
return frame
def main():
parser = argparse.ArgumentParser(description='Pickleball Detection Stream')
parser.add_argument('--source', type=str, default='0',
help='Video source: 0 for camera, or path to video file')
parser.add_argument('--rtsp-port', type=int, default=8554,
help='RTSP server port')
parser.add_argument('--model', type=str, default='yolov8n.pt',
help='YOLO model to use')
parser.add_argument('--display', action='store_true',
help='Show local display window')
parser.add_argument('--save', type=str, default=None,
help='Save output to video file')
args = parser.parse_args()
print(f"Loading YOLO model: {args.model}")
model = YOLO(args.model)
# Try to use CUDA
try:
model.to("cuda")
print("Using CUDA for inference")
except:
print("CUDA not available, using CPU")
# Open video source
print(f"Opening video source: {args.source}")
if args.source.isdigit():
# Camera
cap = cv2.VideoCapture(int(args.source))
cap.set(cv2.CAP_PROP_FRAME_WIDTH, STREAM_WIDTH)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, STREAM_HEIGHT)
cap.set(cv2.CAP_PROP_FPS, FPS)
else:
# Video file
cap = cv2.VideoCapture(args.source)
if not cap.isOpened():
# Try GStreamer pipeline for CSI camera
print("Trying CSI camera via GStreamer...")
cap = cv2.VideoCapture(
"nvarguscamerasrc ! "
"video/x-raw(memory:NVMM),width=1280,height=720,framerate=30/1 ! "
"nvvidconv ! video/x-raw,format=BGRx ! "
"videoconvert ! video/x-raw,format=BGR ! appsink drop=1",
cv2.CAP_GSTREAMER
)
if not cap.isOpened():
print("ERROR: Cannot open video source!")
return
# Get video properties
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS) or 30
print(f"Video: {width}x{height} @ {fps}fps")
# Setup video writer if saving
out = None
if args.save:
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(args.save, fourcc, fps, (width, height))
print(f"Saving output to: {args.save}")
frame_count = 0
start_time = time.time()
total_detections = 0
print("Starting detection loop... Press Ctrl+C to stop")
try:
while True:
ret, frame = cap.read()
if not ret:
if not args.source.isdigit():
# Video file ended, loop
cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
continue
print("Failed to grab frame")
break
# Resize if needed
if frame.shape[1] != STREAM_WIDTH or frame.shape[0] != STREAM_HEIGHT:
frame = cv2.resize(frame, (STREAM_WIDTH, STREAM_HEIGHT))
# Run detection
detections = detect_ball(frame, model)
total_detections += len(detections)
# Draw detections
frame = draw_detections(frame, detections)
# Add FPS counter
frame_count += 1
if frame_count % 30 == 0:
elapsed = time.time() - start_time
current_fps = frame_count / elapsed
print(f"FPS: {current_fps:.1f}, Frame: {frame_count}, "
f"Detections this frame: {len(detections)}")
# Add FPS to frame
cv2.putText(frame, f"FPS: {frame_count / (time.time() - start_time):.1f}",
(10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2)
# Save if requested
if out:
out.write(frame)
# Display if requested
if args.display:
cv2.imshow("Pickleball Detection", frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
except KeyboardInterrupt:
print("\nStopping...")
finally:
elapsed = time.time() - start_time
print(f"\nProcessed {frame_count} frames in {elapsed:.1f}s")
print(f"Average FPS: {frame_count / elapsed:.1f}")
print(f"Total ball detections: {total_detections}")
cap.release()
if out:
out.release()
cv2.destroyAllWindows()
if __name__ == "__main__":
main()