Skip to main content

Graph Neural Network

The GNN provides Layer 2 (structural) scoring by analyzing the relational context around each action. It operates on a heterogeneous graph with 13 node types and 20 edge types, using multi-task learning for risk classification, severity regression, and threat category detection.

Architecture: SubgraphGNN

Event → Subgraph Extraction (k=3 hop) → Node Feature Encoding
    → 2-layer HeteroConv (SAGEConv per edge type, sum aggregation)
    → Global Mean Pooling (over Action nodes)
    → Multi-Task Heads

Multi-Task Outputs

HeadOutputLoss
Risk level classification5 classes (none/low/medium/high/critical)Cross-entropy
Severity score regression0-100 continuousMSE
Threat category multi-label10 binary labelsBinary cross-entropy
Binary threat detectionis_threat (0/1)Binary cross-entropy

Threat Categories (10)

CategoryDescription
policy_violationCustomer policy breach
exfiltrationData leaving security boundary
privilege_escalationIncreasing access beyond scope
scope_creepAccessing new resource types
excessive_agencyAgent acting beyond instructions
multi_step_attackCoordinated action chain
prompt_injectionPrompt manipulation attempt
tool_poisoningMCP tool description manipulation
supply_chainDependency/server compromise
anomalous_accessStatistical behavioral anomaly

Node Types & Features

Node TypeFeature DimKey Features
Action52action_type(20) + verb_risk(1) + tool_hash(8) + tool_risk(1) + param_features(20) + time(2)
Agent24agent_type(14) + model(4) + deployment_context(6)
Resource22sensitivity(1) + resource_type(13) + data_classification(5) + name_risk(3)
User17clearance(1) + department(10) + role(6)
Policy10severity(1) + rule_type(6) + scope_hash(3)
Session4message_count(1) + tool_call_count(1) + time_features(2)

Edge Types (20)

("Agent", "PERFORMED", "Action")          # Core behavioral
("Action", "ACCESSED", "Resource")
("Action", "TOUCHED_FIELD", "DataField")
("Action", "SENT_TO", "ExternalEndpoint")
("Session", "INITIATED_BY", "User")
("Agent", "BELONGS_TO", "Session")
("Resource", "GOVERNED_BY", "Policy")
("Policy", "PERMITS", "Action")           # Policy evaluation
("Policy", "BLOCKS", "Action")
("Action", "PRECEDED_BY", "Action")       # Temporal chains
("Action", "ESCALATED_FROM", "Action")
("Action", "SIMILAR_TO", "Action")
("Agent", "CONNECTED_TO", "MCPServer")    # MCP topology
("Action", "INVOKED", "MCPTool")
("Action", "READ_RESOURCE", "MCPResource")
("Action", "USED_PROMPT", "MCPPrompt")
("MCPServer", "EXPOSES", "MCPTool")
("MCPServer", "AUTHENTICATED_BY", "OAuthToken")
("MCPTool", "DESCRIPTION_CHANGED", "MCPTool")
("MCPTool", "CROSS_SERVER_CALL", "MCPTool")

Score Production

The GNN maps its probability distribution to a 0-100 score using risk class centers:
centers = {"none": 5, "low": 20, "medium": 45, "high": 72, "critical": 92}
gnn_score = sum(prob[class] * centers[class] for class in risk_classes)

GNN Predictor API

from quint_graph.gnn import GNNPredictor

predictor = GNNPredictor(
    model_path="path/to/model.pt",
    hidden_channels=64,
    num_layers=2,
    device="cpu"
)

# Full inference
result = predictor.predict(event, centrality=None)
# Returns: risk_class, probabilities, confidence, severity, is_threat

# Score only
score = predictor.predict_score(event, centrality=None)
# Returns: float 0-100

Training Data

The GNN trains on synthetic data generated without proprietary customer data:
SourceExamplesPercentage
Simulated (benign)12,000-18,000~40%
Simulated (threat)8,000-12,000~25%
Adversarial evasion5,000-8,000~18%
LLM-generated scenarios3,000-5,000~12%
Real-world (design partners)500-2,000~5%

Evaluation Targets

MetricTarget
Risk Score MAE< 8 points
Risk Level Accuracy> 92%
Threat Detection F1> 95%
False Positive Rate< 3%
Adversarial Detection Rate> 75%
GPU is NOT required for inference. For our graph size (~1,500 ontology + dynamic events), CPU inference is 5-15ms.