Initial commit
This commit is contained in:
222
jetson/rtsp_detection_server.py
Normal file
222
jetson/rtsp_detection_server.py
Normal file
@@ -0,0 +1,222 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
RTSP server with YOLOv8 ball detection for Jetson.
|
||||
Streams video with detections over RTSP.
|
||||
"""
|
||||
|
||||
import cv2
|
||||
import time
|
||||
import argparse
|
||||
import threading
|
||||
import gi
|
||||
gi.require_version('Gst', '1.0')
|
||||
gi.require_version('GstRtspServer', '1.0')
|
||||
from gi.repository import Gst, GstRtspServer, GLib
|
||||
from ultralytics import YOLO
|
||||
import numpy as np
|
||||
|
||||
# COCO class 32 = sports ball
|
||||
BALL_CLASS_ID = 32
|
||||
|
||||
|
||||
class DetectionRTSPServer:
|
||||
"""RTSP server that streams video with YOLO detections."""
|
||||
|
||||
def __init__(self, source, model_path='yolov8n.pt', port=8554, width=1280, height=720, fps=30):
|
||||
self.source = source
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.fps = fps
|
||||
self.port = port
|
||||
self.running = False
|
||||
self.frame = None
|
||||
self.lock = threading.Lock()
|
||||
|
||||
# Load YOLO model
|
||||
print(f"Loading YOLO model: {model_path}")
|
||||
self.model = YOLO(model_path)
|
||||
try:
|
||||
self.model.to("cuda")
|
||||
print("Using CUDA")
|
||||
except:
|
||||
print("Using CPU")
|
||||
|
||||
# Init GStreamer
|
||||
Gst.init(None)
|
||||
|
||||
def detect_and_draw(self, frame):
|
||||
"""Run detection and draw boxes."""
|
||||
results = self.model(frame, verbose=False, classes=[BALL_CLASS_ID], conf=0.25)
|
||||
|
||||
for result in results:
|
||||
for box in result.boxes:
|
||||
x1, y1, x2, y2 = map(int, box.xyxy[0])
|
||||
conf = float(box.conf[0])
|
||||
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 3)
|
||||
cv2.putText(frame, f"Ball {conf:.2f}", (x1, y1 - 10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2)
|
||||
|
||||
return frame
|
||||
|
||||
def capture_loop(self):
|
||||
"""Capture frames and run detection."""
|
||||
print(f"Opening source: {self.source}")
|
||||
|
||||
if self.source.isdigit():
|
||||
cap = cv2.VideoCapture(int(self.source))
|
||||
elif self.source == 'csi':
|
||||
# CSI camera on Jetson
|
||||
cap = cv2.VideoCapture(
|
||||
f"nvarguscamerasrc ! video/x-raw(memory:NVMM),width={self.width},height={self.height},"
|
||||
f"framerate={self.fps}/1 ! nvvidconv ! video/x-raw,format=BGRx ! "
|
||||
f"videoconvert ! video/x-raw,format=BGR ! appsink drop=1",
|
||||
cv2.CAP_GSTREAMER
|
||||
)
|
||||
else:
|
||||
cap = cv2.VideoCapture(self.source)
|
||||
|
||||
if not cap.isOpened():
|
||||
print("ERROR: Cannot open video source!")
|
||||
return
|
||||
|
||||
frame_count = 0
|
||||
start_time = time.time()
|
||||
|
||||
while self.running:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
if not self.source.isdigit() and self.source != 'csi':
|
||||
cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
|
||||
continue
|
||||
break
|
||||
|
||||
# Resize
|
||||
frame = cv2.resize(frame, (self.width, self.height))
|
||||
|
||||
# Detect
|
||||
frame = self.detect_and_draw(frame)
|
||||
|
||||
# FPS overlay
|
||||
frame_count += 1
|
||||
fps = frame_count / (time.time() - start_time)
|
||||
cv2.putText(frame, f"FPS: {fps:.1f}", (10, 30),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2)
|
||||
|
||||
with self.lock:
|
||||
self.frame = frame.copy()
|
||||
|
||||
if frame_count % 100 == 0:
|
||||
print(f"FPS: {fps:.1f}")
|
||||
|
||||
cap.release()
|
||||
|
||||
def start(self):
|
||||
"""Start RTSP server and capture."""
|
||||
self.running = True
|
||||
|
||||
# Start capture thread
|
||||
self.capture_thread = threading.Thread(target=self.capture_loop)
|
||||
self.capture_thread.start()
|
||||
|
||||
# Wait for first frame
|
||||
print("Waiting for first frame...")
|
||||
while self.frame is None and self.running:
|
||||
time.sleep(0.1)
|
||||
|
||||
# Create RTSP server with test source first, then we'll push frames
|
||||
self.server = GstRtspServer.RTSPServer.new()
|
||||
self.server.set_service(str(self.port))
|
||||
|
||||
# Create factory with appsrc
|
||||
self.factory = GstRtspServer.RTSPMediaFactory.new()
|
||||
|
||||
# Pipeline that accepts raw video and encodes to H264
|
||||
launch_str = (
|
||||
f'( appsrc name=mysrc is-live=true block=false format=GST_FORMAT_TIME '
|
||||
f'caps=video/x-raw,format=BGR,width={self.width},height={self.height},framerate={self.fps}/1 ! '
|
||||
f'queue ! videoconvert ! video/x-raw,format=I420 ! '
|
||||
f'x264enc tune=zerolatency bitrate=4000 speed-preset=ultrafast ! '
|
||||
f'rtph264pay config-interval=1 name=pay0 pt=96 )'
|
||||
)
|
||||
|
||||
self.factory.set_launch(launch_str)
|
||||
self.factory.set_shared(True)
|
||||
self.factory.connect('media-configure', self.on_media_configure)
|
||||
|
||||
mounts = self.server.get_mount_points()
|
||||
mounts.add_factory('/live', self.factory)
|
||||
|
||||
self.server.attach(None)
|
||||
print(f"\n{'='*50}")
|
||||
print(f"RTSP stream ready at: rtsp://pickle:{self.port}/live")
|
||||
print(f"{'='*50}\n")
|
||||
|
||||
# Run GLib main loop
|
||||
self.loop = GLib.MainLoop()
|
||||
try:
|
||||
self.loop.run()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
self.stop()
|
||||
|
||||
def on_media_configure(self, factory, media):
|
||||
"""Configure media when client connects."""
|
||||
print("Client connected!")
|
||||
appsrc = media.get_element().get_child_by_name('mysrc')
|
||||
appsrc.connect('need-data', self.on_need_data)
|
||||
|
||||
def on_need_data(self, src, length):
|
||||
"""Push frame to appsrc when needed."""
|
||||
with self.lock:
|
||||
if self.frame is None:
|
||||
return
|
||||
|
||||
frame = self.frame.copy()
|
||||
|
||||
# Create buffer
|
||||
data = frame.tobytes()
|
||||
buf = Gst.Buffer.new_allocate(None, len(data), None)
|
||||
buf.fill(0, data)
|
||||
|
||||
# Set timestamp
|
||||
timestamp = int(time.time() * Gst.SECOND)
|
||||
buf.pts = timestamp
|
||||
buf.duration = int(Gst.SECOND / self.fps)
|
||||
|
||||
src.emit('push-buffer', buf)
|
||||
|
||||
def stop(self):
|
||||
"""Stop server."""
|
||||
self.running = False
|
||||
if hasattr(self, 'capture_thread'):
|
||||
self.capture_thread.join()
|
||||
print("Server stopped")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='RTSP Detection Server')
|
||||
parser.add_argument('--source', type=str, default='csi',
|
||||
help='Video source: csi, 0 (USB cam), or video file path')
|
||||
parser.add_argument('--model', type=str, default='yolov8n.pt',
|
||||
help='YOLO model')
|
||||
parser.add_argument('--port', type=int, default=8554,
|
||||
help='RTSP port')
|
||||
parser.add_argument('--width', type=int, default=1280)
|
||||
parser.add_argument('--height', type=int, default=720)
|
||||
parser.add_argument('--fps', type=int, default=30)
|
||||
args = parser.parse_args()
|
||||
|
||||
server = DetectionRTSPServer(
|
||||
source=args.source,
|
||||
model_path=args.model,
|
||||
port=args.port,
|
||||
width=args.width,
|
||||
height=args.height,
|
||||
fps=args.fps
|
||||
)
|
||||
server.start()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user