181 lines
6.6 KiB
TypeScript
181 lines
6.6 KiB
TypeScript
import type {
|
|
JsonValue,
|
|
ProviderToCoreMessage,
|
|
} from "../../shared/src/index";
|
|
import { isProviderToCoreMessage as checkProviderMessage } from "../../shared/src/index";
|
|
import { ctx, sql, logger, config } from "./context";
|
|
import type { ProviderSocket, WsData } from "./types";
|
|
import { errorToJson, wsSendJson } from "./http";
|
|
import { recordEvent, upsertProviderNode, updateProviderHeartbeat, upsertDockerStatus, upsertSystemStatus } from "./db";
|
|
import { notifyTaskTerminal } from "./task-dispatcher";
|
|
import { forwardSshProviderMessage } from "./ssh-bridge";
|
|
import { handleEgressTcpOpen, handleEgressTcpData, handleEgressTcpClose } from "./egress-tcp";
|
|
|
|
function isTerminalTaskStatus(status: string): boolean {
|
|
return status === "succeeded" || status === "failed";
|
|
}
|
|
|
|
export function parseMessage(raw: string | Buffer): ProviderToCoreMessage {
|
|
const text = typeof raw === "string" ? raw : raw.toString("utf8");
|
|
const parsed = JSON.parse(text) as unknown;
|
|
if (!checkProviderMessage(parsed)) {
|
|
throw new Error(`Unsupported provider message: ${text.slice(0, 200)}`);
|
|
}
|
|
return parsed;
|
|
}
|
|
|
|
export async function markProviderOffline(providerId: string): Promise<void> {
|
|
ctx.activeProviders.delete(providerId);
|
|
if (!ctx.dbReady) return;
|
|
await sql()`
|
|
UPDATE unidesk_nodes
|
|
SET status = 'offline', updated_at = now()
|
|
WHERE provider_id = ${providerId}
|
|
`;
|
|
await recordEvent("provider_offline", providerId, { providerId });
|
|
}
|
|
|
|
export async function markStaleProvidersOffline(): Promise<void> {
|
|
if (!ctx.dbReady) return;
|
|
const timeoutMs = config().heartbeatTimeoutMs;
|
|
const rows = await sql()<{ provider_id: string }[]>`
|
|
UPDATE unidesk_nodes
|
|
SET status = 'offline', updated_at = now()
|
|
WHERE status = 'online'
|
|
AND last_heartbeat IS NOT NULL
|
|
AND last_heartbeat < now() - (${timeoutMs} * interval '1 millisecond')
|
|
RETURNING provider_id
|
|
`;
|
|
for (const row of rows) {
|
|
ctx.activeProviders.delete(row.provider_id);
|
|
await recordEvent("provider_heartbeat_timeout", row.provider_id, { providerId: row.provider_id, timeoutMs });
|
|
}
|
|
}
|
|
|
|
export async function providerSupports(providerId: string, capability: string): Promise<boolean> {
|
|
if (!ctx.dbReady) return false;
|
|
const rows = await sql()<Array<{ labels: unknown }>>`
|
|
SELECT labels
|
|
FROM unidesk_nodes
|
|
WHERE provider_id = ${providerId}
|
|
LIMIT 1
|
|
`;
|
|
const labels = rows[0]?.labels;
|
|
if (typeof labels !== "object" || labels === null || Array.isArray(labels)) return false;
|
|
const capabilities = (labels as Record<string, unknown>).unideskCapabilities;
|
|
return Array.isArray(capabilities) ? capabilities.filter((item): item is string => typeof item === "string").includes(capability) : false;
|
|
}
|
|
|
|
export async function handleProviderMessage(ws: ProviderSocket, raw: string | Buffer): Promise<void> {
|
|
const message = parseMessage(raw);
|
|
ws.data.providerId = message.providerId;
|
|
ctx.activeProviders.set(message.providerId, ws);
|
|
|
|
if (
|
|
message.type === "host_ssh_opened" ||
|
|
message.type === "host_ssh_data" ||
|
|
message.type === "host_ssh_exit" ||
|
|
message.type === "host_ssh_error"
|
|
) {
|
|
forwardSshProviderMessage(message);
|
|
return;
|
|
}
|
|
|
|
if (message.type === "http_tunnel_response") {
|
|
const waiter = ctx.httpTunnelWaiters.get(message.requestId);
|
|
if (waiter === undefined) {
|
|
logger("warn", "http_tunnel_response_without_waiter", { providerId: message.providerId, requestId: message.requestId });
|
|
return;
|
|
}
|
|
ctx.httpTunnelWaiters.delete(message.requestId);
|
|
waiter({
|
|
providerId: message.providerId,
|
|
requestId: message.requestId,
|
|
ok: message.ok,
|
|
result: message.result,
|
|
});
|
|
return;
|
|
}
|
|
|
|
if (message.type === "egress_tcp_open") {
|
|
handleEgressTcpOpen(ws, message);
|
|
return;
|
|
}
|
|
if (message.type === "egress_tcp_data") {
|
|
handleEgressTcpData(message);
|
|
return;
|
|
}
|
|
if (message.type === "egress_tcp_close") {
|
|
handleEgressTcpClose(message);
|
|
return;
|
|
}
|
|
|
|
if (message.type === "register") {
|
|
const labels = { ...message.labels, unideskCapabilities: message.capabilities };
|
|
await upsertProviderNode(message.providerId, message.name, labels);
|
|
await recordEvent("provider_registered", message.providerId, {
|
|
providerId: message.providerId,
|
|
name: message.name,
|
|
labels,
|
|
capabilities: message.capabilities,
|
|
});
|
|
ws.send(JSON.stringify({ type: "ack", requestId: "register", ok: true, message: "registered" }));
|
|
return;
|
|
}
|
|
|
|
if (message.type === "heartbeat") {
|
|
await updateProviderHeartbeat(message.providerId, message.labels);
|
|
logger("debug", "provider_heartbeat", { providerId: message.providerId, labels: message.labels });
|
|
return;
|
|
}
|
|
|
|
if (message.type === "system_status") {
|
|
await upsertSystemStatus(message.providerId, message.status as unknown as JsonValue, message.status.collectedAt);
|
|
logger("debug", "provider_system_status", {
|
|
providerId: message.providerId,
|
|
cpuPercent: message.status.cpu.percent,
|
|
memoryPercent: message.status.memory.percent,
|
|
diskPercent: message.status.disk.percent,
|
|
ok: message.status.ok,
|
|
});
|
|
return;
|
|
}
|
|
|
|
if (message.type === "docker_status") {
|
|
await upsertDockerStatus(message.providerId, message.status as unknown as JsonValue, message.status.collectedAt);
|
|
logger("debug", "provider_docker_status", { providerId: message.providerId, counts: message.status.counts, ok: message.status.ok });
|
|
return;
|
|
}
|
|
|
|
await sql()`
|
|
WITH incoming AS (
|
|
SELECT ${message.status}::text AS status, ${sql().json(message.result ?? { message: message.message })}::jsonb AS result
|
|
)
|
|
UPDATE unidesk_tasks
|
|
SET
|
|
status = CASE
|
|
WHEN unidesk_tasks.status IN ('succeeded', 'failed') AND incoming.status NOT IN ('succeeded', 'failed') THEN unidesk_tasks.status
|
|
WHEN unidesk_tasks.status = 'running' AND incoming.status = 'accepted' THEN unidesk_tasks.status
|
|
ELSE incoming.status
|
|
END,
|
|
result = CASE
|
|
WHEN unidesk_tasks.status IN ('succeeded', 'failed') AND incoming.status NOT IN ('succeeded', 'failed') THEN unidesk_tasks.result
|
|
WHEN unidesk_tasks.status = 'running' AND incoming.status = 'accepted' THEN unidesk_tasks.result
|
|
ELSE incoming.result
|
|
END,
|
|
updated_at = now()
|
|
FROM incoming
|
|
WHERE id = ${message.taskId}
|
|
`;
|
|
await recordEvent("task_status", message.providerId, {
|
|
providerId: message.providerId,
|
|
taskId: message.taskId,
|
|
status: message.status,
|
|
message: message.message,
|
|
result: message.result ?? null,
|
|
});
|
|
if (isTerminalTaskStatus(message.status)) {
|
|
await notifyTaskTerminal(message.taskId);
|
|
}
|
|
}
|