Initial commit
This commit is contained in:
277
dagster_project/assets/court_detection.py
Normal file
277
dagster_project/assets/court_detection.py
Normal file
@@ -0,0 +1,277 @@
|
||||
"""Asset 2: Detect court keypoints using Roboflow Hosted API"""
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
from dagster import asset, AssetExecutionContext
|
||||
from inference_sdk import InferenceHTTPClient
|
||||
|
||||
|
||||
@asset(
|
||||
io_manager_key="json_io_manager",
|
||||
compute_kind="roboflow",
|
||||
description="Detect pickleball court corners using Roboflow keypoint detection model"
|
||||
)
|
||||
def detect_court_keypoints(
|
||||
context: AssetExecutionContext,
|
||||
extract_video_frames: Dict
|
||||
) -> Dict:
|
||||
"""
|
||||
Detect court keypoints from first frame using Roboflow model
|
||||
|
||||
Inputs:
|
||||
- extract_video_frames: metadata from frame extraction
|
||||
- data/frames/frame_0000.jpg: first frame
|
||||
|
||||
Outputs:
|
||||
- data/detect_court_keypoints.json
|
||||
|
||||
Returns:
|
||||
Dict with:
|
||||
- corners_pixel: list of 4 corner coordinates [[x,y], ...]
|
||||
- court_width_m: court width in meters (6.1)
|
||||
- court_length_m: court length in meters (13.4)
|
||||
- keypoints: all detected keypoints
|
||||
"""
|
||||
from inference import get_model
|
||||
|
||||
frames_dir = Path(extract_video_frames['frames_dir'])
|
||||
first_frame_path = frames_dir / "frame_0000.jpg"
|
||||
|
||||
context.log.info(f"Loading first frame: {first_frame_path}")
|
||||
|
||||
if not first_frame_path.exists():
|
||||
raise FileNotFoundError(f"First frame not found: {first_frame_path}")
|
||||
|
||||
# Load frame
|
||||
frame = cv2.imread(str(first_frame_path))
|
||||
h, w = frame.shape[:2]
|
||||
context.log.info(f"Frame dimensions: {w}x{h}")
|
||||
|
||||
# Get API key
|
||||
api_key = os.getenv("ROBOFLOW_API_KEY")
|
||||
if not api_key:
|
||||
context.log.warning("ROBOFLOW_API_KEY not set, using estimated corners")
|
||||
corners = _estimate_court_corners(w, h)
|
||||
else:
|
||||
# Try to detect court using Roboflow Hosted API
|
||||
try:
|
||||
context.log.info("Detecting court using Roboflow Hosted API...")
|
||||
|
||||
client = InferenceHTTPClient(
|
||||
api_url="https://serverless.roboflow.com",
|
||||
api_key=api_key
|
||||
)
|
||||
|
||||
result = client.infer(str(first_frame_path), model_id="pickleball-court-cfyv4/1")
|
||||
|
||||
# Extract keypoints from result
|
||||
all_points = []
|
||||
if result and 'predictions' in result and len(result['predictions']) > 0:
|
||||
pred = result['predictions'][0]
|
||||
if 'points' in pred and len(pred['points']) >= 4:
|
||||
# Модель возвращает много points (линии корта)
|
||||
all_points = [[p['x'], p['y']] for p in pred['points']]
|
||||
context.log.info(f"✓ Detected {len(all_points)} keypoints from court lines")
|
||||
|
||||
# Находим 4 угла для калибровки (но не для визуализации)
|
||||
corners = _extract_court_corners_from_points(all_points, w, h)
|
||||
context.log.info(f"✓ Extracted 4 corners from keypoints")
|
||||
else:
|
||||
context.log.warning("No keypoints in prediction, using estimated corners")
|
||||
corners = _estimate_court_corners(w, h)
|
||||
else:
|
||||
context.log.warning("No predictions from model, using estimated corners")
|
||||
corners = _estimate_court_corners(w, h)
|
||||
|
||||
except Exception as e:
|
||||
context.log.warning(f"Court detection failed: {e}. Using estimated corners.")
|
||||
corners = _estimate_court_corners(w, h)
|
||||
|
||||
context.log.info(f"Court corners: {corners}")
|
||||
|
||||
# Save visualization - рисуем ВСЕ точки и линии от модели
|
||||
vis_frame = frame.copy()
|
||||
|
||||
# Рисуем все точки от модели
|
||||
if len(all_points) > 0:
|
||||
context.log.info(f"Drawing {len(all_points)} keypoints on visualization")
|
||||
|
||||
# Рисуем все точки
|
||||
for i, point in enumerate(all_points):
|
||||
x, y = int(point[0]), int(point[1])
|
||||
cv2.circle(vis_frame, (x, y), 5, (0, 255, 0), -1)
|
||||
# Подписываем каждую точку номером
|
||||
cv2.putText(
|
||||
vis_frame,
|
||||
str(i),
|
||||
(x + 8, y),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.4,
|
||||
(255, 255, 0),
|
||||
1
|
||||
)
|
||||
|
||||
# Соединяем все соседние точки линиями
|
||||
for i in range(len(all_points) - 1):
|
||||
p1 = tuple(map(int, all_points[i]))
|
||||
p2 = tuple(map(int, all_points[i + 1]))
|
||||
cv2.line(vis_frame, p1, p2, (0, 255, 0), 2)
|
||||
|
||||
# Save visualization with run_id
|
||||
run_id = context.run_id
|
||||
vis_path = Path(f"data/{run_id}/court_detection_preview.jpg")
|
||||
cv2.imwrite(str(vis_path), vis_frame)
|
||||
context.log.info(f"Saved court visualization to {vis_path}")
|
||||
|
||||
return {
|
||||
"corners_pixel": corners,
|
||||
"court_width_m": 6.1,
|
||||
"court_length_m": 13.4,
|
||||
"frame_width": w,
|
||||
"frame_height": h
|
||||
}
|
||||
|
||||
|
||||
def _estimate_court_corners(width: int, height: int) -> List[List[float]]:
|
||||
"""
|
||||
Estimate court corners based on typical DJI camera position
|
||||
(camera in corner at angle)
|
||||
|
||||
Returns corners in order: [TL, TR, BR, BL]
|
||||
"""
|
||||
# Assume court takes up ~80% of frame with perspective
|
||||
margin_x = width * 0.05
|
||||
margin_y = height * 0.1
|
||||
|
||||
# Perspective: far edge narrower than near edge
|
||||
return [
|
||||
[margin_x + width * 0.1, margin_y], # Top-left (far)
|
||||
[width - margin_x - width * 0.1, margin_y], # Top-right (far)
|
||||
[width - margin_x, height - margin_y], # Bottom-right (near)
|
||||
[margin_x, height - margin_y] # Bottom-left (near)
|
||||
]
|
||||
|
||||
|
||||
def _extract_court_corners_from_points(points: List[List[float]], width: int, height: int) -> List[List[float]]:
|
||||
"""
|
||||
Extract 4 court corners from many detected points (court lines)
|
||||
|
||||
Strategy:
|
||||
1. Build convex hull from all points
|
||||
2. Classify hull points into 4 sides (left, right, top, bottom)
|
||||
3. Fit line for each side using linear regression
|
||||
4. Find 4 corners as intersections of fitted lines
|
||||
|
||||
This works even if one corner is not visible on frame (extrapolation)
|
||||
"""
|
||||
if len(points) < 4:
|
||||
return _estimate_court_corners(width, height)
|
||||
|
||||
# Build convex hull from all points
|
||||
points_array = np.array(points, dtype=np.float32)
|
||||
hull = cv2.convexHull(points_array)
|
||||
hull_points = np.array([p[0] for p in hull], dtype=np.float32)
|
||||
|
||||
# Classify hull points into 4 sides
|
||||
# Strategy: sort hull points by angle from centroid, then split into 4 groups
|
||||
center = hull_points.mean(axis=0)
|
||||
|
||||
# Calculate angle for each point relative to center
|
||||
angles = np.arctan2(hull_points[:, 1] - center[1], hull_points[:, 0] - center[0])
|
||||
|
||||
# Sort points by angle
|
||||
sorted_indices = np.argsort(angles)
|
||||
sorted_points = hull_points[sorted_indices]
|
||||
|
||||
# Split into 4 groups (4 sides)
|
||||
n = len(sorted_points)
|
||||
quarter = n // 4
|
||||
|
||||
side1 = sorted_points[0:quarter]
|
||||
side2 = sorted_points[quarter:2*quarter]
|
||||
side3 = sorted_points[2*quarter:3*quarter]
|
||||
side4 = sorted_points[3*quarter:]
|
||||
|
||||
# Fit lines for each side using cv2.fitLine
|
||||
def fit_line_coefficients(pts):
|
||||
if len(pts) < 2:
|
||||
return None
|
||||
# cv2.fitLine returns (vx, vy, x0, y0) - direction vector and point on line
|
||||
line = cv2.fitLine(pts, cv2.DIST_L2, 0, 0.01, 0.01)
|
||||
vx, vy, x0, y0 = line[0][0], line[1][0], line[2][0], line[3][0]
|
||||
# Convert to line equation: y = mx + b or vertical line x = c
|
||||
if abs(vx) < 1e-6: # Vertical line
|
||||
return ('vertical', x0)
|
||||
m = vy / vx
|
||||
b = y0 - m * x0
|
||||
return ('normal', m, b)
|
||||
|
||||
line1 = fit_line_coefficients(side1)
|
||||
line2 = fit_line_coefficients(side2)
|
||||
line3 = fit_line_coefficients(side3)
|
||||
line4 = fit_line_coefficients(side4)
|
||||
|
||||
lines = [line1, line2, line3, line4]
|
||||
|
||||
# Find intersections between adjacent sides
|
||||
def line_intersection(line_a, line_b):
|
||||
if line_a is None or line_b is None:
|
||||
return None
|
||||
|
||||
# Handle vertical lines
|
||||
if line_a[0] == 'vertical' and line_b[0] == 'vertical':
|
||||
return None
|
||||
elif line_a[0] == 'vertical':
|
||||
x = line_a[1]
|
||||
m2, b2 = line_b[1], line_b[2]
|
||||
y = m2 * x + b2
|
||||
return [float(x), float(y)]
|
||||
elif line_b[0] == 'vertical':
|
||||
x = line_b[1]
|
||||
m1, b1 = line_a[1], line_a[2]
|
||||
y = m1 * x + b1
|
||||
return [float(x), float(y)]
|
||||
else:
|
||||
m1, b1 = line_a[1], line_a[2]
|
||||
m2, b2 = line_b[1], line_b[2]
|
||||
|
||||
if abs(m1 - m2) < 1e-6: # Parallel lines
|
||||
return None
|
||||
|
||||
x = (b2 - b1) / (m1 - m2)
|
||||
y = m1 * x + b1
|
||||
return [float(x), float(y)]
|
||||
|
||||
# Find 4 corners as intersections
|
||||
corners = []
|
||||
for i in range(4):
|
||||
next_i = (i + 1) % 4
|
||||
corner = line_intersection(lines[i], lines[next_i])
|
||||
if corner:
|
||||
corners.append(corner)
|
||||
|
||||
# If we got 4 corners, return them
|
||||
if len(corners) == 4:
|
||||
return corners
|
||||
|
||||
# Fallback: use convex hull extreme points
|
||||
tl = hull_points[np.argmin(hull_points[:, 0] + hull_points[:, 1])].tolist()
|
||||
tr = hull_points[np.argmax(hull_points[:, 0] - hull_points[:, 1])].tolist()
|
||||
br = hull_points[np.argmax(hull_points[:, 0] + hull_points[:, 1])].tolist()
|
||||
bl = hull_points[np.argmin(hull_points[:, 0] - hull_points[:, 1])].tolist()
|
||||
|
||||
return [tl, tr, br, bl]
|
||||
|
||||
|
||||
def _extract_court_corners(keypoints: List[Dict], width: int, height: int) -> List[List[float]]:
|
||||
"""
|
||||
Extract 4 court corners from detected keypoints (old function for compatibility)
|
||||
"""
|
||||
if len(keypoints) < 4:
|
||||
return _estimate_court_corners(width, height)
|
||||
|
||||
points = [[kp['x'], kp['y']] for kp in keypoints]
|
||||
return _extract_court_corners_from_points(points, width, height)
|
||||
Reference in New Issue
Block a user