223 lines
6.9 KiB
Python
223 lines
6.9 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
RTSP server with YOLOv8 ball detection for Jetson.
|
|
Streams video with detections over RTSP.
|
|
"""
|
|
|
|
import cv2
|
|
import time
|
|
import argparse
|
|
import threading
|
|
import gi
|
|
gi.require_version('Gst', '1.0')
|
|
gi.require_version('GstRtspServer', '1.0')
|
|
from gi.repository import Gst, GstRtspServer, GLib
|
|
from ultralytics import YOLO
|
|
import numpy as np
|
|
|
|
# COCO class 32 = sports ball
|
|
BALL_CLASS_ID = 32
|
|
|
|
|
|
class DetectionRTSPServer:
|
|
"""RTSP server that streams video with YOLO detections."""
|
|
|
|
def __init__(self, source, model_path='yolov8n.pt', port=8554, width=1280, height=720, fps=30):
|
|
self.source = source
|
|
self.width = width
|
|
self.height = height
|
|
self.fps = fps
|
|
self.port = port
|
|
self.running = False
|
|
self.frame = None
|
|
self.lock = threading.Lock()
|
|
|
|
# Load YOLO model
|
|
print(f"Loading YOLO model: {model_path}")
|
|
self.model = YOLO(model_path)
|
|
try:
|
|
self.model.to("cuda")
|
|
print("Using CUDA")
|
|
except:
|
|
print("Using CPU")
|
|
|
|
# Init GStreamer
|
|
Gst.init(None)
|
|
|
|
def detect_and_draw(self, frame):
|
|
"""Run detection and draw boxes."""
|
|
results = self.model(frame, verbose=False, classes=[BALL_CLASS_ID], conf=0.25)
|
|
|
|
for result in results:
|
|
for box in result.boxes:
|
|
x1, y1, x2, y2 = map(int, box.xyxy[0])
|
|
conf = float(box.conf[0])
|
|
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 3)
|
|
cv2.putText(frame, f"Ball {conf:.2f}", (x1, y1 - 10),
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2)
|
|
|
|
return frame
|
|
|
|
def capture_loop(self):
|
|
"""Capture frames and run detection."""
|
|
print(f"Opening source: {self.source}")
|
|
|
|
if self.source.isdigit():
|
|
cap = cv2.VideoCapture(int(self.source))
|
|
elif self.source == 'csi':
|
|
# CSI camera on Jetson
|
|
cap = cv2.VideoCapture(
|
|
f"nvarguscamerasrc ! video/x-raw(memory:NVMM),width={self.width},height={self.height},"
|
|
f"framerate={self.fps}/1 ! nvvidconv ! video/x-raw,format=BGRx ! "
|
|
f"videoconvert ! video/x-raw,format=BGR ! appsink drop=1",
|
|
cv2.CAP_GSTREAMER
|
|
)
|
|
else:
|
|
cap = cv2.VideoCapture(self.source)
|
|
|
|
if not cap.isOpened():
|
|
print("ERROR: Cannot open video source!")
|
|
return
|
|
|
|
frame_count = 0
|
|
start_time = time.time()
|
|
|
|
while self.running:
|
|
ret, frame = cap.read()
|
|
if not ret:
|
|
if not self.source.isdigit() and self.source != 'csi':
|
|
cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
|
|
continue
|
|
break
|
|
|
|
# Resize
|
|
frame = cv2.resize(frame, (self.width, self.height))
|
|
|
|
# Detect
|
|
frame = self.detect_and_draw(frame)
|
|
|
|
# FPS overlay
|
|
frame_count += 1
|
|
fps = frame_count / (time.time() - start_time)
|
|
cv2.putText(frame, f"FPS: {fps:.1f}", (10, 30),
|
|
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2)
|
|
|
|
with self.lock:
|
|
self.frame = frame.copy()
|
|
|
|
if frame_count % 100 == 0:
|
|
print(f"FPS: {fps:.1f}")
|
|
|
|
cap.release()
|
|
|
|
def start(self):
|
|
"""Start RTSP server and capture."""
|
|
self.running = True
|
|
|
|
# Start capture thread
|
|
self.capture_thread = threading.Thread(target=self.capture_loop)
|
|
self.capture_thread.start()
|
|
|
|
# Wait for first frame
|
|
print("Waiting for first frame...")
|
|
while self.frame is None and self.running:
|
|
time.sleep(0.1)
|
|
|
|
# Create RTSP server with test source first, then we'll push frames
|
|
self.server = GstRtspServer.RTSPServer.new()
|
|
self.server.set_service(str(self.port))
|
|
|
|
# Create factory with appsrc
|
|
self.factory = GstRtspServer.RTSPMediaFactory.new()
|
|
|
|
# Pipeline that accepts raw video and encodes to H264
|
|
launch_str = (
|
|
f'( appsrc name=mysrc is-live=true block=false format=GST_FORMAT_TIME '
|
|
f'caps=video/x-raw,format=BGR,width={self.width},height={self.height},framerate={self.fps}/1 ! '
|
|
f'queue ! videoconvert ! video/x-raw,format=I420 ! '
|
|
f'x264enc tune=zerolatency bitrate=4000 speed-preset=ultrafast ! '
|
|
f'rtph264pay config-interval=1 name=pay0 pt=96 )'
|
|
)
|
|
|
|
self.factory.set_launch(launch_str)
|
|
self.factory.set_shared(True)
|
|
self.factory.connect('media-configure', self.on_media_configure)
|
|
|
|
mounts = self.server.get_mount_points()
|
|
mounts.add_factory('/live', self.factory)
|
|
|
|
self.server.attach(None)
|
|
print(f"\n{'='*50}")
|
|
print(f"RTSP stream ready at: rtsp://pickle:{self.port}/live")
|
|
print(f"{'='*50}\n")
|
|
|
|
# Run GLib main loop
|
|
self.loop = GLib.MainLoop()
|
|
try:
|
|
self.loop.run()
|
|
except KeyboardInterrupt:
|
|
pass
|
|
|
|
self.stop()
|
|
|
|
def on_media_configure(self, factory, media):
|
|
"""Configure media when client connects."""
|
|
print("Client connected!")
|
|
appsrc = media.get_element().get_child_by_name('mysrc')
|
|
appsrc.connect('need-data', self.on_need_data)
|
|
|
|
def on_need_data(self, src, length):
|
|
"""Push frame to appsrc when needed."""
|
|
with self.lock:
|
|
if self.frame is None:
|
|
return
|
|
|
|
frame = self.frame.copy()
|
|
|
|
# Create buffer
|
|
data = frame.tobytes()
|
|
buf = Gst.Buffer.new_allocate(None, len(data), None)
|
|
buf.fill(0, data)
|
|
|
|
# Set timestamp
|
|
timestamp = int(time.time() * Gst.SECOND)
|
|
buf.pts = timestamp
|
|
buf.duration = int(Gst.SECOND / self.fps)
|
|
|
|
src.emit('push-buffer', buf)
|
|
|
|
def stop(self):
|
|
"""Stop server."""
|
|
self.running = False
|
|
if hasattr(self, 'capture_thread'):
|
|
self.capture_thread.join()
|
|
print("Server stopped")
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='RTSP Detection Server')
|
|
parser.add_argument('--source', type=str, default='csi',
|
|
help='Video source: csi, 0 (USB cam), or video file path')
|
|
parser.add_argument('--model', type=str, default='yolov8n.pt',
|
|
help='YOLO model')
|
|
parser.add_argument('--port', type=int, default=8554,
|
|
help='RTSP port')
|
|
parser.add_argument('--width', type=int, default=1280)
|
|
parser.add_argument('--height', type=int, default=720)
|
|
parser.add_argument('--fps', type=int, default=30)
|
|
args = parser.parse_args()
|
|
|
|
server = DetectionRTSPServer(
|
|
source=args.source,
|
|
model_path=args.model,
|
|
port=args.port,
|
|
width=args.width,
|
|
height=args.height,
|
|
fps=args.fps
|
|
)
|
|
server.start()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|