diff --git a/agent/config/config.go b/agent/config/config.go index 6ee7d15..a0cde3c 100644 --- a/agent/config/config.go +++ b/agent/config/config.go @@ -1,8 +1,9 @@ package config import ( - _ "github.com/joho/godotenv/autoload" "os" + + _ "github.com/joho/godotenv/autoload" ) // Config represents the configuration for the agent. @@ -16,12 +17,12 @@ type Config struct { } const ( - DebugEnv = "HP_AGENT_DEBUG" - HostnameEnv = "HP_AGENT_HOSTNAME" - TSControlURLEnv = "HP_AGENT_TS_SERVER" - TSAuthKeyEnv = "HP_AGENT_TS_AUTHKEY" - HPControlURLEnv = "HP_AGENT_HP_SERVER" - HPAuthKeyEnv = "HP_AGENT_HP_AUTHKEY" + DebugEnv = "HEADPLANE_AGENT_DEBUG" + HostnameEnv = "HEADPLANE_AGENT_HOSTNAME" + TSControlURLEnv = "HEADPLANE_AGENT_TS_SERVER" + TSAuthKeyEnv = "HEADPLANE_AGENT_TS_AUTHKEY" + HPControlURLEnv = "HEADPLANE_AGENT_HP_SERVER" + HPAuthKeyEnv = "HEADPLANE_AGENT_HP_AUTHKEY" ) // Load reads the agent configuration from environment variables. diff --git a/agent/hpagent/websocket.go b/agent/hpagent/websocket.go index 06c9056..fc544ca 100644 --- a/agent/hpagent/websocket.go +++ b/agent/hpagent/websocket.go @@ -2,11 +2,12 @@ package hpagent import ( "fmt" - "github.com/gorilla/websocket" - "github.com/tale/headplane/agent/tsnet" "log" "net/http" "net/url" + + "github.com/gorilla/websocket" + "github.com/tale/headplane/agent/tsnet" ) type Socket struct { @@ -23,7 +24,7 @@ func NewSocket(agent *tsnet.TSAgent, controlURL, authKey string, debug bool) (*S } headers := http.Header{} - headers.Add("X-Headplane-TS-Node-ID", agent.ID) + headers.Add("X-Headplane-Tailnet-ID", agent.ID) auth := fmt.Sprintf("Bearer %s", authKey) headers.Add("Authorization", auth) diff --git a/app/components/Header.tsx b/app/components/Header.tsx index 547bd8c..01a2f2c 100644 --- a/app/components/Header.tsx +++ b/app/components/Header.tsx @@ -91,8 +91,19 @@ export default function Header(data: Props) { {data.user ? ( - - + + {data.user.picture ? ( + {data.user.name} + ) : ( + + )} { diff --git a/app/routes.ts b/app/routes.ts index 51a6363..7fcf112 100644 --- a/app/routes.ts +++ b/app/routes.ts @@ -11,6 +11,9 @@ export default [ route('/oidc/callback', 'routes/auth/oidc-callback.ts'), route('/oidc/start', 'routes/auth/oidc-start.ts'), + // API + route('/api/agent', 'routes/api/agent.ts'), + // All the main logged-in dashboard routes // Double nested to separate error propagations layout('layouts/shell.tsx', [ diff --git a/app/routes/api/agent.ts b/app/routes/api/agent.ts new file mode 100644 index 0000000..34810a3 --- /dev/null +++ b/app/routes/api/agent.ts @@ -0,0 +1,39 @@ +import { LoaderFunctionArgs } from 'react-router'; +import type { AppContext } from '~server/context/app'; + +export async function loader({ + request, + context, +}: LoaderFunctionArgs) { + if (!context?.agentData) { + return new Response(JSON.stringify({ error: 'Agent data unavailable' }), { + status: 400, + headers: { + 'Content-Type': 'application/json', + }, + }); + } + + const qp = new URLSearchParams(request.url.split('?')[1]); + const nodeIds = qp.get('node_ids')?.split(','); + if (!nodeIds) { + return new Response(JSON.stringify({ error: 'No node IDs provided' }), { + status: 400, + headers: { + 'Content-Type': 'application/json', + }, + }); + } + + const entries = context.agentData.toJSON(); + const missing = nodeIds.filter((nodeID) => !entries[nodeID]); + if (missing.length > 0) { + await context.hp_agentRequest(missing); + } + + return new Response(JSON.stringify(context.agentData), { + headers: { + 'Content-Type': 'application/json', + }, + }); +} diff --git a/app/routes/auth/login.tsx b/app/routes/auth/login.tsx index 107fd48..d7fa760 100644 --- a/app/routes/auth/login.tsx +++ b/app/routes/auth/login.tsx @@ -11,6 +11,7 @@ import Input from '~/components/Input'; import type { Key } from '~/types'; import { pull } from '~/utils/headscale'; import { noContext } from '~/utils/log'; +import { oidcEnabled } from '~/utils/oidc'; import { commitSession, getSession } from '~/utils/sessions.server'; import type { AppContext } from '~server/context/app'; @@ -33,12 +34,12 @@ export async function loader({ // Only set if OIDC is properly enabled anyways const ctx = context.context; - if (ctx.oidc?.disable_api_key_login) { + if (oidcEnabled() && ctx.oidc?.disable_api_key_login) { return redirect('/oidc/start'); } return { - oidc: ctx.oidc?.issuer, + oidc: oidcEnabled(), apiKey: !ctx.oidc?.disable_api_key_login, }; } @@ -132,7 +133,7 @@ export default function Page() { ) : undefined} - {data.oidc ? ( + {data.oidc === true ? (
{!data.apiKey ? ( diff --git a/app/routes/machines/components/machine.tsx b/app/routes/machines/components/machine.tsx index db955f7..e78932a 100644 --- a/app/routes/machines/components/machine.tsx +++ b/app/routes/machines/components/machine.tsx @@ -15,6 +15,7 @@ interface Props { machine: Machine; routes: Route[]; users: User[]; + isAgent?: boolean; magic?: string; stats?: HostInfo; } @@ -22,8 +23,9 @@ interface Props { export default function MachineRow({ machine, routes, - magic, users, + isAgent, + magic, stats, }: Props) { const expired = @@ -79,6 +81,10 @@ export default function MachineRow({ tags.unshift('Subnets'); } + if (isAgent) { + tags.unshift('Headplane Agent'); + } + const ipOptions = useMemo(() => { if (magic) { return [...machine.ipAddresses, `${machine.givenName}.${prefix}`]; @@ -146,24 +152,18 @@ export default function MachineRow({
- {/** - {stats !== undefined ? ( - <> -

- {hinfo.getTSVersion(stats)} -

+ {stats !== undefined ? ( + <> +

{hinfo.getTSVersion(stats)}

{hinfo.getOSInfo(stats)}

- - ) : ( -

- Unknown -

- )} + + ) : ( +

Unknown

+ )} - **/} ) { const session = await getSession(request.headers.get('Cookie')); if (!params.id) { throw new Error('No machine ID provided'); @@ -44,6 +49,7 @@ export async function loader({ request, params }: LoaderFunctionArgs) { routes: routes.routes.filter((route) => route.node.id === params.id), users: users.users, magic, + agent: context?.agents.includes(machine.node.id), }; } @@ -52,8 +58,10 @@ export async function action({ request }: ActionFunctionArgs) { } export default function Page() { - const { machine, magic, routes, users } = useLoaderData(); + const { machine, magic, routes, users, agent } = + useLoaderData(); const [showRouting, setShowRouting] = useState(false); + console.log(machine.expiry); const expired = machine.expiry === '0001-01-01 00:00:00' || @@ -68,6 +76,10 @@ export default function Page() { tags.unshift('Expired'); } + if (agent) { + tags.unshift('Headplane Agent'); + } + // This is much easier with Object.groupBy but it's too new for us const { exit, subnet, subnetApproved } = routes.reduce<{ exit: Route[]; @@ -148,16 +160,18 @@ export default function Page() { {machine.user.name} -
-

- Status -

-
- {tags.map((tag) => ( - - ))} + {tags.length > 0 ? ( +
+

+ Status +

+
+ {tags.map((tag) => ( + + ))} +
-
+ ) : undefined}

Subnets & Routing

node.nodeKey)); const ctx = context.context; const { mode, config } = hs_getConfig(); - let magic: string | undefined; if (mode !== 'no') { @@ -53,9 +49,9 @@ export async function loader({ routes: routes.routes, users: users.users, magic, - stats, server: ctx.headscale.url, publicServer: ctx.headscale.public_url, + agents: context.agents, }; } @@ -65,6 +61,7 @@ export async function action({ request }: ActionFunctionArgs) { export default function Page() { const data = useLoaderData(); + const { data: stats } = useAgent(data.nodes.map((node) => node.nodeKey)); return ( <> @@ -108,7 +105,7 @@ export default function Page() { ) : undefined} - {/**Version**/} + Version Last Seen @@ -127,7 +124,8 @@ export default function Page() { )} users={data.users} magic={data.magic} - stats={data.stats?.[machine.nodeKey]} + stats={stats?.[machine.nodeKey]} + isAgent={data.agents.includes(machine.id)} /> ))} diff --git a/app/utils/config/parser.ts b/app/utils/config/parser.ts index 7363987..4e0046b 100644 --- a/app/utils/config/parser.ts +++ b/app/utils/config/parser.ts @@ -105,8 +105,8 @@ const headscaleConfig = type({ magic_dns: goBool.default(true), base_domain: 'string = "headscale.net"', nameservers: type({ - 'global?': 'string[]', - 'split': type('Record').default(() => ({})), + global: type('string[]').default(() => []), + split: type('Record').default(() => ({})), }).default(() => ({ global: [], split: {} })), search_domains: type('string[]').default(() => []), extra_records: type({ diff --git a/app/utils/oidc.ts b/app/utils/oidc.ts index 7fce546..8515b5b 100644 --- a/app/utils/oidc.ts +++ b/app/utils/oidc.ts @@ -1,6 +1,7 @@ +import { readFile } from 'node:fs/promises'; import * as client from 'openid-client'; -import log from '~/utils/log'; import type { AppContext } from '~server/context/app'; +import log from '~server/utils/log'; type OidcConfig = NonNullable; declare global { @@ -35,6 +36,57 @@ export function getRedirectUri(req: Request) { return url.href; } +let oidcSecret: string | undefined = undefined; +export function getOidcSecret() { + return oidcSecret; +} + +async function resolveClientSecret(oidc: OidcConfig) { + if (!oidc.client_secret && !oidc.client_secret_path) { + return; + } + + if (oidc.client_secret_path) { + // We need to interpolate environment variables into the path + // Path formatting can be like ${ENV_NAME}/path/to/secret + let path = oidc.client_secret_path; + const matches = path.match(/\${(.*?)}/g); + + if (matches) { + for (const match of matches) { + const env = match.slice(2, -1); + const value = process.env[env]; + if (!value) { + log.error('CFGX', 'Environment variable %s is not set', env); + return; + } + + log.debug('CFGX', 'Interpolating %s with %s', match, value); + path = path.replace(match, value); + } + } + + try { + log.debug('CFGX', 'Reading client secret from %s', path); + const secret = await readFile(path, 'utf-8'); + if (secret.trim().length === 0) { + log.error('CFGX', 'Empty OIDC client secret'); + return; + } + + oidcSecret = secret; + } catch (error) { + log.error('CFGX', 'Failed to read client secret from %s', path); + log.error('CFGX', 'Error: %s', error); + log.debug('CFGX', 'Error details: %o', error); + } + } + + if (oidc.client_secret) { + oidcSecret = oidc.client_secret; + } +} + function clientAuthMethod( method: string, ): (secret: string) => client.ClientAuth { @@ -55,7 +107,7 @@ export async function beginAuthFlow(oidc: OidcConfig, redirect_uri: string) { new URL(oidc.issuer), oidc.client_id, oidc.client_secret, - clientAuthMethod(oidc.token_endpoint_auth_method)(oidc.client_secret), + clientAuthMethod(oidc.token_endpoint_auth_method)(__oidc_context.secret), ); const codeVerifier = client.randomPKCECodeVerifier(); @@ -97,7 +149,7 @@ export async function finishAuthFlow(oidc: OidcConfig, options: FlowOptions) { new URL(oidc.issuer), oidc.client_id, oidc.client_secret, - clientAuthMethod(oidc.token_endpoint_auth_method)(oidc.client_secret), + clientAuthMethod(oidc.token_endpoint_auth_method)(__oidc_context.secret), ); let subject: string; @@ -126,15 +178,41 @@ export async function finishAuthFlow(oidc: OidcConfig, options: FlowOptions) { ); return { - subject: claims.sub, - name: claims.name ? String(claims.name) : 'Anonymous', - email: claims.email ? String(claims.email) : undefined, - username: claims.preferred_username - ? String(claims.preferred_username) - : undefined, + subject: user.sub, + name: getName(user, claims), + email: user.email ?? claims.email?.toString(), + username: user.preferred_username ?? claims.preferred_username?.toString(), + picture: user.picture, }; } +function getName(user: client.UserInfoResponse, claims: client.IDToken) { + if (user.name) { + return user.name; + } + + if (claims.name && typeof claims.name === 'string') { + return claims.name; + } + + if (user.given_name && user.family_name) { + return `${user.given_name} ${user.family_name}`; + } + + if (user.preferred_username) { + return user.preferred_username; + } + + if ( + claims.preferred_username && + typeof claims.preferred_username === 'string' + ) { + return claims.preferred_username; + } + + return 'Anonymous'; +} + export function formatError(error: unknown) { if (error instanceof client.ResponseBodyError) { return { @@ -177,13 +255,27 @@ export function formatError(error: unknown) { }; } +export function oidcEnabled() { + return __oidc_context.valid; +} + export async function testOidc(oidc: OidcConfig) { + await resolveClientSecret(oidc); + if (!oidcSecret) { + log.debug( + 'OIDC', + 'Cannot validate OIDC configuration without a client secret', + ); + return false; + } + log.debug('OIDC', 'Discovering OIDC configuration from %s', oidc.issuer); + const secret = await resolveClientSecret(oidc); const config = await client.discovery( new URL(oidc.issuer), oidc.client_id, oidc.client_secret, - clientAuthMethod(oidc.token_endpoint_auth_method)(oidc.client_secret), + clientAuthMethod(oidc.token_endpoint_auth_method)(oidcSecret), ); const meta = config.serverMetadata(); @@ -214,13 +306,9 @@ export async function testOidc(oidc: OidcConfig) { 'OIDC server does not support %s', oidc.token_endpoint_auth_method, ); + return false; } - } else { - log.warn( - 'OIDC', - 'OIDC server does not advertise token_endpoint_auth_methods_supported', - ); } log.debug('OIDC', 'OIDC configuration is valid'); diff --git a/app/utils/sessions.server.ts b/app/utils/sessions.server.ts index cf6ccf7..dac2f29 100644 --- a/app/utils/sessions.server.ts +++ b/app/utils/sessions.server.ts @@ -12,6 +12,7 @@ export type SessionData = { name: string; email?: string; username?: string; + picture?: string; }; }; diff --git a/app/utils/useAgent.ts b/app/utils/useAgent.ts new file mode 100644 index 0000000..808c609 --- /dev/null +++ b/app/utils/useAgent.ts @@ -0,0 +1,32 @@ +import { useEffect, useMemo, useRef } from 'react'; +import { useFetcher } from 'react-router'; +import { HostInfo } from '~/types'; + +export default function useAgent(nodeIds: string[], interval = 3000) { + const fetcher = useFetcher>(); + const qp = useMemo( + () => new URLSearchParams({ node_ids: nodeIds.join(',') }), + [nodeIds], + ); + + const idRef = useRef([]); + useEffect(() => { + if (idRef.current.join(',') !== nodeIds.join(',')) { + fetcher.load(`/api/agent?${qp.toString()}`); + idRef.current = nodeIds; + } + + const intervalID = setInterval(() => { + fetcher.load(`/api/agent?${qp.toString()}`); + }, interval); + + return () => { + clearInterval(intervalID); + }; + }, [interval, qp]); + + return { + data: fetcher.data, + isLoading: fetcher.state === 'loading', + }; +} diff --git a/config.example.yaml b/config.example.yaml index 4aeff9c..df4ac7a 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -73,7 +73,15 @@ integration: oidc: issuer: "https://accounts.google.com" client_id: "your-client-id" + + # The client secret for the OIDC client + # Either this or `client_secret_path` must be set for OIDC to work client_secret: "" + # You can alternatively set `client_secret_path` to read the secret from disk. + # The path specified can resolve environment variables, making integration + # with systemd's `LoadCredential` straightforward: + # client_secret_path: "${CREDENTIALS_DIRECTORY}/oidc_client_secret" + disable_api_key_login: false token_endpoint_auth_method: "client_secret_post" diff --git a/docs/Headplane-Agent.md b/docs/Headplane-Agent.md index d75075d..828ca38 100644 --- a/docs/Headplane-Agent.md +++ b/docs/Headplane-Agent.md @@ -1,5 +1,8 @@ # Headplane Agent +> This is currently not available in Headplane. +> It is incomplete and will land within the next few releases. + The Headplane agent is a lightweight service that runs alongside the Headscale server. It's used to interface with devices on your network locally, unlocking the following: @@ -19,17 +22,12 @@ Agent binaries are available on the [releases](https://github.com/tale/headplane The Docker image is available through the `ghcr.io/tale/headplane-agent` tag. The agent requires the following environment variables to be set: -- **`HP_AGENT_HOSTNAME`**: A hostname you want to use for the agent. -- **`HP_AGENT_TS_SERVER`**: The URL to your Headscale instance. -- **`HP_AGENT_TS_AUTHKEY`**: An authorization key to authenticate with Headscale (see below). -- **`HP_AGENT_HP_SERVER`**: The URL to your Headplane instance. -- **`HP_AGENT_HP_AUTHKEY`**: The generated auth key to authenticate with Headplane. +- **`HEADPLANE_AGENT_DEBUG`**: Enable debug logging if `true`. +- **`HEADPLANE_AGENT_HOSTNAME`**: A hostname you want to use for the agent. +- **`HEADPLANE_AGENT_TS_SERVER`**: The URL to your Headscale instance. +- **`HEADPLANE_AGENT_TS_AUTHKEY`**: An authorization key to authenticate with Headscale (see below). +- **`HEADPLANE_AGENT_HP_SERVER`**: The URL to your Headplane instance, including the subpath (eg. `https://headplane.example.com/admin`). +- **`HEADPLANE_AGENT_HP_AUTHKEY`**: The generated auth key to authenticate with Headplane. If you already have Headplane setup, you can generate all of these values within the Headplane UI. Navigate to the `Settings` page and click `Agent` to get started. - -HP_AGENT_HOSTNAME=headplane-agent -HP_AGENT_TS_SERVER=http://localhost:8080 -#HP_AGENT_AUTH_KEY=3e0cd749021e5984267cde4b0a5a2ac32c1859e56f7911aa -HP_AGENT_TS_AUTHKEY=a4dab065c735cb4eae4f12804cf7e111206f9c7c9247c629 -HP_AGENT_HP_SERVER=http://localhost:3000/admin diff --git a/package.json b/package.json index dc56156..5fcd837 100644 --- a/package.json +++ b/package.json @@ -2,6 +2,7 @@ "name": "headplane", "private": true, "sideEffects": false, + "version": "0.5.3", "type": "module", "scripts": { "build": "react-router build && vite build -c server/vite.config.ts", diff --git a/server/context/app.ts b/server/context/app.ts index c06e79b..729886f 100644 --- a/server/context/app.ts +++ b/server/context/app.ts @@ -1,12 +1,22 @@ +import type { HostInfo } from '~/types'; +import { TimedCache } from '~server/ws/cache'; +import { hp_agentRequest, hp_getAgentCache } from '~server/ws/data'; +import { hp_getAgents } from '~server/ws/socket'; import { hp_getConfig } from './loader'; -import { HeadplaneConfig } from './parser'; +import type { HeadplaneConfig } from './parser'; export interface AppContext { context: HeadplaneConfig; + hp_agentRequest: typeof hp_agentRequest; + agents: string[]; + agentData?: TimedCache; } -export default function appContext() { +export default function appContext(): AppContext { return { context: hp_getConfig(), + hp_agentRequest, + agents: [...hp_getAgents().keys()], + agentData: hp_getAgentCache(), }; } diff --git a/server/context/globals.ts b/server/context/globals.ts new file mode 100644 index 0000000..9ca8fc0 --- /dev/null +++ b/server/context/globals.ts @@ -0,0 +1,21 @@ +import { HeadplaneConfig } from './parser'; + +declare global { + const __cookie_context: { + cookie_secret: string; + cookie_secure: boolean; + }; + + const __hs_context: { + url: string; + config_path?: string; + config_strict?: boolean; + }; + + const __oidc_context: { + valid: boolean; + secret: string; + }; + + let __integration_context: HeadplaneConfig['integration']; +} diff --git a/server/context/loader.ts b/server/context/loader.ts index 91fcadf..f97b703 100644 --- a/server/context/loader.ts +++ b/server/context/loader.ts @@ -3,7 +3,7 @@ import { env } from 'node:process'; import { type } from 'arktype'; import dotenv from 'dotenv'; import { parseDocument } from 'yaml'; -import { testOidc } from '~/utils/oidc'; +import { getOidcSecret, testOidc } from '~/utils/oidc'; import log, { hpServer_loadLogger } from '~server/utils/log'; import mutex from '~server/utils/mutex'; import { HeadplaneConfig, coalesceConfig, validateConfig } from './parser'; @@ -20,6 +20,11 @@ declare namespace globalThis { config_strict?: boolean; }; + let __oidc_context: { + valid: boolean; + secret: string; + }; + let __integration_context: HeadplaneConfig['integration']; } @@ -113,8 +118,27 @@ export async function hp_loadConfig() { process.exit(1); } - if (config.oidc?.strict_validation) { - testOidc(config.oidc); + // OIDC Related Checks + if (config.oidc) { + if (!config.oidc.client_secret && !config.oidc.client_secret_path) { + log.error('CFGX', 'OIDC configuration is missing a secret, disabling'); + log.error( + 'CFGX', + 'Please specify either `oidc.client_secret` or `oidc.client_secret_path`', + ); + } + + if (config.oidc?.strict_validation) { + const result = await testOidc(config.oidc); + if (!result) { + log.error('CFGX', 'OIDC configuration failed validation, disabling'); + } + + globalThis.__oidc_context = { + valid: result, + secret: getOidcSecret() ?? '', + }; + } } globalThis.__cookie_context = { diff --git a/server/context/parser.ts b/server/context/parser.ts index da9d246..b324990 100644 --- a/server/context/parser.ts +++ b/server/context/parser.ts @@ -8,12 +8,24 @@ const serverConfig = type({ port: type('string | number.integer').pipe((v) => Number(v)), cookie_secret: '32 <= string <= 32', cookie_secure: stringToBool, + agent: type({ + authkey: 'string', + ttl: 'number.integer = 180000', // Default to 3 minutes + cache_path: 'string = "/var/lib/headplane/agent_cache.json"', + }) + .onDeepUndeclaredKey('reject') + .default(() => ({ + authkey: '', + ttl: 180000, + cache_path: '/var/lib/headplane/agent_cache.json', + })), }); const oidcConfig = type({ issuer: 'string.url', client_id: 'string', - client_secret: 'string', + client_secret: 'string?', + client_secret_path: 'string?', token_endpoint_auth_method: '"client_secret_basic" | "client_secret_post" | "client_secret_jwt"', redirect_uri: 'string.url?', diff --git a/server/entry.ts b/server/entry.ts index baac614..01e8a3c 100644 --- a/server/entry.ts +++ b/server/entry.ts @@ -1,9 +1,11 @@ -// import { initWebsocket } from '~server/ws'; import { constants, access } from 'node:fs/promises'; import { createServer } from 'node:http'; +import { WebSocketServer } from 'ws'; import { hp_getConfig, hp_loadConfig } from '~server/context/loader'; import { listener } from '~server/listener'; import log from '~server/utils/log'; +import { hp_loadAgentCache } from '~server/ws/data'; +import { initWebsocket } from '~server/ws/socket'; log.info('SRVX', 'Running Node.js %s', process.versions.node); @@ -19,16 +21,16 @@ try { await hp_loadConfig(); const server = createServer(listener); -// const ws = initWebsocket(); -// if (ws) { -// server.on('upgrade', (req, socket, head) => { -// ws.handleUpgrade(req, socket, head, (ws) => { -// ws.emit('connection', ws, req); -// }); -// }); -// } - const context = hp_getConfig(); +if (context.server.agent.authkey.length > 0) { + const ws = new WebSocketServer({ server }); + initWebsocket(ws, context.server.agent.authkey); + await hp_loadAgentCache( + context.server.agent.ttl, + context.server.agent.cache_path, + ); +} + server.listen(context.server.port, context.server.host, () => { log.info( 'SRVX', diff --git a/server/utils/ws.ts b/server/utils/ws.ts deleted file mode 100644 index b01f71c..0000000 --- a/server/utils/ws.ts +++ /dev/null @@ -1,60 +0,0 @@ -import WebSocket, { WebSocketServer } from 'ws'; -import log from '~server/utils/log'; - -const server = new WebSocketServer({ noServer: true }); -export function initWebsocket() { - // TODO: Finish this and make public - return; - - const key = process.env.LOCAL_AGENT_AUTHKEY; - if (!key) { - return; - } - - log.info('CACH', 'Initializing agent WebSocket'); - server.on('connection', (ws, req) => { - // biome-ignore lint: this file is not USED - const auth = req.headers['authorization']; - if (auth !== `Bearer ${key}`) { - log.warn('CACH', 'Invalid agent WebSocket connection'); - ws.close(1008, 'ERR_INVALID_AUTH'); - return; - } - - const nodeID = req.headers['x-headplane-ts-node-id']; - if (!nodeID) { - log.warn('CACH', 'Invalid agent WebSocket connection'); - ws.close(1008, 'ERR_INVALID_NODE_ID'); - return; - } - - const pinger = setInterval(() => { - if (ws.readyState !== WebSocket.OPEN) { - clearInterval(pinger); - return; - } - - ws.ping(); - }, 30000); - - ws.on('close', () => { - clearInterval(pinger); - }); - - ws.on('error', (error) => { - clearInterval(pinger); - log.error('CACH', 'Closing agent WebSocket connection'); - log.error('CACH', 'Agent WebSocket error: %s', error); - ws.close(1011, 'ERR_INTERNAL_ERROR'); - }); - }); - - return server; -} - -export function appContext() { - return { - ws: server, - wsAuthKey: process.env.LOCAL_AGENT_AUTHKEY, - }; -} diff --git a/server/ws/cache.ts b/server/ws/cache.ts new file mode 100644 index 0000000..148018a --- /dev/null +++ b/server/ws/cache.ts @@ -0,0 +1,126 @@ +import { createHash } from 'node:crypto'; +import { readFile, writeFile } from 'node:fs/promises'; +import { type } from 'arktype'; +import log from '~server/utils/log'; +import mutex from '~server/utils/mutex'; + +const diskSchema = type({ + key: 'string', + value: 'unknown', + expires: 'number?', +}).array(); + +// A persistent HashMap with a TTL for each key +export class TimedCache { + private _cache = new Map(); + private _timings = new Map(); + + // Default TTL is 1 minute + private defaultTTL: number; + private filePath: string; + private writeLock = mutex(); + + // Last flush ID is essentially a hash of the flush contents + // Prevents unnecessary flushing if nothing has changed + private lastFlushId = ''; + + constructor(defaultTTL: number, filePath: string) { + this.defaultTTL = defaultTTL; + this.filePath = filePath; + + // Load the cache from disk and then queue flushes every 10 seconds + this.load().then(() => { + setInterval(() => this.flush(), 10000); + }); + } + + set(key: string, value: V, ttl: number = this.defaultTTL) { + this._cache.set(key, value); + this._timings.set(key, Date.now() + ttl); + } + + get(key: string) { + const value = this._cache.get(key); + if (!value) { + return; + } + + const expires = this._timings.get(key); + if (!expires || expires < Date.now()) { + this._cache.delete(key); + this._timings.delete(key); + return; + } + + return value; + } + + // Map into a Record without any TTLs + toJSON() { + const result: Record = {}; + for (const [key, value] of this._cache.entries()) { + result[key] = value; + } + + return result; + } + + // WARNING: This function expects that this.filePath is NOT ENOENT + private async load() { + const data = await readFile(this.filePath, 'utf-8'); + const cache = () => { + try { + return JSON.parse(data); + } catch (e) { + return undefined; + } + }; + + const diskData = cache(); + if (diskData === undefined) { + log.error('CACH', 'Failed to load cache at %s', this.filePath); + return; + } + + const cacheData = diskSchema(diskData); + if (cacheData instanceof type.errors) { + log.error('CACH', 'Failed to load cache at %s', this.filePath); + log.debug('CACHE', 'Error details: %s', cacheData.toString()); + + // Skip loading the cache (it should be overwritten soon) + return; + } + + for (const { key, value, expires } of diskData) { + this._cache.set(key, value); + this._timings.set(key, expires); + } + + log.info('CACH', 'Loaded cache from %s', this.filePath); + } + + private async flush() { + this.writeLock.acquire(); + const data = Array.from(this._cache.entries()).map(([key, value]) => { + return { key, value, expires: this._timings.get(key) }; + }); + + if (data.length === 0) { + this.writeLock.release(); + return; + } + + // Calculate the hash of the data + const dumpData = JSON.stringify(data); + const sha = createHash('sha256').update(dumpData).digest('hex'); + if (sha === this.lastFlushId) { + this.writeLock.release(); + return; + } + + await writeFile(this.filePath, dumpData, 'utf-8'); + this.lastFlushId = sha; + this.writeLock.release(); + log.debug('CACH', 'Flushed cache to %s', this.filePath); + } +} diff --git a/server/ws/data.ts b/server/ws/data.ts new file mode 100644 index 0000000..38664fa --- /dev/null +++ b/server/ws/data.ts @@ -0,0 +1,61 @@ +import { open } from 'node:fs/promises'; +import type { HostInfo } from '~/types'; +import log from '~server/utils/log'; +import { TimedCache } from './cache'; +import { hp_getAgents } from './socket'; + +let cache: TimedCache | undefined; +export async function hp_loadAgentCache(defaultTTL: number, filepath: string) { + log.debug('CACH', `Loading agent cache from ${filepath}`); + + try { + const handle = await open(filepath, 'w'); + log.info('CACH', `Using agent cache file at ${filepath}`); + await handle.close(); + } catch (e) { + log.info('CACH', `Agent cache file not found at ${filepath}`); + return; + } + + cache = new TimedCache(defaultTTL, filepath); +} + +export function hp_getAgentCache() { + return cache; +} + +export async function hp_agentRequest(nodeList: string[]) { + // Request to all connected agents (we can have multiple) + // Luckily we can parse all the data at once through message parsing + // and then overlapping cache entries will be overwritten by time + const agents = hp_getAgents(); + + // Deduplicate the list of nodes + const NodeIDs = [...new Set(nodeList)]; + NodeIDs.map((node) => { + log.debug('CACH', 'Requesting agent data for', node); + }); + + // Await so that data loads on first request without racing + // Since we do agent.once() we NEED to wait for it to finish + await Promise.allSettled( + [...agents].map(async ([id, agent]) => { + agent.send(JSON.stringify({ NodeIDs })); + await new Promise((resolve) => { + // Just as a safety measure, we set a maximum timeout of 3 seconds + setTimeout(() => resolve(), 3000); + + agent.once('message', (data) => { + const parsed = JSON.parse(data.toString()); + log.debug('CACH', 'Received agent data from %s', id); + for (const [node, info] of Object.entries(parsed)) { + cache?.set(node, info); + log.debug('CACH', 'Cached %s', node); + } + + resolve(); + }); + }); + }), + ); +} diff --git a/server/ws/socket.ts b/server/ws/socket.ts new file mode 100644 index 0000000..73766e0 --- /dev/null +++ b/server/ws/socket.ts @@ -0,0 +1,57 @@ +import WebSocket, { WebSocketServer } from 'ws'; +import log from '~server/utils/log'; + +export function initWebsocket(server: WebSocketServer, authKey: string) { + log.info('SRVX', 'Starting a WebSocket server for agent connections'); + server.on('connection', (ws, req) => { + const tailnetID = req.headers['x-headplane-tailnet-id']; + if (!tailnetID || typeof tailnetID !== 'string') { + log.warn( + 'SRVX', + 'Rejecting an agent WebSocket connection without a tailnet ID', + ); + ws.close(1008, 'ERR_INVALID_TAILNET_ID'); + return; + } + + if (req.headers.authorization !== `Bearer ${authKey}`) { + log.warn('SRVX', 'Rejecting an unauthorized WebSocket connection'); + if (req.socket.remoteAddress) { + log.warn('SRVX', 'Agent source IP: %s', req.socket.remoteAddress); + } + + ws.close(1008, 'ERR_UNAUTHORIZED'); + return; + } + + agents.set(tailnetID, ws); + const pinger = setInterval(() => { + if (ws.readyState !== WebSocket.OPEN) { + clearInterval(pinger); + return; + } + + ws.ping(); + }, 30000); + + ws.on('close', () => { + clearInterval(pinger); + agents.delete(tailnetID); + }); + + ws.on('error', (error) => { + clearInterval(pinger); + log.error('SRVX', 'Agent WebSocket error: %s', error); + log.debug('SRVX', 'Error details: %o', error); + log.error('SRVX', 'Closing agent WebSocket connection'); + ws.close(1011, 'ERR_INTERNAL_ERROR'); + }); + }); + + return server; +} + +const agents = new Map(); +export function hp_getAgents() { + return agents; +} diff --git a/vite.config.ts b/vite.config.ts index 5b36858..1ba3986 100644 --- a/vite.config.ts +++ b/vite.config.ts @@ -1,19 +1,21 @@ +import { readFile } from 'node:fs/promises'; import { reactRouter } from '@react-router/dev/vite'; +import autoprefixer from 'autoprefixer'; +import tailwindcss from 'tailwindcss'; import { defineConfig } from 'vite'; import babel from 'vite-plugin-babel'; import tsconfigPaths from 'vite-tsconfig-paths'; -import fs from 'node:fs'; -import tailwindcss from 'tailwindcss'; -import autoprefixer from 'autoprefixer'; const prefix = process.env.__INTERNAL_PREFIX || '/admin'; if (prefix.endsWith('/')) { throw new Error('Prefix must not end with a slash'); } -const version = fs.readFileSync("version", "utf8"); +// Load the version via package.json +const pkg = await readFile('package.json', 'utf-8'); +const { version } = JSON.parse(pkg); if (!version) { - throw new Error('Unable to read ./version'); + throw new Error('Unable to read version from package.json'); } export default defineConfig({