Files
pikasTech-unidesk/src/components/backend-core/src/provider-registry.ts
T
2026-05-16 16:03:53 +00:00

181 lines
6.6 KiB
TypeScript

import type {
JsonValue,
ProviderToCoreMessage,
} from "../../shared/src/index";
import { isProviderToCoreMessage as checkProviderMessage } from "../../shared/src/index";
import { ctx, sql, logger, config } from "./context";
import type { ProviderSocket, WsData } from "./types";
import { errorToJson, wsSendJson } from "./http";
import { recordEvent, upsertProviderNode, updateProviderHeartbeat, upsertDockerStatus, upsertSystemStatus } from "./db";
import { notifyTaskTerminal } from "./task-dispatcher";
import { forwardSshProviderMessage } from "./ssh-bridge";
import { handleEgressTcpOpen, handleEgressTcpData, handleEgressTcpClose } from "./egress-tcp";
function isTerminalTaskStatus(status: string): boolean {
return status === "succeeded" || status === "failed";
}
export function parseMessage(raw: string | Buffer): ProviderToCoreMessage {
const text = typeof raw === "string" ? raw : raw.toString("utf8");
const parsed = JSON.parse(text) as unknown;
if (!checkProviderMessage(parsed)) {
throw new Error(`Unsupported provider message: ${text.slice(0, 200)}`);
}
return parsed;
}
export async function markProviderOffline(providerId: string): Promise<void> {
ctx.activeProviders.delete(providerId);
if (!ctx.dbReady) return;
await sql()`
UPDATE unidesk_nodes
SET status = 'offline', updated_at = now()
WHERE provider_id = ${providerId}
`;
await recordEvent("provider_offline", providerId, { providerId });
}
export async function markStaleProvidersOffline(): Promise<void> {
if (!ctx.dbReady) return;
const timeoutMs = config().heartbeatTimeoutMs;
const rows = await sql()<{ provider_id: string }[]>`
UPDATE unidesk_nodes
SET status = 'offline', updated_at = now()
WHERE status = 'online'
AND last_heartbeat IS NOT NULL
AND last_heartbeat < now() - (${timeoutMs} * interval '1 millisecond')
RETURNING provider_id
`;
for (const row of rows) {
ctx.activeProviders.delete(row.provider_id);
await recordEvent("provider_heartbeat_timeout", row.provider_id, { providerId: row.provider_id, timeoutMs });
}
}
export async function providerSupports(providerId: string, capability: string): Promise<boolean> {
if (!ctx.dbReady) return false;
const rows = await sql()<Array<{ labels: unknown }>>`
SELECT labels
FROM unidesk_nodes
WHERE provider_id = ${providerId}
LIMIT 1
`;
const labels = rows[0]?.labels;
if (typeof labels !== "object" || labels === null || Array.isArray(labels)) return false;
const capabilities = (labels as Record<string, unknown>).unideskCapabilities;
return Array.isArray(capabilities) ? capabilities.filter((item): item is string => typeof item === "string").includes(capability) : false;
}
export async function handleProviderMessage(ws: ProviderSocket, raw: string | Buffer): Promise<void> {
const message = parseMessage(raw);
ws.data.providerId = message.providerId;
ctx.activeProviders.set(message.providerId, ws);
if (
message.type === "host_ssh_opened" ||
message.type === "host_ssh_data" ||
message.type === "host_ssh_exit" ||
message.type === "host_ssh_error"
) {
forwardSshProviderMessage(message);
return;
}
if (message.type === "http_tunnel_response") {
const waiter = ctx.httpTunnelWaiters.get(message.requestId);
if (waiter === undefined) {
logger("warn", "http_tunnel_response_without_waiter", { providerId: message.providerId, requestId: message.requestId });
return;
}
ctx.httpTunnelWaiters.delete(message.requestId);
waiter({
providerId: message.providerId,
requestId: message.requestId,
ok: message.ok,
result: message.result,
});
return;
}
if (message.type === "egress_tcp_open") {
handleEgressTcpOpen(ws, message);
return;
}
if (message.type === "egress_tcp_data") {
handleEgressTcpData(message);
return;
}
if (message.type === "egress_tcp_close") {
handleEgressTcpClose(message);
return;
}
if (message.type === "register") {
const labels = { ...message.labels, unideskCapabilities: message.capabilities };
await upsertProviderNode(message.providerId, message.name, labels);
await recordEvent("provider_registered", message.providerId, {
providerId: message.providerId,
name: message.name,
labels,
capabilities: message.capabilities,
});
ws.send(JSON.stringify({ type: "ack", requestId: "register", ok: true, message: "registered" }));
return;
}
if (message.type === "heartbeat") {
await updateProviderHeartbeat(message.providerId, message.labels);
logger("debug", "provider_heartbeat", { providerId: message.providerId, labels: message.labels });
return;
}
if (message.type === "system_status") {
await upsertSystemStatus(message.providerId, message.status as unknown as JsonValue, message.status.collectedAt);
logger("debug", "provider_system_status", {
providerId: message.providerId,
cpuPercent: message.status.cpu.percent,
memoryPercent: message.status.memory.percent,
diskPercent: message.status.disk.percent,
ok: message.status.ok,
});
return;
}
if (message.type === "docker_status") {
await upsertDockerStatus(message.providerId, message.status as unknown as JsonValue, message.status.collectedAt);
logger("debug", "provider_docker_status", { providerId: message.providerId, counts: message.status.counts, ok: message.status.ok });
return;
}
await sql()`
WITH incoming AS (
SELECT ${message.status}::text AS status, ${sql().json(message.result ?? { message: message.message })}::jsonb AS result
)
UPDATE unidesk_tasks
SET
status = CASE
WHEN unidesk_tasks.status IN ('succeeded', 'failed') AND incoming.status NOT IN ('succeeded', 'failed') THEN unidesk_tasks.status
WHEN unidesk_tasks.status = 'running' AND incoming.status = 'accepted' THEN unidesk_tasks.status
ELSE incoming.status
END,
result = CASE
WHEN unidesk_tasks.status IN ('succeeded', 'failed') AND incoming.status NOT IN ('succeeded', 'failed') THEN unidesk_tasks.result
WHEN unidesk_tasks.status = 'running' AND incoming.status = 'accepted' THEN unidesk_tasks.result
ELSE incoming.result
END,
updated_at = now()
FROM incoming
WHERE id = ${message.taskId}
`;
await recordEvent("task_status", message.providerId, {
providerId: message.providerId,
taskId: message.taskId,
status: message.status,
message: message.message,
result: message.result ?? null,
});
if (isTerminalTaskStatus(message.status)) {
await notifyTaskTerminal(message.taskId);
}
}