278 lines
9.7 KiB
Python
278 lines
9.7 KiB
Python
"""Asset 2: Detect court keypoints using Roboflow Hosted API"""
|
||
|
||
import os
|
||
import cv2
|
||
import numpy as np
|
||
from pathlib import Path
|
||
from typing import Dict, List
|
||
from dagster import asset, AssetExecutionContext
|
||
from inference_sdk import InferenceHTTPClient
|
||
|
||
|
||
@asset(
|
||
io_manager_key="json_io_manager",
|
||
compute_kind="roboflow",
|
||
description="Detect pickleball court corners using Roboflow keypoint detection model"
|
||
)
|
||
def detect_court_keypoints(
|
||
context: AssetExecutionContext,
|
||
extract_video_frames: Dict
|
||
) -> Dict:
|
||
"""
|
||
Detect court keypoints from first frame using Roboflow model
|
||
|
||
Inputs:
|
||
- extract_video_frames: metadata from frame extraction
|
||
- data/frames/frame_0000.jpg: first frame
|
||
|
||
Outputs:
|
||
- data/detect_court_keypoints.json
|
||
|
||
Returns:
|
||
Dict with:
|
||
- corners_pixel: list of 4 corner coordinates [[x,y], ...]
|
||
- court_width_m: court width in meters (6.1)
|
||
- court_length_m: court length in meters (13.4)
|
||
- keypoints: all detected keypoints
|
||
"""
|
||
from inference import get_model
|
||
|
||
frames_dir = Path(extract_video_frames['frames_dir'])
|
||
first_frame_path = frames_dir / "frame_0000.jpg"
|
||
|
||
context.log.info(f"Loading first frame: {first_frame_path}")
|
||
|
||
if not first_frame_path.exists():
|
||
raise FileNotFoundError(f"First frame not found: {first_frame_path}")
|
||
|
||
# Load frame
|
||
frame = cv2.imread(str(first_frame_path))
|
||
h, w = frame.shape[:2]
|
||
context.log.info(f"Frame dimensions: {w}x{h}")
|
||
|
||
# Get API key
|
||
api_key = os.getenv("ROBOFLOW_API_KEY")
|
||
if not api_key:
|
||
context.log.warning("ROBOFLOW_API_KEY not set, using estimated corners")
|
||
corners = _estimate_court_corners(w, h)
|
||
else:
|
||
# Try to detect court using Roboflow Hosted API
|
||
try:
|
||
context.log.info("Detecting court using Roboflow Hosted API...")
|
||
|
||
client = InferenceHTTPClient(
|
||
api_url="https://serverless.roboflow.com",
|
||
api_key=api_key
|
||
)
|
||
|
||
result = client.infer(str(first_frame_path), model_id="pickleball-court-cfyv4/1")
|
||
|
||
# Extract keypoints from result
|
||
all_points = []
|
||
if result and 'predictions' in result and len(result['predictions']) > 0:
|
||
pred = result['predictions'][0]
|
||
if 'points' in pred and len(pred['points']) >= 4:
|
||
# Модель возвращает много points (линии корта)
|
||
all_points = [[p['x'], p['y']] for p in pred['points']]
|
||
context.log.info(f"✓ Detected {len(all_points)} keypoints from court lines")
|
||
|
||
# Находим 4 угла для калибровки (но не для визуализации)
|
||
corners = _extract_court_corners_from_points(all_points, w, h)
|
||
context.log.info(f"✓ Extracted 4 corners from keypoints")
|
||
else:
|
||
context.log.warning("No keypoints in prediction, using estimated corners")
|
||
corners = _estimate_court_corners(w, h)
|
||
else:
|
||
context.log.warning("No predictions from model, using estimated corners")
|
||
corners = _estimate_court_corners(w, h)
|
||
|
||
except Exception as e:
|
||
context.log.warning(f"Court detection failed: {e}. Using estimated corners.")
|
||
corners = _estimate_court_corners(w, h)
|
||
|
||
context.log.info(f"Court corners: {corners}")
|
||
|
||
# Save visualization - рисуем ВСЕ точки и линии от модели
|
||
vis_frame = frame.copy()
|
||
|
||
# Рисуем все точки от модели
|
||
if len(all_points) > 0:
|
||
context.log.info(f"Drawing {len(all_points)} keypoints on visualization")
|
||
|
||
# Рисуем все точки
|
||
for i, point in enumerate(all_points):
|
||
x, y = int(point[0]), int(point[1])
|
||
cv2.circle(vis_frame, (x, y), 5, (0, 255, 0), -1)
|
||
# Подписываем каждую точку номером
|
||
cv2.putText(
|
||
vis_frame,
|
||
str(i),
|
||
(x + 8, y),
|
||
cv2.FONT_HERSHEY_SIMPLEX,
|
||
0.4,
|
||
(255, 255, 0),
|
||
1
|
||
)
|
||
|
||
# Соединяем все соседние точки линиями
|
||
for i in range(len(all_points) - 1):
|
||
p1 = tuple(map(int, all_points[i]))
|
||
p2 = tuple(map(int, all_points[i + 1]))
|
||
cv2.line(vis_frame, p1, p2, (0, 255, 0), 2)
|
||
|
||
# Save visualization with run_id
|
||
run_id = context.run_id
|
||
vis_path = Path(f"data/{run_id}/court_detection_preview.jpg")
|
||
cv2.imwrite(str(vis_path), vis_frame)
|
||
context.log.info(f"Saved court visualization to {vis_path}")
|
||
|
||
return {
|
||
"corners_pixel": corners,
|
||
"court_width_m": 6.1,
|
||
"court_length_m": 13.4,
|
||
"frame_width": w,
|
||
"frame_height": h
|
||
}
|
||
|
||
|
||
def _estimate_court_corners(width: int, height: int) -> List[List[float]]:
|
||
"""
|
||
Estimate court corners based on typical DJI camera position
|
||
(camera in corner at angle)
|
||
|
||
Returns corners in order: [TL, TR, BR, BL]
|
||
"""
|
||
# Assume court takes up ~80% of frame with perspective
|
||
margin_x = width * 0.05
|
||
margin_y = height * 0.1
|
||
|
||
# Perspective: far edge narrower than near edge
|
||
return [
|
||
[margin_x + width * 0.1, margin_y], # Top-left (far)
|
||
[width - margin_x - width * 0.1, margin_y], # Top-right (far)
|
||
[width - margin_x, height - margin_y], # Bottom-right (near)
|
||
[margin_x, height - margin_y] # Bottom-left (near)
|
||
]
|
||
|
||
|
||
def _extract_court_corners_from_points(points: List[List[float]], width: int, height: int) -> List[List[float]]:
|
||
"""
|
||
Extract 4 court corners from many detected points (court lines)
|
||
|
||
Strategy:
|
||
1. Build convex hull from all points
|
||
2. Classify hull points into 4 sides (left, right, top, bottom)
|
||
3. Fit line for each side using linear regression
|
||
4. Find 4 corners as intersections of fitted lines
|
||
|
||
This works even if one corner is not visible on frame (extrapolation)
|
||
"""
|
||
if len(points) < 4:
|
||
return _estimate_court_corners(width, height)
|
||
|
||
# Build convex hull from all points
|
||
points_array = np.array(points, dtype=np.float32)
|
||
hull = cv2.convexHull(points_array)
|
||
hull_points = np.array([p[0] for p in hull], dtype=np.float32)
|
||
|
||
# Classify hull points into 4 sides
|
||
# Strategy: sort hull points by angle from centroid, then split into 4 groups
|
||
center = hull_points.mean(axis=0)
|
||
|
||
# Calculate angle for each point relative to center
|
||
angles = np.arctan2(hull_points[:, 1] - center[1], hull_points[:, 0] - center[0])
|
||
|
||
# Sort points by angle
|
||
sorted_indices = np.argsort(angles)
|
||
sorted_points = hull_points[sorted_indices]
|
||
|
||
# Split into 4 groups (4 sides)
|
||
n = len(sorted_points)
|
||
quarter = n // 4
|
||
|
||
side1 = sorted_points[0:quarter]
|
||
side2 = sorted_points[quarter:2*quarter]
|
||
side3 = sorted_points[2*quarter:3*quarter]
|
||
side4 = sorted_points[3*quarter:]
|
||
|
||
# Fit lines for each side using cv2.fitLine
|
||
def fit_line_coefficients(pts):
|
||
if len(pts) < 2:
|
||
return None
|
||
# cv2.fitLine returns (vx, vy, x0, y0) - direction vector and point on line
|
||
line = cv2.fitLine(pts, cv2.DIST_L2, 0, 0.01, 0.01)
|
||
vx, vy, x0, y0 = line[0][0], line[1][0], line[2][0], line[3][0]
|
||
# Convert to line equation: y = mx + b or vertical line x = c
|
||
if abs(vx) < 1e-6: # Vertical line
|
||
return ('vertical', x0)
|
||
m = vy / vx
|
||
b = y0 - m * x0
|
||
return ('normal', m, b)
|
||
|
||
line1 = fit_line_coefficients(side1)
|
||
line2 = fit_line_coefficients(side2)
|
||
line3 = fit_line_coefficients(side3)
|
||
line4 = fit_line_coefficients(side4)
|
||
|
||
lines = [line1, line2, line3, line4]
|
||
|
||
# Find intersections between adjacent sides
|
||
def line_intersection(line_a, line_b):
|
||
if line_a is None or line_b is None:
|
||
return None
|
||
|
||
# Handle vertical lines
|
||
if line_a[0] == 'vertical' and line_b[0] == 'vertical':
|
||
return None
|
||
elif line_a[0] == 'vertical':
|
||
x = line_a[1]
|
||
m2, b2 = line_b[1], line_b[2]
|
||
y = m2 * x + b2
|
||
return [float(x), float(y)]
|
||
elif line_b[0] == 'vertical':
|
||
x = line_b[1]
|
||
m1, b1 = line_a[1], line_a[2]
|
||
y = m1 * x + b1
|
||
return [float(x), float(y)]
|
||
else:
|
||
m1, b1 = line_a[1], line_a[2]
|
||
m2, b2 = line_b[1], line_b[2]
|
||
|
||
if abs(m1 - m2) < 1e-6: # Parallel lines
|
||
return None
|
||
|
||
x = (b2 - b1) / (m1 - m2)
|
||
y = m1 * x + b1
|
||
return [float(x), float(y)]
|
||
|
||
# Find 4 corners as intersections
|
||
corners = []
|
||
for i in range(4):
|
||
next_i = (i + 1) % 4
|
||
corner = line_intersection(lines[i], lines[next_i])
|
||
if corner:
|
||
corners.append(corner)
|
||
|
||
# If we got 4 corners, return them
|
||
if len(corners) == 4:
|
||
return corners
|
||
|
||
# Fallback: use convex hull extreme points
|
||
tl = hull_points[np.argmin(hull_points[:, 0] + hull_points[:, 1])].tolist()
|
||
tr = hull_points[np.argmax(hull_points[:, 0] - hull_points[:, 1])].tolist()
|
||
br = hull_points[np.argmax(hull_points[:, 0] + hull_points[:, 1])].tolist()
|
||
bl = hull_points[np.argmin(hull_points[:, 0] - hull_points[:, 1])].tolist()
|
||
|
||
return [tl, tr, br, bl]
|
||
|
||
|
||
def _extract_court_corners(keypoints: List[Dict], width: int, height: int) -> List[List[float]]:
|
||
"""
|
||
Extract 4 court corners from detected keypoints (old function for compatibility)
|
||
"""
|
||
if len(keypoints) < 4:
|
||
return _estimate_court_corners(width, height)
|
||
|
||
points = [[kp['x'], kp['y']] for kp in keypoints]
|
||
return _extract_court_corners_from_points(points, width, height)
|