diff --git a/server/entry.ts b/server/entry.ts index 228f426..01e8a3c 100644 --- a/server/entry.ts +++ b/server/entry.ts @@ -1,5 +1,6 @@ import { constants, access } from 'node:fs/promises'; import { createServer } from 'node:http'; +import { WebSocketServer } from 'ws'; import { hp_getConfig, hp_loadConfig } from '~server/context/loader'; import { listener } from '~server/listener'; import log from '~server/utils/log'; @@ -21,18 +22,13 @@ await hp_loadConfig(); const server = createServer(listener); const context = hp_getConfig(); -const ws = initWebsocket(context.server.agent.authkey); -if (ws) { +if (context.server.agent.authkey.length > 0) { + const ws = new WebSocketServer({ server }); + initWebsocket(ws, context.server.agent.authkey); await hp_loadAgentCache( context.server.agent.ttl, context.server.agent.cache_path, ); - - server.on('upgrade', (req, socket, head) => { - ws.handleUpgrade(req, socket, head, (ws) => { - ws.emit('connection', ws, req); - }); - }); } server.listen(context.server.port, context.server.host, () => { diff --git a/server/ws/socket.ts b/server/ws/socket.ts index 11716cc..73766e0 100644 --- a/server/ws/socket.ts +++ b/server/ws/socket.ts @@ -1,12 +1,7 @@ import WebSocket, { WebSocketServer } from 'ws'; import log from '~server/utils/log'; -const server = new WebSocketServer({ noServer: true }); -export function initWebsocket(authKey: string) { - if (authKey.length === 0) { - return; - } - +export function initWebsocket(server: WebSocketServer, authKey: string) { log.info('SRVX', 'Starting a WebSocket server for agent connections'); server.on('connection', (ws, req) => { const tailnetID = req.headers['x-headplane-tailnet-id'];