Files

280 lines
9.4 KiB
Rust
Raw Permalink Normal View History

use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use std::sync::atomic::{AtomicU64, Ordering};
use anyhow::{Context, Result};
use dashmap::DashMap;
use quinn::{Endpoint, ServerConfig, TransportConfig};
use quinn::crypto::rustls::QuicServerConfig;
use rcgen::{CertifiedKey, generate_simple_self_signed};
use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
use tokio::net::UdpSocket;
use tracing::{error, info, warn};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
use uuid::Uuid;
mod auth;
mod config;
use auth::validate_relay_token;
use config::Config;
type RoomPeers = DashMap<Uuid, quinn::Connection>;
type Rooms = DashMap<Uuid, RoomPeers>;
static TOTAL_CONNECTIONS: AtomicU64 = AtomicU64::new(0);
static ACTIVE_CONNECTIONS: AtomicU64 = AtomicU64::new(0);
#[tokio::main]
async fn main() -> Result<()> {
dotenvy::dotenv().ok();
tracing_subscriber::registry()
.with(tracing_subscriber::EnvFilter::new(
std::env::var("RUST_LOG")
.unwrap_or_else(|_| "funmc_relay_server=info,quinn=warn".into()),
))
.with(tracing_subscriber::fmt::layer())
.init();
let config = Config::from_env();
info!("╔══════════════════════════════════════════════════════════╗");
info!("║ FunMC 中继服务端 v{} ║", env!("CARGO_PKG_VERSION"));
info!("║ 魔幻方开发 ║");
info!("╠══════════════════════════════════════════════════════════╣");
info!("║ 监听地址: {:43} ║", config.listen_addr);
info!("╚══════════════════════════════════════════════════════════╝");
let rooms: Arc<Rooms> = Arc::new(DashMap::new());
// Start UDP ping responder on same port for latency measurements
let ping_addr = config.listen_addr;
tokio::spawn(async move {
if let Err(e) = run_ping_responder(ping_addr).await {
warn!("Ping responder error: {}", e);
}
});
let server_config = build_server_config()?;
let endpoint = Endpoint::server(server_config, config.listen_addr)
.context("无法绑定 QUIC 端口")?;
info!("QUIC 中继服务已启动,等待连接...");
loop {
match endpoint.accept().await {
Some(incoming) => {
TOTAL_CONNECTIONS.fetch_add(1, Ordering::Relaxed);
ACTIVE_CONNECTIONS.fetch_add(1, Ordering::Relaxed);
let rooms = Arc::clone(&rooms);
let jwt_secret = config.jwt_secret.clone();
tokio::spawn(async move {
if let Err(e) = handle_connection(incoming, rooms, &jwt_secret).await {
warn!("连接处理错误: {}", e);
}
ACTIVE_CONNECTIONS.fetch_sub(1, Ordering::Relaxed);
});
}
None => {
error!("端点已关闭");
break;
}
}
}
Ok(())
}
async fn run_ping_responder(addr: SocketAddr) -> Result<()> {
let socket = match UdpSocket::bind(format!("0.0.0.0:{}", addr.port() + 10000)).await {
Ok(s) => s,
Err(_) => UdpSocket::bind("0.0.0.0:0").await?,
};
info!("Ping responder listening on {}", socket.local_addr()?);
let mut buf = [0u8; 64];
loop {
match socket.recv_from(&mut buf).await {
Ok((len, src)) => {
if len >= 10 && &buf[..10] == b"FUNMC_PING" {
let response = format!("FUNMC_PONG {} {}",
ACTIVE_CONNECTIONS.load(Ordering::Relaxed),
TOTAL_CONNECTIONS.load(Ordering::Relaxed));
let _ = socket.send_to(response.as_bytes(), src).await;
}
}
Err(e) => {
warn!("Ping recv error: {}", e);
}
}
}
}
fn build_server_config() -> Result<ServerConfig> {
2026-02-25 20:35:01 +08:00
let CertifiedKey { cert, key_pair } = generate_simple_self_signed(vec!["fc.funmc.cn".into()])
.context("生成自签名证书失败")?;
let cert_der = CertificateDer::from(cert.der().to_vec());
let key_der = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(key_pair.serialize_der()));
let mut server_crypto = rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(vec![cert_der], key_der)
.context("TLS 配置失败")?;
server_crypto.alpn_protocols = vec![b"funmc".to_vec()];
let mut transport = TransportConfig::default();
transport.max_idle_timeout(Some(Duration::from_secs(60).try_into()?));
transport.keep_alive_interval(Some(Duration::from_secs(10)));
let quic_crypto = QuicServerConfig::try_from(Arc::new(server_crypto))
.context("rustls 配置转为 QUIC 失败")?;
let mut server_config = ServerConfig::with_crypto(Arc::new(quic_crypto));
server_config.transport_config(Arc::new(transport));
Ok(server_config)
}
async fn handle_connection(
incoming: quinn::Incoming,
rooms: Arc<Rooms>,
jwt_secret: &str,
) -> Result<()> {
let conn = incoming.await.context("接受连接失败")?;
let remote = conn.remote_address();
info!("新连接: {}", remote);
let (user_id, room_id) = match authenticate_peer(&conn, jwt_secret).await {
Ok(result) => result,
Err(e) => {
warn!("[{}] 认证失败: {}", remote, e);
conn.close(1u32.into(), b"auth_failed");
return Ok(());
}
};
info!("[{}] 用户 {} 加入房间 {}", remote, user_id, room_id);
let room_peers = rooms.entry(room_id).or_insert_with(DashMap::new);
room_peers.insert(user_id, conn.clone());
loop {
tokio::select! {
stream = conn.accept_bi() => {
match stream {
Ok((send, recv)) => {
let peers = room_peers.clone();
let src_user = user_id;
tokio::spawn(async move {
if let Err(e) = relay_stream(send, recv, peers, src_user).await {
warn!("流中继错误: {}", e);
}
});
}
Err(quinn::ConnectionError::ApplicationClosed(_)) => {
info!("[{}] 用户 {} 主动断开", remote, user_id);
break;
}
Err(e) => {
warn!("[{}] 连接错误: {}", remote, e);
break;
}
}
}
_ = conn.closed() => {
info!("[{}] 连接已关闭", remote);
break;
}
}
}
room_peers.remove(&user_id);
if room_peers.is_empty() {
rooms.remove(&room_id);
info!("房间 {} 已清空并移除", room_id);
}
Ok(())
}
async fn authenticate_peer(conn: &quinn::Connection, jwt_secret: &str) -> Result<(Uuid, Uuid)> {
let mut recv = conn
.accept_uni()
.await
.context("等待认证流超时")?;
let mut len_buf = [0u8; 4];
recv.read_exact(&mut len_buf).await.context("读取长度失败")?;
let len = u32::from_be_bytes(len_buf) as usize;
if len > 4096 {
anyhow::bail!("认证数据过大");
}
let mut buf = vec![0u8; len];
recv.read_exact(&mut buf).await.context("读取认证数据失败")?;
#[derive(serde::Deserialize)]
struct AuthHandshake {
token: String,
room_id: Uuid,
}
let handshake: AuthHandshake =
serde_json::from_slice(&buf).context("解析认证数据失败")?;
let user_id = validate_relay_token(&handshake.token, jwt_secret)?;
Ok((user_id, handshake.room_id))
}
async fn relay_stream(
mut src_send: quinn::SendStream,
mut src_recv: quinn::RecvStream,
peers: DashMap<Uuid, quinn::Connection>,
source_user: Uuid,
) -> Result<()> {
let mut header_buf = [0u8; 17];
src_recv.read_exact(&mut header_buf).await?;
let is_broadcast = header_buf[0] == 0;
let dest_user = if is_broadcast {
None
} else {
Some(Uuid::from_slice(&header_buf[1..17])?)
};
let mut data = Vec::new();
src_recv.read_to_end(1024 * 1024).await.map(|d| data = d.to_vec()).ok();
src_recv.read_to_end(64 * 1024).await?;
let full_payload = [&header_buf[..], &data].concat();
if let Some(target) = dest_user {
if let Some(peer_conn) = peers.get(&target) {
let (mut send, _recv) = peer_conn.open_bi().await?;
send.write_all(&full_payload).await?;
send.finish()?;
}
} else {
for entry in peers.iter() {
if *entry.key() == source_user {
continue;
}
let peer_conn = entry.value();
if let Ok((mut send, _recv)) = peer_conn.open_bi().await {
let _ = send.write_all(&full_payload).await;
let _ = send.finish();
}
}
}
src_send.finish()?;
Ok(())
}