rave/src/core/group.rs

449 lines
16 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! 流管理器 — 实现 StreamManagerApi trait
//!
/// StreamManager 是所有活跃流的注册中心:
/// - 使用 HashMap<String, Arc<Stream>> 管理流路径到流实例的映射
/// - 实现 `StreamManagerApi` trait供插件通过 `EngineContext` 调用
/// - 读写分离锁RwLock优化读多写少场景
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use crate::core::stream::Stream;
use crate::sdk::traits::{PublisherApi, StreamManagerApi, SubscriberApi};
use crate::sdk::types::{AVFrame, StreamCodecMeta, StreamPath, StreamSummary};
/// 流管理器
///
/// 管理所有活跃流实例的注册表
pub struct StreamManager {
/// 流路径 → 流实例的映射RwLock 保护
streams: RwLock<HashMap<String, Arc<Stream>>>,
}
impl StreamManager {
/// 创建空的流管理器
pub fn new() -> Self {
Self {
streams: RwLock::new(HashMap::new()),
}
}
}
impl StreamManagerApi for StreamManager {
/// 创建新流
///
/// 如果流路径已存在则返回错误
fn create_stream(&self, path: StreamPath) -> Result<Arc<dyn PublisherApi>, String> {
let key = path.full_path();
{
let streams = self.streams.read().unwrap();
if streams.contains_key(&key) {
return Err(format!("stream '{}' already exists", key));
}
}
let stream = Arc::new(Stream::new(path));
let publisher = stream.publisher();
{
let mut streams = self.streams.write().unwrap();
streams.insert(key, stream);
}
Ok(publisher)
}
/// 获取已存在流的 Publisher
fn get_stream(&self, path: &StreamPath) -> Option<Arc<dyn PublisherApi>> {
let key = path.full_path();
let streams = self.streams.read().unwrap();
streams.get(&key).map(|s| s.publisher() as Arc<dyn PublisherApi>)
}
/// 向指定流分发一帧
///
/// 通过 Stream::dispatch_frame 写入 GOP 缓存并广播到所有订阅者。
/// 如果流不存在则静默忽略(发布者可能已断开)
fn dispatch_frame(&self, path: &StreamPath, frame: AVFrame) {
let key = path.full_path();
let streams = self.streams.read().unwrap();
if let Some(stream) = streams.get(&key) {
stream.dispatch_frame(frame);
}
}
/// 订阅流
///
/// 新订阅者会收到 GOP 缓存中的最近关键帧
fn subscribe(&self, path: &StreamPath) -> Result<Arc<dyn SubscriberApi>, String> {
let key = path.full_path();
let streams = self.streams.read().unwrap();
match streams.get(&key) {
Some(stream) => Ok(stream.subscribe() as Arc<dyn SubscriberApi>),
None => Err(format!("stream '{}' not found", key)),
}
}
/// 移除流
fn remove_stream(&self, path: &StreamPath) {
let key = path.full_path();
let mut streams = self.streams.write().unwrap();
streams.remove(&key);
}
/// 获取活跃流数量
fn stream_count(&self) -> usize {
self.streams.read().unwrap().len()
}
/// 检查流是否存在
fn has_stream(&self, path: &StreamPath) -> bool {
let key = path.full_path();
self.streams.read().unwrap().contains_key(&key)
}
/// 列出所有活跃流路径
fn list_streams(&self) -> Vec<StreamPath> {
let streams = self.streams.read().unwrap();
streams.values().map(|s| s.path().clone()).collect()
}
/// 获取所有活跃流的详细信息摘要
fn stream_summaries(&self) -> Vec<StreamSummary> {
let streams = self.streams.read().unwrap();
streams.values().map(|s| s.summary()).collect()
}
/// 获取流的编解码器元数据
fn get_codec_metadata(&self, path: &StreamPath) -> Option<StreamCodecMeta> {
let key = path.full_path();
let streams = self.streams.read().unwrap();
streams.get(&key).map(|s| s.codec_metadata())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sdk::types::{FrameType, StreamPath, VideoCodec, AudioCodec, AVFrame, CodecExtraInfo, H264SeqHeader, AacSeqHeader};
use std::sync::Arc;
#[test]
fn test_stream_manager_new_is_empty() {
let mgr = StreamManager::new();
assert_eq!(mgr.stream_count(), 0);
assert!(mgr.list_streams().is_empty());
}
#[test]
fn test_stream_manager_create_stream_succeeds() {
let mgr = StreamManager::new();
let path = StreamPath::new("live", "test");
let result = mgr.create_stream(path.clone());
assert!(result.is_ok());
assert_eq!(mgr.stream_count(), 1);
assert!(mgr.has_stream(&path));
}
#[test]
fn test_stream_manager_create_duplicate_stream_fails() {
let mgr = StreamManager::new();
let path = StreamPath::new("live", "dup");
mgr.create_stream(path.clone()).unwrap();
let result = mgr.create_stream(path);
assert!(result.is_err());
let err_msg = result.err().unwrap();
assert!(err_msg.contains("already exists"));
}
#[test]
fn test_stream_manager_get_stream_succeeds() {
let mgr = StreamManager::new();
let path = StreamPath::new("live", "get");
mgr.create_stream(path.clone()).unwrap();
let pub_ = mgr.get_stream(&path);
assert!(pub_.is_some());
assert_eq!(pub_.unwrap().stream_path().full_path(), "live/get");
}
#[test]
fn test_stream_manager_get_stream_missing_returns_none() {
let mgr = StreamManager::new();
let path = StreamPath::new("live", "missing");
assert!(mgr.get_stream(&path).is_none());
}
#[test]
fn test_stream_manager_subscribe_succeeds() {
let mgr = StreamManager::new();
let path = StreamPath::new("live", "sub");
mgr.create_stream(path.clone()).unwrap();
let result = mgr.subscribe(&path);
assert!(result.is_ok());
}
#[test]
fn test_stream_manager_subscribe_missing_fails() {
let mgr = StreamManager::new();
let path = StreamPath::new("live", "nosub");
let result = mgr.subscribe(&path);
assert!(result.is_err());
let err_msg = result.err().unwrap();
assert!(err_msg.contains("not found"));
}
#[test]
fn test_stream_manager_remove_stream_succeeds() {
let mgr = StreamManager::new();
let path = StreamPath::new("live", "rm");
mgr.create_stream(path.clone()).unwrap();
assert!(mgr.has_stream(&path));
mgr.remove_stream(&path);
assert!(!mgr.has_stream(&path));
assert_eq!(mgr.stream_count(), 0);
}
#[test]
fn test_stream_manager_list_streams_returns_all() {
let mgr = StreamManager::new();
mgr.create_stream(StreamPath::new("live", "a")).unwrap();
mgr.create_stream(StreamPath::new("live", "b")).unwrap();
let list = mgr.list_streams();
assert_eq!(list.len(), 2);
}
// ========== 流媒体推拉流集成测试 ==========
/// 辅助函数:创建关键帧
fn make_keyframe(ts: u64) -> AVFrame {
AVFrame::new_video(ts, Arc::new(vec![0x65]), VideoCodec::H264, FrameType::KeyFrame)
}
/// 辅助函数:创建中间帧
fn make_interframe(ts: u64) -> AVFrame {
AVFrame::new_video(ts, Arc::new(vec![0x61]), VideoCodec::H264, FrameType::InterFrame)
}
/// 辅助函数:创建音频帧
fn make_audio(ts: u64) -> AVFrame {
AVFrame::new_audio(ts, Arc::new(vec![0xAF]), AudioCodec::Aac)
}
/// 辅助函数:通过 StreamManager 创建流并获取内部 Stream 引用
fn create_stream_and_get_inner(mgr: &StreamManager, path: &StreamPath) -> Arc<Stream> {
mgr.create_stream(path.clone()).unwrap();
mgr.streams.read().unwrap().get(&path.full_path()).unwrap().clone()
}
/// 集成测试:完整的 Publisher 推流 → Subscriber 拉流 数据通路
///
/// 模拟真实推拉流:通过 Stream::dispatch_frame 写入帧并分发到所有订阅者
#[test]
fn test_pub_sub_full_pipeline_succeeds() {
let mgr = StreamManager::new();
let path = StreamPath::new("live", "pipeline");
// 创建流
let stream = create_stream_and_get_inner(&mgr, &path);
// 订阅流
let subscriber = mgr.subscribe(&path).unwrap();
// 通过 dispatch_frame 推流(写入 GOP 缓存 + 分发到订阅者)
stream.dispatch_frame(make_keyframe(0));
stream.dispatch_frame(make_audio(0));
stream.dispatch_frame(make_interframe(33));
stream.dispatch_frame(make_audio(33));
// 验证订阅者按顺序收到所有帧
let f1 = subscriber.read_frame().unwrap();
assert!(f1.is_keyframe());
assert_eq!(f1.timestamp_ms, 0);
let f2 = subscriber.read_frame().unwrap();
assert!(f2.is_audio());
assert_eq!(f2.timestamp_ms, 0);
let f3 = subscriber.read_frame().unwrap();
assert!(!f3.is_keyframe());
assert_eq!(f3.timestamp_ms, 33);
let f4 = subscriber.read_frame().unwrap();
assert!(f4.is_audio());
assert_eq!(f4.timestamp_ms, 33);
// 所有帧读完后返回 None
assert!(subscriber.read_frame().is_none());
}
/// 集成测试GOP 缓存 — 新订阅者能收到缓存的最近关键帧
#[test]
fn test_pub_sub_gop_cache_new_subscriber_gets_cached_keyframe() {
let mgr = StreamManager::new();
let path = StreamPath::new("live", "gop");
let stream = create_stream_and_get_inner(&mgr, &path);
// 写入一个完整 GOP
stream.dispatch_frame(make_keyframe(0));
stream.dispatch_frame(make_interframe(33));
stream.dispatch_frame(make_interframe(66));
// 新关键帧开始新 GOP清空旧缓存
stream.dispatch_frame(make_keyframe(100));
stream.dispatch_frame(make_interframe(133));
// 新订阅者连接后,应该收到第二个 GOP 的缓存帧
let late_sub = mgr.subscribe(&path).unwrap();
let f1 = late_sub.read_frame().unwrap();
assert!(f1.is_keyframe());
assert_eq!(f1.timestamp_ms, 100);
let f2 = late_sub.read_frame().unwrap();
assert_eq!(f2.timestamp_ms, 133);
assert!(late_sub.read_frame().is_none());
}
/// 集成测试:多订阅者并发拉流
#[test]
fn test_pub_sub_multiple_subscribers_all_receive_frames() {
let mgr = StreamManager::new();
let path = StreamPath::new("live", "multi");
let stream = create_stream_and_get_inner(&mgr, &path);
let sub1 = mgr.subscribe(&path).unwrap();
let sub2 = mgr.subscribe(&path).unwrap();
let sub3 = mgr.subscribe(&path).unwrap();
// 推一帧,所有订阅者都应收到
stream.dispatch_frame(make_keyframe(42));
assert_eq!(sub1.read_frame().unwrap().timestamp_ms, 42);
assert_eq!(sub2.read_frame().unwrap().timestamp_ms, 42);
assert_eq!(sub3.read_frame().unwrap().timestamp_ms, 42);
}
/// 集成测试:发布者关闭不影响已推送的帧
#[test]
fn test_pub_sub_close_publisher_frames_already_queued() {
let mgr = StreamManager::new();
let path = StreamPath::new("live", "close");
let stream = create_stream_and_get_inner(&mgr, &path);
let publisher = mgr.get_stream(&path).unwrap();
let subscriber = mgr.subscribe(&path).unwrap();
// 先推送帧到订阅者
stream.dispatch_frame(make_keyframe(0));
stream.dispatch_frame(make_audio(0));
// 关闭发布者
publisher.close();
assert!(!publisher.is_active());
// 关闭后写入应失败
assert!(publisher.write_frame(make_keyframe(1)).is_err());
// 已推送的帧仍可读取
assert_eq!(subscriber.read_frame().unwrap().timestamp_ms, 0);
assert_eq!(subscriber.read_frame().unwrap().timestamp_ms, 0);
assert!(subscriber.read_frame().is_none());
}
/// 集成测试:跨流隔离
#[test]
fn test_pub_sub_different_streams_isolated() {
let mgr = StreamManager::new();
let path_a = StreamPath::new("live", "streamA");
let path_b = StreamPath::new("live", "streamB");
let stream_a = create_stream_and_get_inner(&mgr, &path_a);
let stream_b = create_stream_and_get_inner(&mgr, &path_b);
let sub_a = mgr.subscribe(&path_a).unwrap();
let sub_b = mgr.subscribe(&path_b).unwrap();
stream_a.dispatch_frame(make_keyframe(100));
stream_b.dispatch_frame(make_keyframe(200));
// 各订阅者只应收到自己流的帧
assert_eq!(sub_a.read_frame().unwrap().timestamp_ms, 100);
assert!(sub_a.read_frame().is_none());
assert_eq!(sub_b.read_frame().unwrap().timestamp_ms, 200);
assert!(sub_b.read_frame().is_none());
}
/// 集成测试:背压 — 慢订阅者丢帧而不阻塞发布者
#[test]
fn test_pub_sub_backpressure_slow_subscriber_drops_frames() {
let mgr = StreamManager::new();
let path = StreamPath::new("live", "backpressure");
let stream = create_stream_and_get_inner(&mgr, &path);
// 先订阅然后写入超过队列容量1024的帧数
let slow_sub = mgr.subscribe(&path).unwrap();
for i in 0..1100u64 {
stream.dispatch_frame(make_keyframe(i));
}
// 慢订阅者的最旧帧已被丢弃(队列容量 1024推了 1100 帧)
// 预期丢失了 1100 - 1024 = 76 帧,最早可读帧时间戳应 >= 76
let first = slow_sub.read_frame().unwrap();
assert!(first.timestamp_ms >= 76,
"expected first frame ts >= 76 (dropped 76 frames), got {}", first.timestamp_ms);
// 应仍能读到最后一帧
let mut last_ts = first.timestamp_ms;
while let Some(f) = slow_sub.read_frame() {
last_ts = f.timestamp_ms;
}
assert_eq!(last_ts, 1099, "last frame should be the most recent");
}
/// 测试get_codec_metadata 对不存在的流返回 None
#[test]
fn test_stream_manager_get_codec_metadata_missing_returns_none() {
let mgr = StreamManager::new();
let path = StreamPath::new("live", "missing");
assert!(mgr.get_codec_metadata(&path).is_none());
}
/// 测试get_codec_metadata 从流中提取编解码器元数据
#[test]
fn test_stream_manager_get_codec_metadata_extracts_meta() {
let mgr = StreamManager::new();
let path = StreamPath::new("live", "meta");
let stream = create_stream_and_get_inner(&mgr, &path);
// 写入带 SPS/PPS 的视频 seq_header
let sps = std::sync::Arc::new(vec![0x67, 0x64, 0x00, 0x29, 0xAC]);
let pps = std::sync::Arc::new(vec![0x68, 0xEE, 0x31, 0x12]);
let mut vframe = AVFrame::new_video(0, std::sync::Arc::new(vec![]), VideoCodec::H264, FrameType::KeyFrame);
vframe.codec_info = Some(CodecExtraInfo::H264SeqHeader(H264SeqHeader {
sps: sps.clone(),
pps: pps.clone(),
}));
stream.dispatch_frame(vframe);
// 写入带 AAC config 的音频 seq_header
let aac_config = std::sync::Arc::new(vec![0x12, 0x10]); // 44100Hz, 2ch
let mut aframe = AVFrame::new_audio(0, std::sync::Arc::new(vec![]), AudioCodec::Aac);
aframe.codec_info = Some(CodecExtraInfo::AacSeqHeader(AacSeqHeader {
audio_specific_config: aac_config.clone(),
}));
stream.dispatch_frame(aframe);
let meta = mgr.get_codec_metadata(&path).expect("应返回元数据");
assert!(meta.h264_sps.is_some(), "应有 SPS");
assert!(meta.h264_pps.is_some(), "应有 PPS");
assert!(meta.aac_config.is_some(), "应有 AAC config");
assert_eq!(meta.audio_sample_rate, 44100, "采样率应为 44100");
assert_eq!(meta.audio_channels, 2, "通道数应为 2");
}
}