Initial commit
This commit is contained in:
202
jetson/ball_detection_stream.py
Normal file
202
jetson/ball_detection_stream.py
Normal file
@@ -0,0 +1,202 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Pickleball detection using YOLOv8 on Jetson.
|
||||
Works with video file or camera input.
|
||||
Outputs RTSP stream with bounding boxes around detected balls.
|
||||
"""
|
||||
|
||||
import cv2
|
||||
import time
|
||||
import argparse
|
||||
import gi
|
||||
gi.require_version('Gst', '1.0')
|
||||
gi.require_version('GstRtspServer', '1.0')
|
||||
from gi.repository import Gst, GstRtspServer, GLib
|
||||
from ultralytics import YOLO
|
||||
import threading
|
||||
|
||||
# COCO class 32 = sports ball
|
||||
BALL_CLASS_ID = 32
|
||||
|
||||
# Stream settings
|
||||
STREAM_WIDTH = 1280
|
||||
STREAM_HEIGHT = 720
|
||||
FPS = 30
|
||||
|
||||
|
||||
class RTSPServer:
|
||||
"""Simple RTSP server using GStreamer."""
|
||||
|
||||
def __init__(self, port=8554):
|
||||
Gst.init(None)
|
||||
self.server = GstRtspServer.RTSPServer()
|
||||
self.server.set_service(str(port))
|
||||
self.factory = GstRtspServer.RTSPMediaFactory()
|
||||
self.factory.set_launch(
|
||||
'( appsrc name=source is-live=true block=true format=GST_FORMAT_TIME '
|
||||
'caps=video/x-raw,format=BGR,width=1280,height=720,framerate=30/1 ! '
|
||||
'videoconvert ! x264enc tune=zerolatency bitrate=2000 speed-preset=ultrafast ! '
|
||||
'rtph264pay name=pay0 pt=96 )'
|
||||
)
|
||||
self.factory.set_shared(True)
|
||||
self.server.get_mount_points().add_factory("/live", self.factory)
|
||||
self.server.attach(None)
|
||||
print(f"RTSP server started at rtsp://pickle:{port}/live")
|
||||
|
||||
|
||||
def detect_ball(frame, model):
|
||||
"""Run YOLO detection on frame."""
|
||||
results = model(frame, verbose=False, classes=[BALL_CLASS_ID], conf=0.3)
|
||||
|
||||
detections = []
|
||||
for result in results:
|
||||
for box in result.boxes:
|
||||
x1, y1, x2, y2 = map(int, box.xyxy[0])
|
||||
conf = float(box.conf[0])
|
||||
detections.append((x1, y1, x2, y2, conf))
|
||||
|
||||
return detections
|
||||
|
||||
|
||||
def draw_detections(frame, detections):
|
||||
"""Draw bounding boxes on frame."""
|
||||
for x1, y1, x2, y2, conf in detections:
|
||||
# Green box for ball
|
||||
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
||||
label = f"Ball {conf:.2f}"
|
||||
cv2.putText(frame, label, (x1, y1 - 10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
|
||||
return frame
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Pickleball Detection Stream')
|
||||
parser.add_argument('--source', type=str, default='0',
|
||||
help='Video source: 0 for camera, or path to video file')
|
||||
parser.add_argument('--rtsp-port', type=int, default=8554,
|
||||
help='RTSP server port')
|
||||
parser.add_argument('--model', type=str, default='yolov8n.pt',
|
||||
help='YOLO model to use')
|
||||
parser.add_argument('--display', action='store_true',
|
||||
help='Show local display window')
|
||||
parser.add_argument('--save', type=str, default=None,
|
||||
help='Save output to video file')
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"Loading YOLO model: {args.model}")
|
||||
model = YOLO(args.model)
|
||||
|
||||
# Try to use CUDA
|
||||
try:
|
||||
model.to("cuda")
|
||||
print("Using CUDA for inference")
|
||||
except:
|
||||
print("CUDA not available, using CPU")
|
||||
|
||||
# Open video source
|
||||
print(f"Opening video source: {args.source}")
|
||||
if args.source.isdigit():
|
||||
# Camera
|
||||
cap = cv2.VideoCapture(int(args.source))
|
||||
cap.set(cv2.CAP_PROP_FRAME_WIDTH, STREAM_WIDTH)
|
||||
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, STREAM_HEIGHT)
|
||||
cap.set(cv2.CAP_PROP_FPS, FPS)
|
||||
else:
|
||||
# Video file
|
||||
cap = cv2.VideoCapture(args.source)
|
||||
|
||||
if not cap.isOpened():
|
||||
# Try GStreamer pipeline for CSI camera
|
||||
print("Trying CSI camera via GStreamer...")
|
||||
cap = cv2.VideoCapture(
|
||||
"nvarguscamerasrc ! "
|
||||
"video/x-raw(memory:NVMM),width=1280,height=720,framerate=30/1 ! "
|
||||
"nvvidconv ! video/x-raw,format=BGRx ! "
|
||||
"videoconvert ! video/x-raw,format=BGR ! appsink drop=1",
|
||||
cv2.CAP_GSTREAMER
|
||||
)
|
||||
|
||||
if not cap.isOpened():
|
||||
print("ERROR: Cannot open video source!")
|
||||
return
|
||||
|
||||
# Get video properties
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
fps = cap.get(cv2.CAP_PROP_FPS) or 30
|
||||
print(f"Video: {width}x{height} @ {fps}fps")
|
||||
|
||||
# Setup video writer if saving
|
||||
out = None
|
||||
if args.save:
|
||||
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||||
out = cv2.VideoWriter(args.save, fourcc, fps, (width, height))
|
||||
print(f"Saving output to: {args.save}")
|
||||
|
||||
frame_count = 0
|
||||
start_time = time.time()
|
||||
total_detections = 0
|
||||
|
||||
print("Starting detection loop... Press Ctrl+C to stop")
|
||||
|
||||
try:
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
if not args.source.isdigit():
|
||||
# Video file ended, loop
|
||||
cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
|
||||
continue
|
||||
print("Failed to grab frame")
|
||||
break
|
||||
|
||||
# Resize if needed
|
||||
if frame.shape[1] != STREAM_WIDTH or frame.shape[0] != STREAM_HEIGHT:
|
||||
frame = cv2.resize(frame, (STREAM_WIDTH, STREAM_HEIGHT))
|
||||
|
||||
# Run detection
|
||||
detections = detect_ball(frame, model)
|
||||
total_detections += len(detections)
|
||||
|
||||
# Draw detections
|
||||
frame = draw_detections(frame, detections)
|
||||
|
||||
# Add FPS counter
|
||||
frame_count += 1
|
||||
if frame_count % 30 == 0:
|
||||
elapsed = time.time() - start_time
|
||||
current_fps = frame_count / elapsed
|
||||
print(f"FPS: {current_fps:.1f}, Frame: {frame_count}, "
|
||||
f"Detections this frame: {len(detections)}")
|
||||
|
||||
# Add FPS to frame
|
||||
cv2.putText(frame, f"FPS: {frame_count / (time.time() - start_time):.1f}",
|
||||
(10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2)
|
||||
|
||||
# Save if requested
|
||||
if out:
|
||||
out.write(frame)
|
||||
|
||||
# Display if requested
|
||||
if args.display:
|
||||
cv2.imshow("Pickleball Detection", frame)
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||
break
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nStopping...")
|
||||
|
||||
finally:
|
||||
elapsed = time.time() - start_time
|
||||
print(f"\nProcessed {frame_count} frames in {elapsed:.1f}s")
|
||||
print(f"Average FPS: {frame_count / elapsed:.1f}")
|
||||
print(f"Total ball detections: {total_detections}")
|
||||
|
||||
cap.release()
|
||||
if out:
|
||||
out.release()
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user