From 926e9479a42f6514f89fccffd17648fa02d17b11 Mon Sep 17 00:00:00 2001 From: Bartosz Majsak Date: Sun, 20 Jul 2025 02:43:06 +0200 Subject: [PATCH 1/9] fix(test): uses local httptest server urls for downloader tests (#4601) Signed-off-by: Bartosz Majsak --- pkg/agent/watcher_test.go | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/pkg/agent/watcher_test.go b/pkg/agent/watcher_test.go index 597d900da6a..980648b5405 100644 --- a/pkg/agent/watcher_test.go +++ b/pkg/agent/watcher_test.go @@ -703,6 +703,12 @@ var _ = Describe("Watcher", func() { }, } for protocol, scenario := range scenarios { + ts := scenario.server + defer ts.Close() + cl := storage.HTTPSProvider{ + Client: ts.Client(), + } + logger.Printf("Setting up %s Server", protocol) logger.Printf("Sync model config using temp dir %v\n", modelDir) watcher := NewWatcher("/tmp/configs", modelDir, sugar) @@ -710,26 +716,19 @@ var _ = Describe("Watcher", func() { { Name: "model1", Spec: v1alpha1.ModelSpec{ - StorageURI: "http://example.com/test.tar", + StorageURI: ts.URL + "/test.tar", Framework: "sklearn", }, }, { Name: "model2", Spec: v1alpha1.ModelSpec{ - StorageURI: "https://example.com/test.zip", + StorageURI: ts.URL + "/test.zip", Framework: "sklearn", }, }, } - // Create HTTPS client - ts := scenario.server - defer ts.Close() - cl := storage.HTTPSProvider{ - Client: ts.Client(), - } - watcher.parseConfig(modelConfigs, false) puller := Puller{ channelMap: make(map[string]*ModelChannel), From 8b29459558a29e4bd274e518280b0f113e504c7e Mon Sep 17 00:00:00 2001 From: Vincent Date: Sun, 20 Jul 2025 09:45:19 +0900 Subject: [PATCH 2/9] Fixed the issue of the same metrics across different deployments under different namespaces (#4593) Signed-off-by: Vincent Hou --- .../rawkube_controller_test.go | 17 ++- .../reconcilers/keda/keda_reconciler.go | 10 +- .../reconcilers/keda/keda_reconciler_test.go | 89 +++++++++-- .../reconcilers/otel/otel_reconciler.go | 140 +++++++++++++----- .../reconcilers/otel/otel_reconciler_test.go | 91 +++++++----- test/e2e/predictor/test_autoscaling.py | 2 +- 6 files changed, 254 insertions(+), 95 deletions(-) diff --git a/pkg/controller/v1beta1/inferenceservice/rawkube_controller_test.go b/pkg/controller/v1beta1/inferenceservice/rawkube_controller_test.go index 303122241a0..b00450284d2 100644 --- a/pkg/controller/v1beta1/inferenceservice/rawkube_controller_test.go +++ b/pkg/controller/v1beta1/inferenceservice/rawkube_controller_test.go @@ -10327,6 +10327,21 @@ var _ = Describe("v1beta1 inference service controller", func() { }, }, }, + "resourcedetection/env": map[string]interface{}{ + "detectors": []interface{}{string("env")}, + "override": bool(false), + "timeout": string("2s"), + }, + "transform": map[string]interface{}{ + "metric_statements": []interface{}{map[string]interface{}{ + "context": string("datapoint"), + "statements": []interface{}{ + string("set(attributes[\"namespace\"], resource.attributes[\"k8s.namespace.name\"])"), + string("set(attributes[\"deployment\"], resource.attributes[\"k8s.deployment.name\"])"), + string("set(attributes[\"pod\"], resource.attributes[\"k8s.pod.name\"])"), + }, + }}, + }, }}, Exporters: otelv1beta1.AnyConfig{Object: map[string]interface{}{ "otlp": map[string]interface{}{ @@ -10341,7 +10356,7 @@ var _ = Describe("v1beta1 inference service controller", func() { Pipelines: map[string]*otelv1beta1.Pipeline{ "metrics": { Receivers: []string{"prometheus"}, - Processors: []string{"filter/metrics"}, + Processors: []string{"resourcedetection/env", "transform", "filter/metrics"}, Exporters: []string{"otlp"}, }, }, diff --git a/pkg/controller/v1beta1/inferenceservice/reconcilers/keda/keda_reconciler.go b/pkg/controller/v1beta1/inferenceservice/reconcilers/keda/keda_reconciler.go index 16e7269db09..75a60039ac0 100644 --- a/pkg/controller/v1beta1/inferenceservice/reconcilers/keda/keda_reconciler.go +++ b/pkg/controller/v1beta1/inferenceservice/reconcilers/keda/keda_reconciler.go @@ -67,7 +67,7 @@ func NewKedaReconciler(client client.Client, }, nil } -func getKedaMetrics(componentExt *v1beta1.ComponentExtensionSpec, configMap *corev1.ConfigMap, +func getKedaMetrics(componentMeta metav1.ObjectMeta, componentExt *v1beta1.ComponentExtensionSpec, configMap *corev1.ConfigMap, ) ([]kedav1alpha1.ScaleTriggers, error) { var triggers []kedav1alpha1.ScaleTriggers @@ -156,8 +156,12 @@ func getKedaMetrics(componentExt *v1beta1.ComponentExtensionSpec, configMap *cor if triggerType == string(constants.AutoScalerMetricsSourceOpenTelemetry) { trigger.Type = "external" + // Inject namespace and deployment label selectors into the query for metric isolation. + // This ensures the metricQuery only selects metrics for the correct deployment and namespace. + // Example: sum({namespace="", deployment=""}) + metricQuery := fmt.Sprintf("sum(%s{namespace=\"%s\", deployment=\"%s\"})", query, componentMeta.Namespace, componentMeta.Name) trigger.Metadata = map[string]string{ - "metricQuery": query, + "metricQuery": metricQuery, "targetValue": fmt.Sprintf("%f", targetValue), "scalerAddress": MetricScalerEndpoint, } @@ -189,7 +193,7 @@ func createKedaScaledObject(componentMeta metav1.ObjectMeta, if MaxReplicas < *MinReplicas { MaxReplicas = *MinReplicas } - triggers, err := getKedaMetrics(componentExtension, configMap) + triggers, err := getKedaMetrics(componentMeta, componentExtension, configMap) if err != nil { return nil, err } diff --git a/pkg/controller/v1beta1/inferenceservice/reconcilers/keda/keda_reconciler_test.go b/pkg/controller/v1beta1/inferenceservice/reconcilers/keda/keda_reconciler_test.go index eb9e8ee4dd6..76830e80ca6 100644 --- a/pkg/controller/v1beta1/inferenceservice/reconcilers/keda/keda_reconciler_test.go +++ b/pkg/controller/v1beta1/inferenceservice/reconcilers/keda/keda_reconciler_test.go @@ -55,10 +55,14 @@ func TestNewKedaReconciler(t *testing.T) { } func TestGetKedaMetrics_ResourceMetricSourceType(t *testing.T) { + componentMeta := metav1.ObjectMeta{ + Name: "test-component", + Namespace: "test-namespace", + } componentExt := createComponentExtensionWithResourceMetric() configMap := &corev1.ConfigMap{} - triggers, err := getKedaMetrics(componentExt, configMap) + triggers, err := getKedaMetrics(componentMeta, componentExt, configMap) require.NoError(t, err) assert.Len(t, triggers, 1) assert.Equal(t, "cpu", triggers[0].Type) @@ -66,10 +70,14 @@ func TestGetKedaMetrics_ResourceMetricSourceType(t *testing.T) { } func TestGetKedaMetrics_ExternalMetricSourceType(t *testing.T) { + componentMeta := metav1.ObjectMeta{ + Name: "test-component", + Namespace: "test-namespace", + } componentExt := createComponentExtensionWithExternalMetric() configMap := &corev1.ConfigMap{} - triggers, err := getKedaMetrics(componentExt, configMap) + triggers, err := getKedaMetrics(componentMeta, componentExt, configMap) require.NoError(t, err) assert.Len(t, triggers, 1) assert.Equal(t, "prometheus", triggers[0].Type) @@ -79,15 +87,20 @@ func TestGetKedaMetrics_ExternalMetricSourceType(t *testing.T) { } func TestGetKedaMetrics_PodMetricSourceType(t *testing.T) { + componentMeta := metav1.ObjectMeta{ + Name: "test-component", + Namespace: "test-namespace", + } componentExt := createComponentExtensionWithPodMetric() configMap := &corev1.ConfigMap{} - triggers, err := getKedaMetrics(componentExt, configMap) + triggers, err := getKedaMetrics(componentMeta, componentExt, configMap) require.NoError(t, err) assert.Len(t, triggers, 1) assert.Equal(t, "external", triggers[0].Type) assert.Equal(t, "http://otel-server", triggers[0].Metadata["scalerAddress"]) - assert.Equal(t, "otel_query", triggers[0].Metadata["metricQuery"]) + // The metricQuery should now include namespace and deployment selectors + assert.Equal(t, "sum(otel_query{namespace=\"test-namespace\", deployment=\"test-component\"})", triggers[0].Metadata["metricQuery"]) assert.Equal(t, "200.000000", triggers[0].Metadata["targetValue"]) } @@ -242,6 +255,10 @@ func TestReconcile_UpdateScaledObject(t *testing.T) { } func TestGetKedaMetrics_AverageValueMetricSourceType(t *testing.T) { + componentMeta := metav1.ObjectMeta{ + Name: "test-component", + Namespace: "test-namespace", + } componentExt := &v1beta1.ComponentExtensionSpec{ AutoScaling: &v1beta1.AutoScalingSpec{ Metrics: []v1beta1.MetricsSpec{ @@ -260,7 +277,7 @@ func TestGetKedaMetrics_AverageValueMetricSourceType(t *testing.T) { } configMap := &corev1.ConfigMap{} - triggers, err := getKedaMetrics(componentExt, configMap) + triggers, err := getKedaMetrics(componentMeta, componentExt, configMap) require.NoError(t, err) assert.Len(t, triggers, 1) assert.Equal(t, "cpu", triggers[0].Type) @@ -268,6 +285,10 @@ func TestGetKedaMetrics_AverageValueMetricSourceType(t *testing.T) { } func TestGetKedaMetrics_ValueMetricSourceType(t *testing.T) { + componentMeta := metav1.ObjectMeta{ + Name: "test-component", + Namespace: "test-namespace", + } componentExt := &v1beta1.ComponentExtensionSpec{ AutoScaling: &v1beta1.AutoScalingSpec{ Metrics: []v1beta1.MetricsSpec{ @@ -286,7 +307,7 @@ func TestGetKedaMetrics_ValueMetricSourceType(t *testing.T) { } configMap := &corev1.ConfigMap{} - triggers, err := getKedaMetrics(componentExt, configMap) + triggers, err := getKedaMetrics(componentMeta, componentExt, configMap) require.NoError(t, err) assert.Len(t, triggers, 1) assert.Equal(t, "memory", triggers[0].Type) @@ -294,6 +315,10 @@ func TestGetKedaMetrics_ValueMetricSourceType(t *testing.T) { } func TestGetKedaMetrics_DefaultCPUUtilization(t *testing.T) { + componentMeta := metav1.ObjectMeta{ + Name: "test-component", + Namespace: "test-namespace", + } componentExt := &v1beta1.ComponentExtensionSpec{ AutoScaling: &v1beta1.AutoScalingSpec{ Metrics: []v1beta1.MetricsSpec{ @@ -311,7 +336,7 @@ func TestGetKedaMetrics_DefaultCPUUtilization(t *testing.T) { } configMap := &corev1.ConfigMap{} - triggers, err := getKedaMetrics(componentExt, configMap) + triggers, err := getKedaMetrics(componentMeta, componentExt, configMap) require.NoError(t, err) assert.Len(t, triggers, 1) assert.Equal(t, "cpu", triggers[0].Type) @@ -381,17 +406,25 @@ func TestCreateKedaScaledObject_MaxReplicasLessThanMinReplicas(t *testing.T) { } func TestGetKedaMetrics_NilAutoScaling(t *testing.T) { + componentMeta := metav1.ObjectMeta{ + Name: "test-component", + Namespace: "test-namespace", + } componentExt := &v1beta1.ComponentExtensionSpec{ AutoScaling: nil, } configMap := &corev1.ConfigMap{} - triggers, err := getKedaMetrics(componentExt, configMap) + triggers, err := getKedaMetrics(componentMeta, componentExt, configMap) require.NoError(t, err) assert.Empty(t, triggers) } func TestGetKedaMetrics_ResourceMetricSourceType_Utilization(t *testing.T) { + componentMeta := metav1.ObjectMeta{ + Name: "test-component", + Namespace: "test-namespace", + } componentExt := &v1beta1.ComponentExtensionSpec{ AutoScaling: &v1beta1.AutoScalingSpec{ Metrics: []v1beta1.MetricsSpec{ @@ -409,7 +442,7 @@ func TestGetKedaMetrics_ResourceMetricSourceType_Utilization(t *testing.T) { }, } configMap := &corev1.ConfigMap{} - triggers, err := getKedaMetrics(componentExt, configMap) + triggers, err := getKedaMetrics(componentMeta, componentExt, configMap) require.NoError(t, err) assert.Len(t, triggers, 1) assert.Equal(t, "cpu", triggers[0].Type) @@ -417,6 +450,10 @@ func TestGetKedaMetrics_ResourceMetricSourceType_Utilization(t *testing.T) { } func TestGetKedaMetrics_ResourceMetricSourceType_Utilization_DefaultCPU(t *testing.T) { + componentMeta := metav1.ObjectMeta{ + Name: "test-component", + Namespace: "test-namespace", + } componentExt := &v1beta1.ComponentExtensionSpec{ AutoScaling: &v1beta1.AutoScalingSpec{ Metrics: []v1beta1.MetricsSpec{ @@ -433,7 +470,7 @@ func TestGetKedaMetrics_ResourceMetricSourceType_Utilization_DefaultCPU(t *testi }, } configMap := &corev1.ConfigMap{} - triggers, err := getKedaMetrics(componentExt, configMap) + triggers, err := getKedaMetrics(componentMeta, componentExt, configMap) require.NoError(t, err) assert.Len(t, triggers, 1) assert.Equal(t, "cpu", triggers[0].Type) @@ -441,6 +478,10 @@ func TestGetKedaMetrics_ResourceMetricSourceType_Utilization_DefaultCPU(t *testi } func TestGetKedaMetrics_ResourceMetricSourceType_AverageValue(t *testing.T) { + componentMeta := metav1.ObjectMeta{ + Name: "test-component", + Namespace: "test-namespace", + } componentExt := &v1beta1.ComponentExtensionSpec{ AutoScaling: &v1beta1.AutoScalingSpec{ Metrics: []v1beta1.MetricsSpec{ @@ -458,7 +499,7 @@ func TestGetKedaMetrics_ResourceMetricSourceType_AverageValue(t *testing.T) { }, } configMap := &corev1.ConfigMap{} - triggers, err := getKedaMetrics(componentExt, configMap) + triggers, err := getKedaMetrics(componentMeta, componentExt, configMap) require.NoError(t, err) assert.Len(t, triggers, 1) assert.Equal(t, "memory", triggers[0].Type) @@ -466,6 +507,10 @@ func TestGetKedaMetrics_ResourceMetricSourceType_AverageValue(t *testing.T) { } func TestGetKedaMetrics_ResourceMetricSourceType_Value(t *testing.T) { + componentMeta := metav1.ObjectMeta{ + Name: "test-component", + Namespace: "test-namespace", + } componentExt := &v1beta1.ComponentExtensionSpec{ AutoScaling: &v1beta1.AutoScalingSpec{ Metrics: []v1beta1.MetricsSpec{ @@ -483,7 +528,7 @@ func TestGetKedaMetrics_ResourceMetricSourceType_Value(t *testing.T) { }, } configMap := &corev1.ConfigMap{} - triggers, err := getKedaMetrics(componentExt, configMap) + triggers, err := getKedaMetrics(componentMeta, componentExt, configMap) require.NoError(t, err) assert.Len(t, triggers, 1) assert.Equal(t, "memory", triggers[0].Type) @@ -491,6 +536,10 @@ func TestGetKedaMetrics_ResourceMetricSourceType_Value(t *testing.T) { } func TestGetKedaMetrics_ExternalMetricSourceType_WithNamespaceAndAuth(t *testing.T) { + componentMeta := metav1.ObjectMeta{ + Name: "test-component", + Namespace: "test-namespace", + } componentExt := &v1beta1.ComponentExtensionSpec{ AutoScaling: &v1beta1.AutoScalingSpec{ Metrics: []v1beta1.MetricsSpec{ @@ -518,7 +567,7 @@ func TestGetKedaMetrics_ExternalMetricSourceType_WithNamespaceAndAuth(t *testing }, } configMap := &corev1.ConfigMap{} - triggers, err := getKedaMetrics(componentExt, configMap) + triggers, err := getKedaMetrics(componentMeta, componentExt, configMap) require.NoError(t, err) assert.Len(t, triggers, 1) trigger := triggers[0] @@ -533,6 +582,10 @@ func TestGetKedaMetrics_ExternalMetricSourceType_WithNamespaceAndAuth(t *testing } func TestGetKedaMetrics_ExternalMetricSourceType_WithoutNamespaceOrAuth(t *testing.T) { + componentMeta := metav1.ObjectMeta{ + Name: "test-component", + Namespace: "test-namespace", + } componentExt := &v1beta1.ComponentExtensionSpec{ AutoScaling: &v1beta1.AutoScalingSpec{ Metrics: []v1beta1.MetricsSpec{ @@ -553,7 +606,7 @@ func TestGetKedaMetrics_ExternalMetricSourceType_WithoutNamespaceOrAuth(t *testi }, } configMap := &corev1.ConfigMap{} - triggers, err := getKedaMetrics(componentExt, configMap) + triggers, err := getKedaMetrics(componentMeta, componentExt, configMap) require.NoError(t, err) assert.Len(t, triggers, 1) trigger := triggers[0] @@ -565,6 +618,10 @@ func TestGetKedaMetrics_ExternalMetricSourceType_WithoutNamespaceOrAuth(t *testi } func TestGetKedaMetrics_PodMetricSourceType_Success(t *testing.T) { + componentMeta := metav1.ObjectMeta{ + Name: "test-component", + Namespace: "test-namespace", + } componentExt := &v1beta1.ComponentExtensionSpec{ AutoScaling: &v1beta1.AutoScalingSpec{ Metrics: []v1beta1.MetricsSpec{ @@ -586,12 +643,12 @@ func TestGetKedaMetrics_PodMetricSourceType_Success(t *testing.T) { }, } configMap := &corev1.ConfigMap{} - triggers, err := getKedaMetrics(componentExt, configMap) + triggers, err := getKedaMetrics(componentMeta, componentExt, configMap) require.NoError(t, err) assert.Len(t, triggers, 1) trigger := triggers[0] assert.Equal(t, "external", trigger.Type) - assert.Equal(t, "otel_query", trigger.Metadata["metricQuery"]) + assert.Equal(t, "sum(otel_query{namespace=\"test-namespace\", deployment=\"test-component\"})", trigger.Metadata["metricQuery"]) assert.Equal(t, "200.000000", trigger.Metadata["targetValue"]) assert.Equal(t, "http://otel-server", trigger.Metadata["scalerAddress"]) assert.Equal(t, "sum", trigger.Metadata["operationOverTime"]) diff --git a/pkg/controller/v1beta1/inferenceservice/reconcilers/otel/otel_reconciler.go b/pkg/controller/v1beta1/inferenceservice/reconcilers/otel/otel_reconciler.go index 6503e170c6a..7e759799b62 100644 --- a/pkg/controller/v1beta1/inferenceservice/reconcilers/otel/otel_reconciler.go +++ b/pkg/controller/v1beta1/inferenceservice/reconcilers/otel/otel_reconciler.go @@ -35,7 +35,54 @@ import ( logf "sigs.k8s.io/controller-runtime/pkg/log" ) -const ModeSidecar = "sidecar" +const ( + ProcessorResourcedetectionEnv = "resourcedetection/env" + ProcessorTransform = "transform" + ProcessorFilterMetrics = "filter/metrics" + JobNameOtelCollector = "otel-collector" + PrometheusReceiver = "prometheus" + OtlpExporter = "otlp" + ModeSidecar = "sidecar" + + AnnotationPrometheusPort = "prometheus.kserve.io/port" + DefaultPrometheusPort = "8080" + + ResourcedetectionDetectorEnv = "env" + ResourcedetectionTimeout = "2s" + ResourcedetectionOverride = false + TransformContextDatapoint = "datapoint" + StatementSetNamespace = "set(attributes[\"namespace\"], resource.attributes[\"k8s.namespace.name\"])" + StatementSetDeployment = "set(attributes[\"deployment\"], resource.attributes[\"k8s.deployment.name\"])" + StatementSetPod = "set(attributes[\"pod\"], resource.attributes[\"k8s.pod.name\"])" + + MatchTypeStrict = "strict" + PipelineMetrics = "metrics" + CompressionNone = "none" + TlsKey = "tls" + TlsInsecureKey = "insecure" + EndpointKey = "endpoint" + + KeyDetectors = "detectors" + KeyTimeout = "timeout" + KeyOverride = "override" + KeyMetricStatements = "metric_statements" + KeyContext = "context" + KeyStatements = "statements" + KeyMetrics = "metrics" + KeyInclude = "include" + KeyMatchType = "match_type" + KeyMetricNames = "metric_names" + KeyConfig = "config" + KeyScrapeConfigs = "scrape_configs" + KeyJobName = "job_name" + KeyScrapeInterval = "scrape_interval" + KeyStaticConfigs = "static_configs" + KeyTargets = "targets" + KeyCompression = "compression" + KeyTls = "tls" + KeyInsecure = "insecure" + KeyEndpoint = "endpoint" +) var log = logf.Log.WithName("OTelReconciler") @@ -62,10 +109,45 @@ func createOtelCollector(componentMeta metav1.ObjectMeta, metricNames []string, otelConfig v1beta1.OtelCollectorConfig, ) *otelv1beta1.OpenTelemetryCollector { - port, ok := componentMeta.Annotations["prometheus.kserve.io/port"] + port, ok := componentMeta.Annotations[AnnotationPrometheusPort] if !ok { - log.Info("Annotation prometheus.kserve.io/port is missing, using default value 8080 to configure OTel Collector") - port = "8080" + log.Info(fmt.Sprintf("Annotation %s is missing, using default value %s to configure OTel Collector", AnnotationPrometheusPort, DefaultPrometheusPort)) + port = DefaultPrometheusPort + } + + processors := map[string]interface{}{ + ProcessorResourcedetectionEnv: map[string]interface{}{ + KeyDetectors: []interface{}{ResourcedetectionDetectorEnv}, + KeyTimeout: ResourcedetectionTimeout, + KeyOverride: ResourcedetectionOverride, + }, + ProcessorTransform: map[string]interface{}{ + KeyMetricStatements: []interface{}{ + map[string]interface{}{ + KeyContext: TransformContextDatapoint, + KeyStatements: []interface{}{ + StatementSetNamespace, + StatementSetDeployment, + StatementSetPod, + }, + }, + }, + }, + } + + pipelineProcessors := []string{ProcessorResourcedetectionEnv, ProcessorTransform} + + // Add filter processor to include all specified metrics + if len(metricNames) > 0 { + processors[ProcessorFilterMetrics] = map[string]interface{}{ + KeyMetrics: map[string]interface{}{ + KeyInclude: map[string]interface{}{ + KeyMatchType: MatchTypeStrict, + KeyMetricNames: metricNames, + }, + }, + } + pipelineProcessors = append(pipelineProcessors, ProcessorFilterMetrics) } otelCollector := &otelv1beta1.OpenTelemetryCollector{ @@ -75,18 +157,18 @@ func createOtelCollector(componentMeta metav1.ObjectMeta, Annotations: componentMeta.Annotations, }, Spec: otelv1beta1.OpenTelemetryCollectorSpec{ - Mode: otelv1beta1.ModeSidecar, + Mode: ModeSidecar, Config: otelv1beta1.Config{ Receivers: otelv1beta1.AnyConfig{Object: map[string]interface{}{ - "prometheus": map[string]interface{}{ - "config": map[string]interface{}{ - "scrape_configs": []interface{}{ + PrometheusReceiver: map[string]interface{}{ + KeyConfig: map[string]interface{}{ + KeyScrapeConfigs: []interface{}{ map[string]interface{}{ - "job_name": "otel-collector", - "scrape_interval": otelConfig.ScrapeInterval, - "static_configs": []interface{}{ + KeyJobName: JobNameOtelCollector, + KeyScrapeInterval: otelConfig.ScrapeInterval, + KeyStaticConfigs: []interface{}{ map[string]interface{}{ - "targets": []interface{}{"localhost:" + port}, + KeyTargets: []interface{}{"localhost:" + port}, }, }, }, @@ -95,20 +177,21 @@ func createOtelCollector(componentMeta metav1.ObjectMeta, }, }}, Exporters: otelv1beta1.AnyConfig{Object: map[string]interface{}{ - "otlp": map[string]interface{}{ - "endpoint": otelConfig.MetricReceiverEndpoint, - "compression": "none", - "tls": map[string]interface{}{ - "insecure": true, + OtlpExporter: map[string]interface{}{ + KeyEndpoint: otelConfig.MetricReceiverEndpoint, + KeyCompression: CompressionNone, + KeyTls: map[string]interface{}{ + KeyInsecure: true, }, }, }}, + Processors: &otelv1beta1.AnyConfig{Object: processors}, Service: otelv1beta1.Service{ Pipelines: map[string]*otelv1beta1.Pipeline{ - "metrics": { - Receivers: []string{"prometheus"}, - Processors: []string{}, - Exporters: []string{"otlp"}, + PipelineMetrics: { + Receivers: []string{PrometheusReceiver}, + Processors: pipelineProcessors, + Exporters: []string{OtlpExporter}, }, }, }, @@ -116,21 +199,6 @@ func createOtelCollector(componentMeta metav1.ObjectMeta, }, } - // Add filter processor to include all specified metrics - if len(metricNames) > 0 { - otelCollector.Spec.Config.Processors = &otelv1beta1.AnyConfig{Object: map[string]interface{}{ - "filter/metrics": map[string]interface{}{ - "metrics": map[string]interface{}{ - "include": map[string]interface{}{ - "match_type": "strict", - "metric_names": metricNames, - }, - }, - }, - }} - otelCollector.Spec.Config.Service.Pipelines["metrics"].Processors = []string{"filter/metrics"} - } - return otelCollector } diff --git a/pkg/controller/v1beta1/inferenceservice/reconcilers/otel/otel_reconciler_test.go b/pkg/controller/v1beta1/inferenceservice/reconcilers/otel/otel_reconciler_test.go index 2ca8dea2283..d9c3c170a73 100644 --- a/pkg/controller/v1beta1/inferenceservice/reconcilers/otel/otel_reconciler_test.go +++ b/pkg/controller/v1beta1/inferenceservice/reconcilers/otel/otel_reconciler_test.go @@ -47,7 +47,7 @@ func TestCreateOtelCollector(t *testing.T) { Name: "test-service", Namespace: "default", Annotations: map[string]string{ - "prometheus.kserve.io/port": "9090", + AnnotationPrometheusPort: "9090", }, }, metricNames: []string{"request-count"}, @@ -56,11 +56,11 @@ func TestCreateOtelCollector(t *testing.T) { MetricReceiverEndpoint: "otel-collector:4317", }, expectedConfig: map[string]interface{}{ - "job_name": "otel-collector", - "scrape_interval": "15s", - "static_configs": []interface{}{ + KeyJobName: JobNameOtelCollector, + KeyScrapeInterval: "15s", + KeyStaticConfigs: []interface{}{ map[string]interface{}{ - "targets": []interface{}{"localhost:9090"}, + KeyTargets: []interface{}{"localhost:9090"}, }, }, }, @@ -76,11 +76,11 @@ func TestCreateOtelCollector(t *testing.T) { ScrapeInterval: "30s", }, expectedConfig: map[string]interface{}{ - "job_name": "otel-collector", - "scrape_interval": "30s", - "static_configs": []interface{}{ + KeyJobName: JobNameOtelCollector, + KeyScrapeInterval: "30s", + KeyStaticConfigs: []interface{}{ map[string]interface{}{ - "targets": []interface{}{"localhost:8080"}, + KeyTargets: []interface{}{"localhost:8080"}, }, }, }, @@ -97,32 +97,47 @@ func TestCreateOtelCollector(t *testing.T) { // Assert config details receivers := collector.Spec.Config.Receivers.Object - prometheusConfig := receivers["prometheus"].(map[string]interface{}) - config := prometheusConfig["config"].(map[string]interface{}) - scrapeConfigs := config["scrape_configs"].([]interface{}) + prometheusConfig := receivers[PrometheusReceiver].(map[string]interface{}) + config := prometheusConfig[KeyConfig].(map[string]interface{}) + scrapeConfigs := config[KeyScrapeConfigs].([]interface{}) scrapeConfig := scrapeConfigs[0].(map[string]interface{}) - assert.Equal(t, tc.expectedConfig["job_name"], scrapeConfig["job_name"]) - assert.Equal(t, tc.expectedConfig["scrape_interval"], scrapeConfig["scrape_interval"]) + assert.Equal(t, tc.expectedConfig[KeyJobName], scrapeConfig[KeyJobName]) + assert.Equal(t, tc.expectedConfig[KeyScrapeInterval], scrapeConfig[KeyScrapeInterval]) - staticConfigs := scrapeConfig["static_configs"].([]interface{}) + staticConfigs := scrapeConfig[KeyStaticConfigs].([]interface{}) staticConfig := staticConfigs[0].(map[string]interface{}) - targets := staticConfig["targets"].([]interface{}) + targets := staticConfig[KeyTargets].([]interface{}) - assert.Equal(t, tc.expectedConfig["static_configs"].([]interface{})[0].(map[string]interface{})["targets"], targets) + assert.Equal(t, tc.expectedConfig[KeyStaticConfigs].([]interface{})[0].(map[string]interface{})[KeyTargets], targets) // Verify filter processor if metric names exist if len(tc.metricNames) > 0 { processors := collector.Spec.Config.Processors.Object - filterMetrics := processors["filter/metrics"].(map[string]interface{}) - metrics := filterMetrics["metrics"].(map[string]interface{}) - include := metrics["include"].(map[string]interface{}) - metricNames := include["metric_names"].([]string) + filterMetrics := processors[ProcessorFilterMetrics].(map[string]interface{}) + metrics := filterMetrics[KeyMetrics].(map[string]interface{}) + include := metrics[KeyInclude].(map[string]interface{}) + metricNames := include[KeyMetricNames].([]string) + assert.ElementsMatch(t, tc.metricNames, metricNames) + } + + // Verify processors always include resourcedetection/env and transform + processors := collector.Spec.Config.Processors.Object + assert.Contains(t, processors, ProcessorResourcedetectionEnv) + assert.Contains(t, processors, ProcessorTransform) + + // Verify pipeline processors + pipeline := collector.Spec.Config.Service.Pipelines[PipelineMetrics].Processors + if len(tc.metricNames) > 0 { + assert.Equal(t, []string{ProcessorResourcedetectionEnv, ProcessorTransform, ProcessorFilterMetrics}, pipeline) + // Verify filter processor config + filterMetrics := processors[ProcessorFilterMetrics].(map[string]interface{}) + metrics := filterMetrics[KeyMetrics].(map[string]interface{}) + include := metrics[KeyInclude].(map[string]interface{}) + metricNames := include[KeyMetricNames].([]string) assert.ElementsMatch(t, tc.metricNames, metricNames) - // Verify processors in pipeline - assert.Equal(t, []string{"filter/metrics"}, collector.Spec.Config.Service.Pipelines["metrics"].Processors) } else { - assert.Empty(t, collector.Spec.Config.Service.Pipelines["metrics"].Processors) + assert.Equal(t, []string{ProcessorResourcedetectionEnv, ProcessorTransform}, pipeline) } }) } @@ -193,15 +208,15 @@ func TestReconcileUpdate(t *testing.T) { Mode: otelv1beta1.ModeSidecar, Config: otelv1beta1.Config{ Receivers: otelv1beta1.AnyConfig{Object: map[string]interface{}{ - "prometheus": map[string]interface{}{ - "config": map[string]interface{}{ - "scrape_configs": []interface{}{ + PrometheusReceiver: map[string]interface{}{ + KeyConfig: map[string]interface{}{ + KeyScrapeConfigs: []interface{}{ map[string]interface{}{ - "job_name": "old-collector", - "scrape_interval": "30s", - "static_configs": []interface{}{ + KeyJobName: "old-collector", + KeyScrapeInterval: "30s", + KeyStaticConfigs: []interface{}{ map[string]interface{}{ - "targets": []interface{}{"localhost:8080"}, + KeyTargets: []interface{}{"localhost:8080"}, }, }, }, @@ -211,8 +226,8 @@ func TestReconcileUpdate(t *testing.T) { }}, Service: otelv1beta1.Service{ Pipelines: map[string]*otelv1beta1.Pipeline{ - "metrics": { - Receivers: []string{"prometheus"}, + PipelineMetrics: { + Receivers: []string{PrometheusReceiver}, Processors: []string{}, Exporters: []string{"otlp"}, }, @@ -240,13 +255,13 @@ func TestReconcileUpdate(t *testing.T) { // Verify updated config receivers := updatedCollector.Spec.Config.Receivers.Object - prometheusConfig := receivers["prometheus"].(map[string]interface{}) - config := prometheusConfig["config"].(map[string]interface{}) - scrapeConfigs := config["scrape_configs"].([]interface{}) + prometheusConfig := receivers[PrometheusReceiver].(map[string]interface{}) + config := prometheusConfig[KeyConfig].(map[string]interface{}) + scrapeConfigs := config[KeyScrapeConfigs].([]interface{}) scrapeConfig := scrapeConfigs[0].(map[string]interface{}) - assert.Equal(t, "otel-collector", scrapeConfig["job_name"]) - assert.Equal(t, "15s", scrapeConfig["scrape_interval"]) + assert.Equal(t, JobNameOtelCollector, scrapeConfig[KeyJobName]) + assert.Equal(t, "15s", scrapeConfig[KeyScrapeInterval]) } func TestSetControllerReferences(t *testing.T) { diff --git a/test/e2e/predictor/test_autoscaling.py b/test/e2e/predictor/test_autoscaling.py index c7444575846..37e90d3e964 100644 --- a/test/e2e/predictor/test_autoscaling.py +++ b/test/e2e/predictor/test_autoscaling.py @@ -649,7 +649,7 @@ async def test_scaling_sklearn_with_keda_otel_add_on(rest_v1_client, network_lay trigger_metadata = scaledobject_resp["items"][0]["spec"]["triggers"][0]["metadata"] trigger_type = scaledobject_resp["items"][0]["spec"]["triggers"][0]["type"] assert trigger_type == "external" - assert trigger_metadata["metricQuery"] == "http_requests_per_second" + assert trigger_metadata["metricQuery"] == 'sum(http_requests_per_second{namespace="kserve-ci-e2e-test", deployment="isvc-sklearn-keda-otel-add-on-predictor"})' assert trigger_metadata["scalerAddress"] == "keda-otel-scaler.keda.svc:4318" assert trigger_metadata["targetValue"] == "50.000000" res = await predict_isvc( From 427dda575bba7bba1680a58eed3ad08ac31c6cf5 Mon Sep 17 00:00:00 2001 From: Sivanantham <90966311+sivanantha321@users.noreply.github.com> Date: Sun, 20 Jul 2025 14:16:38 +0530 Subject: [PATCH 3/9] feat: Add disable postprocessing option for raw logits (#4566) Signed-off-by: Sivanantham Chinnaiyan --- .../huggingfaceserver/__main__.py | 14 +- .../huggingfaceserver/encoder_model.py | 285 +++++++++++++--- python/huggingfaceserver/tests/test_model.py | 59 +++- python/huggingfaceserver/tests/test_output.py | 314 +++++++++--------- test/e2e/predictor/test_huggingface.py | 66 +++- test/e2e/predictor/test_output.py | 2 +- 6 files changed, 516 insertions(+), 224 deletions(-) diff --git a/python/huggingfaceserver/huggingfaceserver/__main__.py b/python/huggingfaceserver/huggingfaceserver/__main__.py index 7fb41f6a0af..7936418388b 100644 --- a/python/huggingfaceserver/huggingfaceserver/__main__.py +++ b/python/huggingfaceserver/huggingfaceserver/__main__.py @@ -138,10 +138,19 @@ def is_vllm_backend_enabled( parser.add_argument( "--return_token_type_ids", action="store_true", help="Return token type ids" ) -parser.add_argument( + +# Create a mutually exclusive group for output format options +# This group allows the user to choose between returning probabilities or disabling postprocessing. +output_format_group = parser.add_mutually_exclusive_group() +output_format_group.add_argument( "--return_probabilities", action="store_true", - help="Return all probabilities", + help="Return probabilities instead of logits for classification tasks such as token classification, text classification and fill-mask.", +) +output_format_group.add_argument( + "--return_raw_logits", + action="store_true", + help="Return raw logits without processing. Supported only classification tasks such as token classification, text classification and fill-mask.", ) parser.add_argument( "--disable_log_requests", action="store_true", help="Disable logging requests" @@ -291,6 +300,7 @@ def load_model(): return_token_type_ids=kwargs.get("return_token_type_ids", None), request_logger=request_logger, return_probabilities=kwargs.get("return_probabilities", False), + return_raw_logits=kwargs.get("return_raw_logits", False), ) model.load() return model diff --git a/python/huggingfaceserver/huggingfaceserver/encoder_model.py b/python/huggingfaceserver/huggingfaceserver/encoder_model.py index 267d40828bd..e352fc4f97b 100644 --- a/python/huggingfaceserver/huggingfaceserver/encoder_model.py +++ b/python/huggingfaceserver/huggingfaceserver/encoder_model.py @@ -13,7 +13,7 @@ # limitations under the License. import base64 import pathlib -from typing import Any, Dict, AsyncGenerator, Optional, Union +from typing import Any, Dict, AsyncGenerator, List, Optional, Tuple, Union from fastapi import Request import struct @@ -102,6 +102,7 @@ def __init__( trust_remote_code: bool = False, return_probabilities: bool = False, request_logger: Optional[RequestLogger] = None, + return_raw_logits: bool = False, ): super().__init__(model_name) self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -116,6 +117,7 @@ def __init__( self.tokenizer_revision = tokenizer_revision self.trust_remote_code = trust_remote_code self.return_probabilities = return_probabilities + self.return_raw_logits = return_raw_logits self.request_logger = request_logger if model_config: @@ -292,69 +294,250 @@ async def predict( def postprocess( self, outputs: Union[Tensor, InferResponse], context: Dict[str, Any] ) -> Union[Dict, InferResponse]: - input_ids = context["input_ids"] - request = context["payload"] - if isinstance(outputs, InferResponse): - shape = torch.Size(outputs.outputs[0].shape) - data = torch.Tensor(outputs.outputs[0].data) - outputs = data.view(shape) - input_ids = torch.Tensor(input_ids) - inferences = [] + """ + Process model outputs based on the ML task. + + Args: + outputs: Model output tensor or inference response + context: Dictionary containing input_ids, attention_mask and payload + + Returns: + Processed inference results as Dict or InferResponse + """ + normalized_outputs, input_ids, request = self._normalize_inputs( + outputs, context + ) + if self.task == MLTask.sequence_classification: - num_rows, num_cols = outputs.shape - for i in range(num_rows): - out = outputs[i].unsqueeze(0) - if self.return_probabilities: - inferences.append(dict(enumerate(out.numpy().flatten()))) - else: - predicted_idx = out.argmax().item() - inferences.append(predicted_idx) + inferences = self._process_sequence_classification(normalized_outputs) return get_predict_response(request, inferences, self.name) elif self.task == MLTask.fill_mask: - num_rows = outputs.shape[0] - for i in range(num_rows): - mask_pos = (input_ids == self._tokenizer.mask_token_id)[i] - mask_token_index = mask_pos.nonzero(as_tuple=True)[0] - if self.return_probabilities: - probabilities = torch.softmax(outputs[i, mask_token_index], dim=-1) - decoded_probabilities = [] - for idx, probs in enumerate(probabilities): - token_probs = [] - for token_id, prob in enumerate(probs): - token = self._tokenizer.decode([token_id]) - token_probs.append({f"{token}": f"{prob.item():.4f}"}) - decoded_probabilities.append(token_probs) - inferences.append(decoded_probabilities) - else: - predicted_token_id = outputs[i, mask_token_index].argmax(axis=-1) - inferences.append(self._tokenizer.decode(predicted_token_id)) + inferences = self._process_fill_mask(normalized_outputs, input_ids) return get_predict_response(request, inferences, self.name) elif self.task == MLTask.token_classification: - num_rows = outputs.shape[0] - for i in range(num_rows): - output = outputs[i].unsqueeze(0) - if self.return_probabilities: - for values in output.tolist(): - res = [{k: v for k, v in enumerate(value)} for value in values] - inferences.append([res]) - else: - predictions = torch.argmax(output, dim=2) - inferences.append(predictions.tolist()) + inferences = self._process_token_classification(normalized_outputs) return get_predict_response(request, inferences, self.name) elif self.task == MLTask.text_embedding: - # Perform pooling - outputs = _mean_pooling(outputs, context["attention_mask"]) - # Normalize embeddings - outputs = F.normalize(outputs, p=2, dim=1) - num_rows, _ = outputs.shape - for i in range(num_rows): - inferences.append(outputs[i].tolist()) + inferences = self._process_text_embedding( + normalized_outputs, context["attention_mask"] + ) return get_predict_response(request, inferences, self.name) else: raise OpenAIError( f"Unsupported task {self.task}. Please check the supported `task` option." ) + def _normalize_inputs( + self, outputs: Union[Tensor, InferResponse], context: Dict[str, Any] + ) -> Tuple[Tensor, Union[Tensor, list], Dict]: + """ + Normalize model outputs and extract context items. + + Args: + outputs: Model output tensor or inference response + context: Dictionary containing input_ids and payload + + Returns: + Tuple of (normalized_outputs, input_ids, request) + """ + input_ids = context["input_ids"] + request = context["payload"] + + if isinstance(outputs, InferResponse): + shape = torch.Size(outputs.outputs[0].shape) + data = torch.Tensor(outputs.outputs[0].data) + outputs = data.view(shape) + input_ids = torch.Tensor(input_ids) + + return outputs, input_ids, request + + def _process_sequence_classification( + self, outputs: Tensor + ) -> List[Union[int, Dict]]: + """ + Process outputs for sequence classification task. + + Args: + outputs: Model output tensor + + Returns: + List of processed inferences (indices or probability dictionaries) + """ + inferences = [] + num_rows, _ = outputs.shape + + for i in range(num_rows): + out = outputs[i].unsqueeze(0) + if self.return_raw_logits: + logits = out.squeeze() + logits = logits.cpu() if logits.is_cuda else logits + inferences.append({j: logits[j].item() for j in range(logits.size(0))}) + elif self.return_probabilities: + probs = torch.softmax(out, dim=-1).squeeze() + probs = probs.cpu() if probs.is_cuda else probs + inferences.append( + {j: float(f"{probs[j]:.4f}") for j in range(probs.size(0))} + ) + else: + predicted_idx = out.argmax().item() + inferences.append(predicted_idx) + + return inferences + + def _process_fill_mask( + self, outputs: Tensor, input_ids: Union[Tensor, list] + ) -> List[Union[str, List]]: + """ + Process outputs for fill mask task. + + Args: + outputs: Model output tensor + input_ids: Input token IDs + + Returns: + List of processed inferences (token strings or token probability lists) + """ + inferences = [] + num_rows = outputs.shape[0] + + for i in range(num_rows): + if isinstance(input_ids, torch.Tensor): + # Find where the mask token is in this sequence + ids_row = input_ids[i] + mask_token_indices = [] + for j in range(ids_row.size(0)): + if ids_row[j].item() == self._tokenizer.mask_token_id: + mask_token_indices.append(j) + mask_token_index = torch.tensor(mask_token_indices) + else: + # Handle list inputs + try: + mask_token_index = input_ids[i].index(self._tokenizer.mask_token_id) + mask_token_index = torch.tensor([mask_token_index]) + except (ValueError, AttributeError): + # Fallback if mask token not found or input_ids is not a list + mask_token_index = torch.tensor([0]) # Use first token as fallback + masked_output = outputs[i, mask_token_index] + + if self.return_raw_logits: + inferences.append(self._process_mask_logits(masked_output)) + elif self.return_probabilities: + inferences.append(self._process_mask_probabilities(masked_output)) + else: + predicted_token_id = masked_output.argmax(dim=-1) + predicted_token = self._tokenizer.decode(predicted_token_id) + inferences.append(predicted_token) + + return inferences + + def _process_mask_logits( + self, masked_output: Tensor + ) -> List[List[Dict[str, float]]]: + """ + Process mask logits into token-logit dictionaries. + + Args: + masked_output: Tensor of logits for masked tokens + + Returns: + Nested lists of token-logit dictionaries + """ + decoded_logits = [] + masked_output = masked_output.cpu() if masked_output.is_cuda else masked_output + for logits in masked_output: + token_logits = [] + for token_id, logit in enumerate(logits): + token_logits.append({token_id: logit.item()}) + decoded_logits.append(token_logits) + return decoded_logits + + def _process_mask_probabilities( + self, masked_output: Tensor + ) -> List[List[Dict[str, str]]]: + """ + Process mask probabilities into token-probability dictionaries. + + Args: + masked_output: Tensor of logits for masked tokens + + Returns: + Nested lists of token-probability dictionaries + """ + probabilities = torch.softmax(masked_output, dim=-1) + decoded_probabilities = [] + probabilities = probabilities.cpu() if probabilities.is_cuda else probabilities + for probs in probabilities: + token_probs = [] + for token_id, prob in enumerate(probs): + token_probs.append({token_id: f"{prob.item():.4f}"}) + decoded_probabilities.append(token_probs) + return decoded_probabilities + + def _process_token_classification( + self, outputs: Tensor + ) -> List[Union[List, List[Dict]]]: + """ + Process outputs for token classification task. + + Args: + outputs: Model output tensor + + Returns: + List of processed token classifications (indices or probability dictionaries) + """ + inferences = [] + num_rows = outputs.shape[0] + for i in range(num_rows): + output = outputs[i].unsqueeze(0) + + if self.return_raw_logits: + token_logits = [] + output = output.cpu() if output.is_cuda else output + for values in output.squeeze(0): + token_logits.append( + {j: values[j].item() for j in range(values.size(0))} + ) + inferences.append(token_logits) + elif self.return_probabilities: + probs = torch.softmax(output, dim=-1) + probs = probs.cpu() if probs.is_cuda else probs + token_probs = [] + for values in probs.squeeze(0): + token_probs.append( + {j: float(f"{values[j]:.4f}") for j in range(values.size(0))} + ) + inferences.append(token_probs) + else: + predictions = torch.argmax(output, dim=2) + inferences.append(predictions.tolist()) + + return inferences + + def _process_text_embedding( + self, outputs: Tensor, attention_mask: Tensor + ) -> List[List[float]]: + """ + Process outputs for text embedding task. + + Args: + outputs: Model output tensor + attention_mask: Attention mask tensor + + Returns: + List of normalized embeddings + """ + # Perform pooling + outputs = _mean_pooling(outputs, attention_mask) + # Normalize embeddings + outputs = F.normalize(outputs, p=2, dim=1) + outputs = outputs.cpu() if outputs.is_cuda else outputs + inferences = [] + num_rows, _ = outputs.shape + for i in range(num_rows): + inferences.append(outputs[i].tolist()) + + return inferences + def _log_request(self, request_id: str, prompt: list[str]) -> None: if self.request_logger: self.request_logger.log_inputs( diff --git a/python/huggingfaceserver/tests/test_model.py b/python/huggingfaceserver/tests/test_model.py index 692ef942c69..f8b7477b39d 100644 --- a/python/huggingfaceserver/tests/test_model.py +++ b/python/huggingfaceserver/tests/test_model.py @@ -28,7 +28,7 @@ from huggingfaceserver.encoder_model import HuggingfaceEncoderModel from huggingfaceserver.generative_model import HuggingfaceGenerativeModel from huggingfaceserver.task import MLTask -from test_output import bert_token_classification_return_prob_expected_output +from test_output import bert_token_classification_return_raw_logits_expected_output import torch.nn.functional as F @@ -97,6 +97,19 @@ def bert_base_return_prob(): model.stop() +@pytest.fixture(scope="module") +def bert_base_return_raw_logits(): + model = HuggingfaceEncoderModel( + "bert-base-uncased-yelp-polarity", + model_id_or_path="textattack/bert-base-uncased-yelp-polarity", + task=MLTask.sequence_classification, + return_raw_logits=True, + ) + model.load() + yield model + model.stop() + + @pytest.fixture(scope="module") def bert_token_classification_return_prob(): model = HuggingfaceEncoderModel( @@ -111,6 +124,20 @@ def bert_token_classification_return_prob(): model.stop() +@pytest.fixture(scope="module") +def bert_token_classification_return_raw_logits(): + model = HuggingfaceEncoderModel( + "bert-large-cased-finetuned-conll03-english", + model_id_or_path="dbmdz/bert-large-cased-finetuned-conll03-english", + do_lower_case=True, + add_special_tokens=False, + return_raw_logits=True, + ) + model.load() + yield model + model.stop() + + @pytest.fixture(scope="module") def bert_token_classification(): model = HuggingfaceEncoderModel( @@ -251,24 +278,42 @@ async def test_bert_sequence_classification_return_probabilities(bert_base_retur {"instances": [request, request]}, headers={} ) + assert response == {"predictions": [{0: 0.0012, 1: 0.9988}, {0: 0.0012, 1: 0.9988}]} + + +@pytest.mark.asyncio +async def test_bert_sequence_classification_return_raw_logits( + bert_base_return_raw_logits, +): + request = "Hello, my dog is cute." + response, _ = await bert_base_return_raw_logits( + {"instances": [request, request]}, headers={} + ) + assert response == { "predictions": [ - {0: approx(-3.1508713), 1: approx(3.5892851)}, - {0: approx(-3.1508713), 1: approx(3.589285)}, + { + 0: approx(-3.1508712768554688, abs=0.000009), + 1: approx(3.589285135269165, abs=0.000009), + }, + { + 0: approx(-3.1508712768554688, abs=0.000009), + 1: approx(3.589284896850586, abs=0.000009), + }, ] } @pytest.mark.asyncio -async def test_bert_token_classification_return_prob( - bert_token_classification_return_prob, +async def test_bert_token_classification_return_raw_logits( + bert_token_classification_return_raw_logits, ): request = "Hello, my dog is cute." - response, _ = await bert_token_classification_return_prob( + response, _ = await bert_token_classification_return_raw_logits( {"instances": [request, request]}, headers={} ) - assert response == bert_token_classification_return_prob_expected_output + assert response == bert_token_classification_return_raw_logits_expected_output @pytest.mark.asyncio diff --git a/python/huggingfaceserver/tests/test_output.py b/python/huggingfaceserver/tests/test_output.py index 81b192640e1..7d6767248f8 100644 --- a/python/huggingfaceserver/tests/test_output.py +++ b/python/huggingfaceserver/tests/test_output.py @@ -15,169 +15,165 @@ from pytest import approx -bert_token_classification_return_prob_expected_output = { +bert_token_classification_return_raw_logits_expected_output = { "predictions": [ [ - [ - { - 0: approx(10.125612258911133, abs=0.000009), - 1: approx(-1.818464756011963, abs=0.000009), - 2: approx(-1.3191171884536743, abs=0.000009), - 3: approx(-1.9324339628219604, abs=0.000009), - 4: approx(-1.4850239753723145, abs=0.000009), - 5: approx(-1.69266676902771, abs=0.000009), - 6: approx(-0.898107647895813, abs=0.000009), - 7: approx(-1.725127935409546, abs=0.000009), - 8: approx(0.32057392597198486, abs=0.000009), - }, - { - 0: approx(9.647306442260742, abs=0.000009), - 1: approx(-2.112745523452759, abs=0.000009), - 2: approx(-0.8831668496131897, abs=0.000009), - 3: approx(-2.719135046005249, abs=0.000009), - 4: approx(-0.47473397850990295, abs=0.000009), - 5: approx(-2.2424018383026123, abs=0.000009), - 6: approx(0.6101505160331726, abs=0.000009), - 7: approx(-2.2192084789276123, abs=0.000009), - 8: approx(-0.654518187046051, abs=0.000009), - }, - { - 0: approx(10.364561080932617, abs=0.000009), - 1: approx(-2.240158796310425, abs=0.000009), - 2: approx(-0.9236820340156555, abs=0.000009), - 3: approx(-2.623302936553955, abs=0.000009), - 4: approx(-0.501063346862793, abs=0.000009), - 5: approx(-1.9418426752090454, abs=0.000009), - 6: approx(-0.041013482958078384, abs=0.000009), - 7: approx(-2.12089204788208, abs=0.000009), - 8: approx(-1.2565152645111084, abs=0.000009), - }, - { - 0: approx(9.21961784362793, abs=0.000009), - 1: approx(-2.1359012126922607, abs=0.000009), - 2: approx(-0.1689995676279068, abs=0.000009), - 3: approx(-3.0277929306030273, abs=0.000009), - 4: approx(0.2589835822582245, abs=0.000009), - 5: approx(-2.4426753520965576, abs=0.000009), - 6: approx(0.4815778136253357, abs=0.000009), - 7: approx(-2.3223936557769775, abs=0.000009), - 8: approx(-0.23837946355342865, abs=0.000009), - }, - { - 0: approx(11.023712158203125, abs=0.000009), - 1: approx(-2.0757784843444824, abs=0.000009), - 2: approx(-0.8648976683616638, abs=0.000009), - 3: approx(-2.6088907718658447, abs=0.000009), - 4: approx(-0.781069278717041, abs=0.000009), - 5: approx(-1.7040152549743652, abs=0.000009), - 6: approx(-0.5265535712242126, abs=0.000009), - 7: approx(-1.7141871452331543, abs=0.000009), - 8: approx(-1.0127886533737183, abs=0.000009), - }, - { - 0: approx(9.288057327270508, abs=0.000009), - 1: approx(-1.9795234203338623, abs=0.000009), - 2: approx(0.3095782697200775, abs=0.000009), - 3: approx(-2.781409740447998, abs=0.000009), - 4: approx(-0.4492073655128479, abs=0.000009), - 5: approx(-1.7794851064682007, abs=0.000009), - 6: approx(-0.3115033507347107, abs=0.000009), - 7: approx(-2.0755014419555664, abs=0.000009), - 8: approx(-0.8829556107521057, abs=0.000009), - }, - { - 0: approx(7.968824863433838, abs=0.000009), - 1: approx(-2.0212578773498535, abs=0.000009), - 2: approx(-0.43740716576576233, abs=0.000009), - 3: approx(-3.6762948036193848, abs=0.000009), - 4: approx(-0.012759938836097717, abs=0.000009), - 5: approx(-2.320978879928589, abs=0.000009), - 6: approx(0.6871926784515381, abs=0.000009), - 7: approx(-2.541991949081421, abs=0.000009), - 8: approx(0.32606029510498047, abs=0.000009), - }, - ] + { + 0: approx(10.125612258911133, abs=0.000009), + 1: approx(-1.818464756011963, abs=0.000009), + 2: approx(-1.3191171884536743, abs=0.000009), + 3: approx(-1.9324339628219604, abs=0.000009), + 4: approx(-1.4850239753723145, abs=0.000009), + 5: approx(-1.69266676902771, abs=0.000009), + 6: approx(-0.898107647895813, abs=0.000009), + 7: approx(-1.725127935409546, abs=0.000009), + 8: approx(0.32057392597198486, abs=0.000009), + }, + { + 0: approx(9.647306442260742, abs=0.000009), + 1: approx(-2.112745523452759, abs=0.000009), + 2: approx(-0.8831668496131897, abs=0.000009), + 3: approx(-2.719135046005249, abs=0.000009), + 4: approx(-0.47473397850990295, abs=0.000009), + 5: approx(-2.2424018383026123, abs=0.000009), + 6: approx(0.6101505160331726, abs=0.000009), + 7: approx(-2.2192084789276123, abs=0.000009), + 8: approx(-0.654518187046051, abs=0.000009), + }, + { + 0: approx(10.364561080932617, abs=0.000009), + 1: approx(-2.240158796310425, abs=0.000009), + 2: approx(-0.9236820340156555, abs=0.000009), + 3: approx(-2.623302936553955, abs=0.000009), + 4: approx(-0.501063346862793, abs=0.000009), + 5: approx(-1.9418426752090454, abs=0.000009), + 6: approx(-0.041013482958078384, abs=0.000009), + 7: approx(-2.12089204788208, abs=0.000009), + 8: approx(-1.2565152645111084, abs=0.000009), + }, + { + 0: approx(9.21961784362793, abs=0.000009), + 1: approx(-2.1359012126922607, abs=0.000009), + 2: approx(-0.1689995676279068, abs=0.000009), + 3: approx(-3.0277929306030273, abs=0.000009), + 4: approx(0.2589835822582245, abs=0.000009), + 5: approx(-2.4426753520965576, abs=0.000009), + 6: approx(0.4815778136253357, abs=0.000009), + 7: approx(-2.3223936557769775, abs=0.000009), + 8: approx(-0.23837946355342865, abs=0.000009), + }, + { + 0: approx(11.023712158203125, abs=0.000009), + 1: approx(-2.0757784843444824, abs=0.000009), + 2: approx(-0.8648976683616638, abs=0.000009), + 3: approx(-2.6088907718658447, abs=0.000009), + 4: approx(-0.781069278717041, abs=0.000009), + 5: approx(-1.7040152549743652, abs=0.000009), + 6: approx(-0.5265535712242126, abs=0.000009), + 7: approx(-1.7141871452331543, abs=0.000009), + 8: approx(-1.0127886533737183, abs=0.000009), + }, + { + 0: approx(9.288057327270508, abs=0.000009), + 1: approx(-1.9795234203338623, abs=0.000009), + 2: approx(0.3095782697200775, abs=0.000009), + 3: approx(-2.781409740447998, abs=0.000009), + 4: approx(-0.4492073655128479, abs=0.000009), + 5: approx(-1.7794851064682007, abs=0.000009), + 6: approx(-0.3115033507347107, abs=0.000009), + 7: approx(-2.0755014419555664, abs=0.000009), + 8: approx(-0.8829556107521057, abs=0.000009), + }, + { + 0: approx(7.968824863433838, abs=0.000009), + 1: approx(-2.0212578773498535, abs=0.000009), + 2: approx(-0.43740716576576233, abs=0.000009), + 3: approx(-3.6762948036193848, abs=0.000009), + 4: approx(-0.012759938836097717, abs=0.000009), + 5: approx(-2.320978879928589, abs=0.000009), + 6: approx(0.6871926784515381, abs=0.000009), + 7: approx(-2.541991949081421, abs=0.000009), + 8: approx(0.32606029510498047, abs=0.000009), + }, ], [ - [ - { - 0: approx(10.125612258911133, abs=0.000009), - 1: approx(-1.818464756011963, abs=0.000009), - 2: approx(-1.3191171884536743, abs=0.000009), - 3: approx(-1.9324339628219604, abs=0.000009), - 4: approx(-1.4850239753723145, abs=0.000009), - 5: approx(-1.69266676902771, abs=0.000009), - 6: approx(-0.898107647895813, abs=0.000009), - 7: approx(-1.725127935409546, abs=0.000009), - 8: approx(0.32057392597198486, abs=0.000009), - }, - { - 0: approx(9.647306442260742, abs=0.000009), - 1: approx(-2.112745523452759, abs=0.000009), - 2: approx(-0.8831668496131897, abs=0.000009), - 3: approx(-2.719135046005249, abs=0.000009), - 4: approx(-0.47473397850990295, abs=0.000009), - 5: approx(-2.2424018383026123, abs=0.000009), - 6: approx(0.6101505160331726, abs=0.000009), - 7: approx(-2.2192084789276123, abs=0.000009), - 8: approx(-0.654518187046051, abs=0.000009), - }, - { - 0: approx(10.364561080932617, abs=0.000009), - 1: approx(-2.240158796310425, abs=0.000009), - 2: approx(-0.9236820340156555, abs=0.000009), - 3: approx(-2.623302936553955, abs=0.000009), - 4: approx(-0.501063346862793, abs=0.000009), - 5: approx(-1.9418426752090454, abs=0.000009), - 6: approx(-0.041013482958078384, abs=0.000009), - 7: approx(-2.12089204788208, abs=0.000009), - 8: approx(-1.2565152645111084, abs=0.000009), - }, - { - 0: approx(9.21961784362793, abs=0.000009), - 1: approx(-2.1359012126922607, abs=0.000009), - 2: approx(-0.1689995676279068, abs=0.000009), - 3: approx(-3.0277929306030273, abs=0.000009), - 4: approx(0.2589835822582245, abs=0.000009), - 5: approx(-2.4426753520965576, abs=0.000009), - 6: approx(0.4815778136253357, abs=0.000009), - 7: approx(-2.3223936557769775, abs=0.000009), - 8: approx(-0.23837946355342865, abs=0.000009), - }, - { - 0: approx(11.023712158203125, abs=0.000009), - 1: approx(-2.0757784843444824, abs=0.000009), - 2: approx(-0.8648976683616638, abs=0.000009), - 3: approx(-2.6088907718658447, abs=0.000009), - 4: approx(-0.781069278717041, abs=0.000009), - 5: approx(-1.7040152549743652, abs=0.000009), - 6: approx(-0.5265535712242126, abs=0.000009), - 7: approx(-1.7141871452331543, abs=0.000009), - 8: approx(-1.0127886533737183, abs=0.000009), - }, - { - 0: approx(9.288057327270508, abs=0.000009), - 1: approx(-1.9795234203338623, abs=0.000009), - 2: approx(0.3095782697200775, abs=0.000009), - 3: approx(-2.781409740447998, abs=0.000009), - 4: approx(-0.4492073655128479, abs=0.000009), - 5: approx(-1.7794851064682007, abs=0.000009), - 6: approx(-0.3115033507347107, abs=0.000009), - 7: approx(-2.0755014419555664, abs=0.000009), - 8: approx(-0.8829556107521057, abs=0.000009), - }, - { - 0: approx(7.968824863433838, abs=0.000009), - 1: approx(-2.0212578773498535, abs=0.000009), - 2: approx(-0.43740716576576233, abs=0.000009), - 3: approx(-3.6762948036193848, abs=0.000009), - 4: approx(-0.012759938836097717, abs=0.000009), - 5: approx(-2.320978879928589, abs=0.000009), - 6: approx(0.6871926784515381, abs=0.000009), - 7: approx(-2.541991949081421, abs=0.000009), - 8: approx(0.32606029510498047, abs=0.000009), - }, - ] + { + 0: approx(10.125612258911133, abs=0.000009), + 1: approx(-1.818464756011963, abs=0.000009), + 2: approx(-1.3191171884536743, abs=0.000009), + 3: approx(-1.9324339628219604, abs=0.000009), + 4: approx(-1.4850239753723145, abs=0.000009), + 5: approx(-1.69266676902771, abs=0.000009), + 6: approx(-0.898107647895813, abs=0.000009), + 7: approx(-1.725127935409546, abs=0.000009), + 8: approx(0.32057392597198486, abs=0.000009), + }, + { + 0: approx(9.647306442260742, abs=0.000009), + 1: approx(-2.112745523452759, abs=0.000009), + 2: approx(-0.8831668496131897, abs=0.000009), + 3: approx(-2.719135046005249, abs=0.000009), + 4: approx(-0.47473397850990295, abs=0.000009), + 5: approx(-2.2424018383026123, abs=0.000009), + 6: approx(0.6101505160331726, abs=0.000009), + 7: approx(-2.2192084789276123, abs=0.000009), + 8: approx(-0.654518187046051, abs=0.000009), + }, + { + 0: approx(10.364561080932617, abs=0.000009), + 1: approx(-2.240158796310425, abs=0.000009), + 2: approx(-0.9236820340156555, abs=0.000009), + 3: approx(-2.623302936553955, abs=0.000009), + 4: approx(-0.501063346862793, abs=0.000009), + 5: approx(-1.9418426752090454, abs=0.000009), + 6: approx(-0.041013482958078384, abs=0.000009), + 7: approx(-2.12089204788208, abs=0.000009), + 8: approx(-1.2565152645111084, abs=0.000009), + }, + { + 0: approx(9.21961784362793, abs=0.000009), + 1: approx(-2.1359012126922607, abs=0.000009), + 2: approx(-0.1689995676279068, abs=0.000009), + 3: approx(-3.0277929306030273, abs=0.000009), + 4: approx(0.2589835822582245, abs=0.000009), + 5: approx(-2.4426753520965576, abs=0.000009), + 6: approx(0.4815778136253357, abs=0.000009), + 7: approx(-2.3223936557769775, abs=0.000009), + 8: approx(-0.23837946355342865, abs=0.000009), + }, + { + 0: approx(11.023712158203125, abs=0.000009), + 1: approx(-2.0757784843444824, abs=0.000009), + 2: approx(-0.8648976683616638, abs=0.000009), + 3: approx(-2.6088907718658447, abs=0.000009), + 4: approx(-0.781069278717041, abs=0.000009), + 5: approx(-1.7040152549743652, abs=0.000009), + 6: approx(-0.5265535712242126, abs=0.000009), + 7: approx(-1.7141871452331543, abs=0.000009), + 8: approx(-1.0127886533737183, abs=0.000009), + }, + { + 0: approx(9.288057327270508, abs=0.000009), + 1: approx(-1.9795234203338623, abs=0.000009), + 2: approx(0.3095782697200775, abs=0.000009), + 3: approx(-2.781409740447998, abs=0.000009), + 4: approx(-0.4492073655128479, abs=0.000009), + 5: approx(-1.7794851064682007, abs=0.000009), + 6: approx(-0.3115033507347107, abs=0.000009), + 7: approx(-2.0755014419555664, abs=0.000009), + 8: approx(-0.8829556107521057, abs=0.000009), + }, + { + 0: approx(7.968824863433838, abs=0.000009), + 1: approx(-2.0212578773498535, abs=0.000009), + 2: approx(-0.43740716576576233, abs=0.000009), + 3: approx(-3.6762948036193848, abs=0.000009), + 4: approx(-0.012759938836097717, abs=0.000009), + 5: approx(-2.320978879928589, abs=0.000009), + 6: approx(0.6871926784515381, abs=0.000009), + 7: approx(-2.541991949081421, abs=0.000009), + 8: approx(0.32606029510498047, abs=0.000009), + }, ], ] } diff --git a/test/e2e/predictor/test_huggingface.py b/test/e2e/predictor/test_huggingface.py index 4fe521821c5..3588a3cb012 100644 --- a/test/e2e/predictor/test_huggingface.py +++ b/test/e2e/predictor/test_huggingface.py @@ -40,7 +40,7 @@ ) from .test_output import ( huggingface_text_embedding_expected_output, - huggingface_sequence_classification_with_probabilities_expected_output, + huggingface_sequence_classification_with_raw_logits_expected_output, ) from kserve.logging import trace_logger @@ -574,7 +574,7 @@ async def test_huggingface_openai_text_embedding(): @pytest.mark.llm @pytest.mark.asyncio(scope="session") -async def test_huggingface_v2_sequence_classification_with_probabilities( +async def test_huggingface_v2_sequence_classification_with_raw_logits( rest_v2_client, ): service_name = "hf-bert-sequence-v2-prob" @@ -595,7 +595,7 @@ async def test_huggingface_v2_sequence_classification_with_probabilities( "a4d0a85ea6c1d5bb944dcc12ea5c918863e469a4", "--backend", "huggingface", - "--return_probabilities", + "--return_raw_logits", ], resources=V1ResourceRequirements( requests={"cpu": "1", "memory": "2Gi"}, @@ -633,7 +633,65 @@ async def test_huggingface_v2_sequence_classification_with_probabilities( parsed_output = [ast.literal_eval(res.outputs[0].data[0])] assert ( parsed_output - == huggingface_sequence_classification_with_probabilities_expected_output + == huggingface_sequence_classification_with_raw_logits_expected_output + ) + + kserve_client.delete(service_name, KSERVE_TEST_NAMESPACE) + + +@pytest.mark.llm +@pytest.mark.asyncio(scope="session") +async def test_huggingface_v2_sequence_classification_with_probabilities( + rest_v2_client, +): + service_name = "hf-bert-sequence-v2-logits" + protocol_version = "v2" + predictor = V1beta1PredictorSpec( + min_replicas=1, + model=V1beta1ModelSpec( + model_format=V1beta1ModelFormat( + name="huggingface", + ), + protocol_version=protocol_version, + args=[ + "--model_id", + "textattack/bert-base-uncased-yelp-polarity", + "--model_revision", + "a4d0a85ea6c1d5bb944dcc12ea5c918863e469a4", + "--tokenizer_revision", + "a4d0a85ea6c1d5bb944dcc12ea5c918863e469a4", + "--backend", + "huggingface", + "--return_probabilities", + ], + resources=V1ResourceRequirements( + requests={"cpu": "1", "memory": "2Gi"}, + limits={"cpu": "1", "memory": "4Gi"}, + ), + ), + ) + + isvc = V1beta1InferenceService( + api_version=constants.KSERVE_V1BETA1, + kind=constants.KSERVE_KIND_INFERENCESERVICE, + metadata=client.V1ObjectMeta( + name=service_name, namespace=KSERVE_TEST_NAMESPACE + ), + spec=V1beta1InferenceServiceSpec(predictor=predictor), + ) + + kserve_client = KServeClient( + config_file=os.environ.get("KUBECONFIG", "~/.kube/config") + ) + kserve_client.create(isvc) + kserve_client.wait_isvc_ready(service_name, namespace=KSERVE_TEST_NAMESPACE) + + res = await predict_isvc( + rest_v2_client, + service_name, + "./data/bert_sequence_classification_v2.json", ) + output = ast.literal_eval(res.outputs[0].data[0]) + assert output == {0: 0.0094, 1: 0.9906} kserve_client.delete(service_name, KSERVE_TEST_NAMESPACE) diff --git a/test/e2e/predictor/test_output.py b/test/e2e/predictor/test_output.py index 21b719ee57c..7b97ef82170 100644 --- a/test/e2e/predictor/test_output.py +++ b/test/e2e/predictor/test_output.py @@ -401,7 +401,7 @@ approx(-0.029213037341833115, abs=1e-6), ] -huggingface_sequence_classification_with_probabilities_expected_output = [ +huggingface_sequence_classification_with_raw_logits_expected_output = [ { 0: approx(-2.152204, abs=0.000009), 1: approx(2.5094059, abs=0.000009), From 4edbb36c520c2e880842229bfc56b7f11d766822 Mon Sep 17 00:00:00 2001 From: Hannah DeFazio Date: Sun, 20 Jul 2025 21:57:35 -0400 Subject: [PATCH 4/9] Stop and resume an inference graph (#4588) Signed-off-by: Hannah DeFazio --- .../v1alpha1/inferencegraph/controller.go | 49 ++++ .../inferencegraph/controller_test.go | 269 +++++++++++++++++- .../inferencegraph/knative_reconciler.go | 24 +- 3 files changed, 328 insertions(+), 14 deletions(-) diff --git a/pkg/controller/v1alpha1/inferencegraph/controller.go b/pkg/controller/v1alpha1/inferencegraph/controller.go index eec10409ece..1073c53568f 100644 --- a/pkg/controller/v1alpha1/inferencegraph/controller.go +++ b/pkg/controller/v1alpha1/inferencegraph/controller.go @@ -143,6 +143,9 @@ func (r *InferenceGraphReconciler) Reconcile(ctx context.Context, req ctrl.Reque } r.Log.Info("Reconciling inference graph", "apiVersion", graph.APIVersion, "graph", graph.Name) + + forceStopRuntime := utils.GetForceStopRuntime(graph) + configMap, err := r.Clientset.CoreV1().ConfigMaps(constants.KServeNamespace).Get(ctx, constants.InferenceServiceConfigMapName, metav1.GetOptions{}) if err != nil { r.Log.Error(err, "Failed to find config map", "name", constants.InferenceServiceConfigMapName) @@ -257,6 +260,52 @@ func (r *InferenceGraphReconciler) Reconcile(ctx context.Context, req ctrl.Reque } } + // Handle InferenceGraph status updates based on the force stop annotation. + // If true, transition the service to a stopped and unready state; otherwise, ensure it's not marked as stopped. + transition_time := apis.VolatileTime{Inner: metav1.Now()} + existingStoppedCondition := graph.Status.GetCondition(v1beta1.Stopped) + if existingStoppedCondition == nil { + defaultStoppedCondition := apis.Condition{ + LastTransitionTime: transition_time, + Type: v1beta1.Stopped, + Status: corev1.ConditionFalse, + } + graph.Status.Conditions = append(graph.Status.Conditions, defaultStoppedCondition) + } + existingStoppedCondition = graph.Status.GetCondition(v1beta1.Stopped) + if forceStopRuntime { + // If the graph's stopped condition is not set or + // If the graph is currently running, update its status to signal that it should be stopped + if existingStoppedCondition.Status == corev1.ConditionFalse { + // Add the stopped condition + stoppedCondition := apis.Condition{ + LastTransitionTime: transition_time, + Type: v1beta1.Stopped, + Status: corev1.ConditionTrue, + } + readyCondition := apis.Condition{ + LastTransitionTime: transition_time, + Type: apis.ConditionReady, + Status: corev1.ConditionFalse, + Reason: v1beta1.StoppedISVCReason, + } + graph.Status.Conditions = []apis.Condition{stoppedCondition, readyCondition} + + graph.Status.URL = nil + } + } else { + // If the graph's stopped condition is not set or + // If the graph is currently stopped, update its status to signal that it should resume + if existingStoppedCondition.Status == corev1.ConditionTrue { + resumeCondition := apis.Condition{ + LastTransitionTime: transition_time, + Type: v1beta1.Stopped, + Status: corev1.ConditionFalse, + } + graph.Status.Conditions = append(graph.Status.Conditions, resumeCondition) + } + } + if err := r.updateStatus(ctx, graph); err != nil { r.Recorder.Eventf(graph, corev1.EventTypeWarning, "InternalError", err.Error()) return reconcile.Result{}, err diff --git a/pkg/controller/v1alpha1/inferencegraph/controller_test.go b/pkg/controller/v1alpha1/inferencegraph/controller_test.go index 3bed025d584..fd59d4def52 100644 --- a/pkg/controller/v1alpha1/inferencegraph/controller_test.go +++ b/pkg/controller/v1alpha1/inferencegraph/controller_test.go @@ -18,6 +18,7 @@ package inferencegraph import ( "context" + "fmt" "time" . "github.com/onsi/ginkgo/v2" @@ -25,9 +26,11 @@ import ( "google.golang.org/protobuf/proto" appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" + apierr "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/api/resource" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" + "knative.dev/pkg/apis" "knative.dev/pkg/kmp" "knative.dev/serving/pkg/apis/autoscaling" knservingv1 "knative.dev/serving/pkg/apis/serving/v1" @@ -35,6 +38,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/reconcile" "github.com/kserve/kserve/pkg/apis/serving/v1alpha1" + "github.com/kserve/kserve/pkg/apis/serving/v1beta1" "github.com/kserve/kserve/pkg/constants" "github.com/kserve/kserve/pkg/utils" ) @@ -49,18 +53,34 @@ var _ = Describe("Inference Graph controller test", func() { configs := map[string]string{ "router": `{ - "image": "kserve/router:v0.10.0", - "memoryRequest": "100Mi", - "memoryLimit": "500Mi", - "cpuRequest": "100m", - "cpuLimit": "100m", - "headers": { - "propagate": [ - "Authorization", - "Intuit_tid" - ] - } - }`, + "image": "kserve/router:v0.10.0", + "memoryRequest": "100Mi", + "memoryLimit": "500Mi", + "cpuRequest": "100m", + "cpuLimit": "100m", + "headers": { + "propagate": [ + "Authorization", + "Intuit_tid" + ] + } + }`, + "ingress": `{ + "kserveIngressGateway": "kserve/kserve-ingress-gateway", + "ingressGateway": "knative-serving/knative-ingress-gateway", + "localGateway": "knative-serving/knative-local-gateway", + "localGatewayService": "knative-local-gateway.istio-system.svc.cluster.local" + }`, + "storageInitializer": `{ + "image" : "kserve/storage-initializer:latest", + "memoryRequest": "100Mi", + "memoryLimit": "1Gi", + "cpuRequest": "100m", + "cpuLimit": "1", + "CaBundleConfigMapName": "", + "caBundleVolumeMountPath": "/etc/ssl/custom-certs", + "enableDirectPvcVolumeMount": false + }`, } expectedReadinessProbe := constants.GetRouterReadinessProbe() @@ -1041,6 +1061,231 @@ var _ = Describe("Inference Graph controller test", func() { }) }) + Context("When creating an InferenceGraph with `serving.kserve.io/stop`", func() { + // --- Default values --- + createIGConfigMap := func() *corev1.ConfigMap { + configMap := &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: constants.InferenceServiceConfigMapName, + Namespace: constants.KServeNamespace, + }, + Data: configs, + } + return configMap + } + + // --- Reusable Check Functions --- + // Wait for the IG to exist. + expectIGToExist := func(ctx context.Context, serviceKey types.NamespacedName) v1alpha1.InferenceGraph { + actualIG := &v1alpha1.InferenceGraph{} + Eventually(func() bool { + err := k8sClient.Get(ctx, serviceKey, actualIG) + return err == nil + }, timeout, interval).Should(BeTrue()) + + return *actualIG + } + + // Waits for any Kubernestes object to be found + expectResourceToExist := func(ctx context.Context, obj client.Object, objKey types.NamespacedName) { + Eventually(func() bool { + err := k8sClient.Get(ctx, objKey, obj) + return err == nil + }, timeout, interval).Should(BeTrue(), "%T %s should exist", obj, objKey.Name) + } + + // Checks that any Kubernetes object to be not found. + expectResourceIsDeleted := func(ctx context.Context, obj client.Object, objKey types.NamespacedName) { + Consistently(func() bool { + err := k8sClient.Get(ctx, objKey, obj) + return apierr.IsNotFound(err) + }, time.Second*10, interval).Should(BeTrue(), "%T %s should not be created", obj, objKey.Name) + } + + // Wait for any Kubernetes object to be not found. + expectResourceToBeDeleted := func(ctx context.Context, obj client.Object, objKey types.NamespacedName) { + Eventually(func() bool { + err := k8sClient.Get(ctx, objKey, obj) + return apierr.IsNotFound(err) + }, timeout, interval).Should(BeTrue(), "%T %s should be deleted", obj, objKey.Name) + } + + // Wait for a specific condition on an InferenceGraph to reach the desired status + expectIGConditionStatus := func(ctx context.Context, serviceKey types.NamespacedName, conditionType apis.ConditionType, expectedStatus corev1.ConditionStatus) { + message := fmt.Sprintf("The '%s' condition for InferenceGraph '%s' should be '%s'", + conditionType, serviceKey.Name, expectedStatus) + + actualIg := &v1alpha1.InferenceGraph{} + Eventually(func() bool { + err := k8sClient.Get(ctx, serviceKey, actualIg) + if err == nil { + cond := actualIg.Status.GetCondition(conditionType) + if cond != nil && cond.Status == expectedStatus { + return true + } + } + return false + }, timeout, interval).Should(BeTrue(), message) + } + + Describe("in Serverless mode", func() { + // --- Default values --- + defaultIG := func(serviceKey types.NamespacedName) *v1alpha1.InferenceGraph { + ig := &v1alpha1.InferenceGraph{ + ObjectMeta: metav1.ObjectMeta{ + Name: serviceKey.Name, + Namespace: serviceKey.Namespace, + Annotations: map[string]string{ + "serving.kserve.io/deploymentMode": string(constants.Serverless), + }, + }, + Spec: v1alpha1.InferenceGraphSpec{ + Nodes: map[string]v1alpha1.InferenceRouter{ + v1alpha1.GraphRootNodeName: { + RouterType: v1alpha1.Sequence, + Steps: []v1alpha1.InferenceStep{ + { + InferenceTarget: v1alpha1.InferenceTarget{ + ServiceURL: "http://someservice.exmaple.com", + }, + }, + }, + }, + }, + }, + } + return ig + } + + It("Should keep the knative service when the annotation is set to false", func() { + ctx, cancel := context.WithCancel(context.Background()) + DeferCleanup(cancel) + + // Config map + configMap := createIGConfigMap() + Expect(k8sClient.Create(context.TODO(), configMap)).NotTo(HaveOccurred()) + defer k8sClient.Delete(context.TODO(), configMap) + + // Define InferenceGraph + serviceNamespace := "default" + graphName := "stop-false-ig" + graphExpectedRequest := reconcile.Request{NamespacedName: types.NamespacedName{Name: graphName, Namespace: serviceNamespace}} + graphServiceKey := graphExpectedRequest.NamespacedName + ig := defaultIG(graphServiceKey) + ig.Annotations[constants.StopAnnotationKey] = "false" + Expect(k8sClient.Create(ctx, ig)).Should(Succeed()) + defer k8sClient.Delete(ctx, ig) + + // Check the inference graph + expectResourceToExist(context.Background(), &knservingv1.Service{}, graphServiceKey) + expectIGToExist(context.Background(), graphServiceKey) + + expectIGConditionStatus(ctx, graphServiceKey, v1beta1.Stopped, corev1.ConditionFalse) + }) + + It("Should not create the knative service when the annotation is set to true", func() { + ctx, cancel := context.WithCancel(context.Background()) + DeferCleanup(cancel) + + configMap := createIGConfigMap() + Expect(k8sClient.Create(ctx, configMap)).NotTo(HaveOccurred()) + defer k8sClient.Delete(ctx, configMap) + + graphName := "stop-true-ig" + serviceNamespace := "default" + expectedRequest := reconcile.Request{NamespacedName: types.NamespacedName{Name: graphName, Namespace: serviceNamespace}} + graphServiceKey := expectedRequest.NamespacedName + ig := defaultIG(graphServiceKey) + ig.Annotations[constants.StopAnnotationKey] = "true" + Expect(k8sClient.Create(context.Background(), ig)).Should(Succeed()) + defer k8sClient.Delete(context.Background(), ig) + + // Check that the knative service was not created + expectResourceIsDeleted(context.Background(), &knservingv1.Service{}, graphServiceKey) + + // Check the inference graph + expectIGToExist(context.Background(), graphServiceKey) + expectIGConditionStatus(ctx, graphServiceKey, v1beta1.Stopped, corev1.ConditionTrue) + }) + + It("Should delete the knative service when the annotation is updated to true on an existing IG", func() { + ctx, cancel := context.WithCancel(context.Background()) + DeferCleanup(cancel) + + // Config map + configMap := createIGConfigMap() + Expect(k8sClient.Create(context.TODO(), configMap)).NotTo(HaveOccurred()) + defer k8sClient.Delete(context.TODO(), configMap) + + // Define InferenceGraph + serviceNamespace := "default" + graphName := "stop-update-true-ig" + graphExpectedRequest := reconcile.Request{NamespacedName: types.NamespacedName{Name: graphName, Namespace: serviceNamespace}} + graphServiceKey := graphExpectedRequest.NamespacedName + ig := defaultIG(graphServiceKey) + ig.Annotations[constants.StopAnnotationKey] = "false" + Expect(k8sClient.Create(ctx, ig)).Should(Succeed()) + defer k8sClient.Delete(ctx, ig) + + // Check the inference graph + expectResourceToExist(context.Background(), &knservingv1.Service{}, graphServiceKey) + expectIGToExist(context.Background(), graphServiceKey) + + expectIGConditionStatus(ctx, graphServiceKey, v1beta1.Stopped, corev1.ConditionFalse) + + // Stop the inference graph + actualIG := expectIGToExist(ctx, graphServiceKey) + updatedIG := actualIG.DeepCopy() + updatedIG.Annotations[constants.StopAnnotationKey] = "true" + Expect(k8sClient.Update(ctx, updatedIG)).NotTo(HaveOccurred()) + + // Check that the knative service was deleted + expectResourceToBeDeleted(context.Background(), &knservingv1.Service{}, graphServiceKey) + + // Check the inference graph + expectIGToExist(context.Background(), graphServiceKey) + expectIGConditionStatus(ctx, graphServiceKey, v1beta1.Stopped, corev1.ConditionTrue) + }) + + It("Should create the knative service when the annotation is updated to false on an existing IG", func() { + ctx, cancel := context.WithCancel(context.Background()) + DeferCleanup(cancel) + + configMap := createIGConfigMap() + Expect(k8sClient.Create(ctx, configMap)).NotTo(HaveOccurred()) + defer k8sClient.Delete(ctx, configMap) + + graphName := "stop-update-false-ig" + serviceNamespace := "default" + expectedRequest := reconcile.Request{NamespacedName: types.NamespacedName{Name: graphName, Namespace: serviceNamespace}} + graphServiceKey := expectedRequest.NamespacedName + ig := defaultIG(graphServiceKey) + ig.Annotations[constants.StopAnnotationKey] = "true" + Expect(k8sClient.Create(context.Background(), ig)).Should(Succeed()) + defer k8sClient.Delete(context.Background(), ig) + + // Check that the knative service was not created + expectResourceIsDeleted(context.Background(), &knservingv1.Service{}, graphServiceKey) + + // Check the inference graph + expectIGToExist(context.Background(), graphServiceKey) + expectIGConditionStatus(ctx, graphServiceKey, v1beta1.Stopped, corev1.ConditionTrue) + + // Resume the inference graph + actualIG := expectIGToExist(ctx, graphServiceKey) + updatedIG := actualIG.DeepCopy() + updatedIG.Annotations[constants.StopAnnotationKey] = "false" + Expect(k8sClient.Update(ctx, updatedIG)).NotTo(HaveOccurred()) + + // Check the inference graph + expectResourceToExist(context.Background(), &knservingv1.Service{}, graphServiceKey) + expectIGToExist(context.Background(), graphServiceKey) + + expectIGConditionStatus(ctx, graphServiceKey, v1beta1.Stopped, corev1.ConditionFalse) + }) + }) + }) + Context("When creating an IG with tolerations in the spec", func() { It("Should propagate to underlying pod", func() { configMap := &corev1.ConfigMap{ diff --git a/pkg/controller/v1alpha1/inferencegraph/knative_reconciler.go b/pkg/controller/v1alpha1/inferencegraph/knative_reconciler.go index 8db052ad46f..13490ce875c 100644 --- a/pkg/controller/v1alpha1/inferencegraph/knative_reconciler.go +++ b/pkg/controller/v1alpha1/inferencegraph/knative_reconciler.go @@ -92,12 +92,28 @@ func (r *GraphKnativeServiceReconciler) Reconcile(ctx context.Context) (*knservi desired := r.Service existing := &knservingv1.Service{} + forceStopRuntime := false + if val, exist := desired.Spec.Template.Annotations[constants.StopAnnotationKey]; exist { + forceStopRuntime = strings.EqualFold(val, "true") + } + err := retry.RetryOnConflict(retry.DefaultBackoff, func() error { log.Info("Updating inference graph knative service", "namespace", desired.Namespace, "name", desired.Name) if err := r.client.Get(ctx, types.NamespacedName{Name: desired.Name, Namespace: desired.Namespace}, existing); err != nil { return err } + if forceStopRuntime { + log.Info("Deleting inference graph knative service", "namespace", existing.Namespace, "name", existing.Name) + if existing.GetDeletionTimestamp() == nil { // check if the ksvc was already deleted + err := r.client.Delete(ctx, existing) + if err != nil { + return err + } + } + return nil + } + // Set ResourceVersion which is required for update operation. desired.ResourceVersion = existing.ResourceVersion // Add immutable annotations to avoid validation error during dry-run update. @@ -121,9 +137,13 @@ func (r *GraphKnativeServiceReconciler) Reconcile(ctx context.Context) (*knservi return r.client.Update(ctx, existing) }) if err != nil { + // Create service if it does not exist if apierr.IsNotFound(err) { - log.Info("Creating inference graph knative service", "namespace", desired.Namespace, "name", desired.Name) - return &desired.Status, r.client.Create(ctx, desired) + if !forceStopRuntime { + log.Info("Creating inference graph knative service", "namespace", desired.Namespace, "name", desired.Name) + return &desired.Status, r.client.Create(ctx, desired) + } + return &desired.Status, nil } return &existing.Status, errors.Wrapf(err, "fails to reconcile inference graph knative service") } From 766e34507f819d44edd224dfcc9d15a603d6fcac Mon Sep 17 00:00:00 2001 From: Hare Hutaki <132063819+HutakiHare@users.noreply.github.com> Date: Wed, 23 Jul 2025 16:40:41 +0800 Subject: [PATCH 5/9] disallow `name` field in standard predictor (#4535) Signed-off-by: HutakiHare --- .../v1beta1/inference_service_validation.go | 38 +++++++++++++++++++ .../inference_service_validation_test.go | 26 +++++++++++++ 2 files changed, 64 insertions(+) diff --git a/pkg/apis/serving/v1beta1/inference_service_validation.go b/pkg/apis/serving/v1beta1/inference_service_validation.go index 7f5dfd5d0f5..f32c07fbfb2 100644 --- a/pkg/apis/serving/v1beta1/inference_service_validation.go +++ b/pkg/apis/serving/v1beta1/inference_service_validation.go @@ -125,6 +125,10 @@ func validateInferenceService(isvc *InferenceService) (admission.Warnings, error return allWarnings, err } + if err := validatePredictor(isvc); err != nil { + return allWarnings, err + } + for _, component := range []Component{ &isvc.Spec.Predictor, isvc.Spec.Transformer, @@ -146,6 +150,40 @@ func validateInferenceService(isvc *InferenceService) (admission.Warnings, error return allWarnings, nil } +func validatePredictor(isvc *InferenceService) error { + predictor := isvc.Spec.Predictor + + // log predictor + validatorLogger.Info("Incoming predictor struct", "predictor", predictor) + + // in most of the case, standard predictors will all be packed into `predictor.model`, and decide the backend process through `modelFormat.name`` + switch { + case predictor.SKLearn != nil && predictor.SKLearn.Name != "": + return errors.New("the 'name' field is not allowed in standard predictor") + case predictor.XGBoost != nil && predictor.XGBoost.Name != "": + return errors.New("the 'name' field is not allowed in standard predictor") + case predictor.Tensorflow != nil && predictor.Tensorflow.Name != "": + return errors.New("the 'name' field is not allowed in standard predictor") + case predictor.PyTorch != nil && predictor.PyTorch.Name != "": + return errors.New("the 'name' field is not allowed in standard predictor") + case predictor.Triton != nil && predictor.Triton.Name != "": + return errors.New("the 'name' field is not allowed in standard predictor") + case predictor.ONNX != nil && predictor.ONNX.Name != "": + return errors.New("the 'name' field is not allowed in standard predictor") + case predictor.HuggingFace != nil && predictor.HuggingFace.Name != "": + return errors.New("the 'name' field is not allowed in standard predictor") + case predictor.PMML != nil && predictor.PMML.Name != "": + return errors.New("the 'name' field is not allowed in standard predictor") + case predictor.LightGBM != nil && predictor.LightGBM.Name != "": + return errors.New("the 'name' field is not allowed in standard predictor") + case predictor.Paddle != nil && predictor.Paddle.Name != "": + return errors.New("the 'name' field is not allowed in standard predictor") + case predictor.Model != nil && predictor.Model.Name != "": + return errors.New("the 'name' field is not allowed in standard predictor") + } + return nil +} + // validateMultiNodeVariables validates when there is workerSpec set in isvc func validateMultiNodeVariables(isvc *InferenceService) error { if isvc.Spec.Predictor.WorkerSpec != nil { diff --git a/pkg/apis/serving/v1beta1/inference_service_validation_test.go b/pkg/apis/serving/v1beta1/inference_service_validation_test.go index 85c7e5e9a74..5e81fc827e3 100644 --- a/pkg/apis/serving/v1beta1/inference_service_validation_test.go +++ b/pkg/apis/serving/v1beta1/inference_service_validation_test.go @@ -18,6 +18,7 @@ package v1beta1 import ( "fmt" + "strings" "testing" "github.com/kserve/kserve/pkg/constants" @@ -32,6 +33,31 @@ import ( "k8s.io/utils/ptr" ) +func TestInvalidNameInSKLearnPredictor(t *testing.T) { + isvc := InferenceService{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-isvc", + }, + Spec: InferenceServiceSpec{ + Predictor: PredictorSpec{ + SKLearn: &SKLearnSpec{ + PredictorExtensionSpec: PredictorExtensionSpec{ + Container: corev1.Container{ + Name: "invalid-name", + Image: "dummy-image", + }, + StorageURI: proto.String("gs://kfserving-examples/models/sklearn/1.0/model"), + }, + }, + }, + }, + } + err := validatePredictor(&isvc) + if err == nil || !strings.Contains(err.Error(), "not allowed") { + t.Errorf("Expected error for name field in SKLearn predictor, got: %v", err) + } +} + func makeTestInferenceService() InferenceService { inferenceservice := InferenceService{ ObjectMeta: metav1.ObjectMeta{ From c579dd68e08afbd172bdff65a43ceed6099ca172 Mon Sep 17 00:00:00 2001 From: Sivanantham <90966311+sivanantha321@users.noreply.github.com> Date: Fri, 25 Jul 2025 08:55:02 +0530 Subject: [PATCH 6/9] fix: missing pytest-asyncio session scope marker in test_kserve_logger_cipn (#4607) Signed-off-by: Sivanantham Chinnaiyan --- test/e2e/logger/test_raw_logger.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/e2e/logger/test_raw_logger.py b/test/e2e/logger/test_raw_logger.py index e347aec5e46..0dcf497ebf7 100644 --- a/test/e2e/logger/test_raw_logger.py +++ b/test/e2e/logger/test_raw_logger.py @@ -66,6 +66,7 @@ async def test_kserve_logger(rest_v1_client, network_layer): await base_test(msg_dumper, service_name, predictor, rest_v1_client, network_layer) +@pytest.mark.asyncio(scope="session") @pytest.mark.rawcipn async def test_kserve_logger_cipn(rest_v1_client, network_layer): msg_dumper = "message-dumper-raw-cipn" From 5c0493da0c026731f2e07eac846eb2c45d3b884c Mon Sep 17 00:00:00 2001 From: Vedant Mahabaleshwarkar Date: Fri, 25 Jul 2025 13:13:43 -0400 Subject: [PATCH 7/9] update CRDs for LWS based multi node support (#4596) Signed-off-by: Vedant Mahabaleshwarkar Co-authored-by: Bartosz Majsak Co-authored-by: Pierangelo Di Pilato --- ....kserve.io_llminferenceserviceconfigs.yaml | 32 ++++- ...erving.kserve.io_llminferenceservices.yaml | 32 ++++- ....kserve.io_llminferenceserviceconfigs.yaml | 32 ++++- ...erving.kserve.io_llminferenceservices.yaml | 32 ++++- .../llm_inference_service_defaults.go | 32 +++++ .../llm_inference_service_lifecycle.go | 133 ++++++++++++++++++ .../v1alpha1/llm_inference_service_types.go | 19 ++- .../llm_inference_service_types_func.go | 75 ++++++++++ pkg/apis/serving/v1alpha1/v1alpha1.go | 4 + .../serving/v1alpha1/zz_generated.deepcopy.go | 19 ++- .../v1alpha1/llmisvc/config_merge_test.go | 48 +++---- 11 files changed, 405 insertions(+), 53 deletions(-) create mode 100644 pkg/apis/serving/v1alpha1/llm_inference_service_defaults.go create mode 100644 pkg/apis/serving/v1alpha1/llm_inference_service_lifecycle.go create mode 100644 pkg/apis/serving/v1alpha1/llm_inference_service_types_func.go diff --git a/charts/llmisvc-crd/templates/serving.kserve.io_llminferenceserviceconfigs.yaml b/charts/llmisvc-crd/templates/serving.kserve.io_llminferenceserviceconfigs.yaml index 4122b60eced..4598e42d17b 100644 --- a/charts/llmisvc-crd/templates/serving.kserve.io_llminferenceserviceconfigs.yaml +++ b/charts/llmisvc-crd/templates/serving.kserve.io_llminferenceserviceconfigs.yaml @@ -85,22 +85,44 @@ spec: type: object parallelism: properties: + data: + format: int32 + type: integer + dataLocal: + format: int32 + type: integer + dataRPCPort: + format: int32 + type: integer + expert: + type: boolean pipeline: - format: int64 + format: int32 type: integer tensor: - format: int64 + format: int32 type: integer type: object prefill: properties: parallelism: properties: + data: + format: int32 + type: integer + dataLocal: + format: int32 + type: integer + dataRPCPort: + format: int32 + type: integer + expert: + type: boolean pipeline: - format: int64 + format: int32 type: integer tensor: - format: int64 + format: int32 type: integer type: object replicas: @@ -19538,8 +19560,6 @@ spec: required: - containers type: object - required: - - model type: object type: object served: true diff --git a/charts/llmisvc-crd/templates/serving.kserve.io_llminferenceservices.yaml b/charts/llmisvc-crd/templates/serving.kserve.io_llminferenceservices.yaml index 9e7e4d68a74..3fa7b17282c 100644 --- a/charts/llmisvc-crd/templates/serving.kserve.io_llminferenceservices.yaml +++ b/charts/llmisvc-crd/templates/serving.kserve.io_llminferenceservices.yaml @@ -104,22 +104,44 @@ spec: type: object parallelism: properties: + data: + format: int32 + type: integer + dataLocal: + format: int32 + type: integer + dataRPCPort: + format: int32 + type: integer + expert: + type: boolean pipeline: - format: int64 + format: int32 type: integer tensor: - format: int64 + format: int32 type: integer type: object prefill: properties: parallelism: properties: + data: + format: int32 + type: integer + dataLocal: + format: int32 + type: integer + dataRPCPort: + format: int32 + type: integer + expert: + type: boolean pipeline: - format: int64 + format: int32 type: integer tensor: - format: int64 + format: int32 type: integer type: object replicas: @@ -19557,8 +19579,6 @@ spec: required: - containers type: object - required: - - model type: object status: properties: diff --git a/config/crd/full/serving.kserve.io_llminferenceserviceconfigs.yaml b/config/crd/full/serving.kserve.io_llminferenceserviceconfigs.yaml index 4122b60eced..4598e42d17b 100644 --- a/config/crd/full/serving.kserve.io_llminferenceserviceconfigs.yaml +++ b/config/crd/full/serving.kserve.io_llminferenceserviceconfigs.yaml @@ -85,22 +85,44 @@ spec: type: object parallelism: properties: + data: + format: int32 + type: integer + dataLocal: + format: int32 + type: integer + dataRPCPort: + format: int32 + type: integer + expert: + type: boolean pipeline: - format: int64 + format: int32 type: integer tensor: - format: int64 + format: int32 type: integer type: object prefill: properties: parallelism: properties: + data: + format: int32 + type: integer + dataLocal: + format: int32 + type: integer + dataRPCPort: + format: int32 + type: integer + expert: + type: boolean pipeline: - format: int64 + format: int32 type: integer tensor: - format: int64 + format: int32 type: integer type: object replicas: @@ -19538,8 +19560,6 @@ spec: required: - containers type: object - required: - - model type: object type: object served: true diff --git a/config/crd/full/serving.kserve.io_llminferenceservices.yaml b/config/crd/full/serving.kserve.io_llminferenceservices.yaml index 9e7e4d68a74..3fa7b17282c 100644 --- a/config/crd/full/serving.kserve.io_llminferenceservices.yaml +++ b/config/crd/full/serving.kserve.io_llminferenceservices.yaml @@ -104,22 +104,44 @@ spec: type: object parallelism: properties: + data: + format: int32 + type: integer + dataLocal: + format: int32 + type: integer + dataRPCPort: + format: int32 + type: integer + expert: + type: boolean pipeline: - format: int64 + format: int32 type: integer tensor: - format: int64 + format: int32 type: integer type: object prefill: properties: parallelism: properties: + data: + format: int32 + type: integer + dataLocal: + format: int32 + type: integer + dataRPCPort: + format: int32 + type: integer + expert: + type: boolean pipeline: - format: int64 + format: int32 type: integer tensor: - format: int64 + format: int32 type: integer type: object replicas: @@ -19557,8 +19579,6 @@ spec: required: - containers type: object - required: - - model type: object status: properties: diff --git a/pkg/apis/serving/v1alpha1/llm_inference_service_defaults.go b/pkg/apis/serving/v1alpha1/llm_inference_service_defaults.go new file mode 100644 index 00000000000..899c23945e2 --- /dev/null +++ b/pkg/apis/serving/v1alpha1/llm_inference_service_defaults.go @@ -0,0 +1,32 @@ +/* +Copyright 2025 The KServe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package v1alpha1 + +import ( + "context" + + "k8s.io/utils/ptr" + "knative.dev/pkg/apis" +) + +var _ apis.Defaultable = &LLMInferenceService{} + +func (in *LLMInferenceService) SetDefaults(_ context.Context) { + if in.Spec.Model.Name == nil || *in.Spec.Model.Name == "" { + in.Spec.Model.Name = ptr.To(in.GetName()) + } +} diff --git a/pkg/apis/serving/v1alpha1/llm_inference_service_lifecycle.go b/pkg/apis/serving/v1alpha1/llm_inference_service_lifecycle.go new file mode 100644 index 00000000000..f42453d4cbd --- /dev/null +++ b/pkg/apis/serving/v1alpha1/llm_inference_service_lifecycle.go @@ -0,0 +1,133 @@ +/* +Copyright 2025 The KServe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package v1alpha1 + +import ( + "knative.dev/pkg/apis" + duckv1 "knative.dev/pkg/apis/duck/v1" +) + +const ( + PresetsCombined apis.ConditionType = "PresetsCombined" + WorkloadReady apis.ConditionType = "WorkloadsReady" + RouterReady apis.ConditionType = "RouterReady" +) + +const ( + MainWorkloadReady apis.ConditionType = "MainWorkloadReady" + WorkerWorkloadReady apis.ConditionType = "WorkerWorkloadReady" + PrefillWorkloadReady apis.ConditionType = "PrefillWorkloadReady" + PrefillWorkerWorkloadReady apis.ConditionType = "PrefillWorkerWorkloadReady" +) + +const ( + SchedulerWorkloadReady apis.ConditionType = "SchedulerWorkloadReady" +) + +var llmInferenceServiceCondSet = apis.NewLivingConditionSet( + WorkloadReady, + RouterReady, +) + +func (in *LLMInferenceService) GetStatus() *duckv1.Status { + return &in.Status.Status +} + +func (in *LLMInferenceService) GetConditionSet() apis.ConditionSet { + return llmInferenceServiceCondSet +} + +func (in *LLMInferenceService) MarkWorkloadNotReady(reason, messageFormat string, messageA ...interface{}) { + in.GetConditionSet().Manage(in.GetStatus()).MarkFalse(WorkloadReady, reason, messageFormat, messageA...) +} + +func (in *LLMInferenceService) MarkMainWorkloadReady() { + in.GetConditionSet().Manage(in.GetStatus()).MarkTrue(MainWorkloadReady) +} + +func (in *LLMInferenceService) MarkMainWorkloadNotReady(reason, messageFormat string, messageA ...interface{}) { + in.GetConditionSet().Manage(in.GetStatus()).MarkFalse(MainWorkloadReady, reason, messageFormat, messageA...) +} + +func (in *LLMInferenceService) MarkWorkerWorkloadReady() { + in.GetConditionSet().Manage(in.GetStatus()).MarkTrue(WorkerWorkloadReady) +} + +func (in *LLMInferenceService) MarkWorkerWorkloadNotReady(reason, messageFormat string, messageA ...interface{}) { + in.GetConditionSet().Manage(in.GetStatus()).MarkFalse(WorkerWorkloadReady, reason, messageFormat, messageA...) +} + +func (in *LLMInferenceService) MarkPrefillWorkloadReady() { + in.GetConditionSet().Manage(in.GetStatus()).MarkTrue(PrefillWorkloadReady) +} + +func (in *LLMInferenceService) MarkPrefillWorkloadNotReady(reason, messageFormat string, messageA ...interface{}) { + in.GetConditionSet().Manage(in.GetStatus()).MarkFalse(PrefillWorkloadReady, reason, messageFormat, messageA...) +} + +func (in *LLMInferenceService) MarkPrefillWorkerWorkloadReady() { + in.GetConditionSet().Manage(in.GetStatus()).MarkTrue(PrefillWorkerWorkloadReady) +} + +func (in *LLMInferenceService) MarkPrefillWorkerWorkloadNotReady(reason, messageFormat string, messageA ...interface{}) { + in.GetConditionSet().Manage(in.GetStatus()).MarkFalse(PrefillWorkerWorkloadReady, reason, messageFormat, messageA...) +} + +func (in *LLMInferenceService) DetermineWorkloadReadiness() { + subConditions := []*apis.Condition{ + in.GetStatus().GetCondition(MainWorkloadReady), + in.GetStatus().GetCondition(WorkerWorkloadReady), + in.GetStatus().GetCondition(PrefillWorkloadReady), + in.GetStatus().GetCondition(PrefillWorkerWorkloadReady), + in.GetStatus().GetCondition(SchedulerWorkloadReady), + } + + for _, cond := range subConditions { + if cond == nil { + continue + } + if cond.IsFalse() { + in.GetConditionSet().Manage(in.GetStatus()).MarkFalse(WorkloadReady, cond.Reason, cond.Message) + return + } + } + in.GetConditionSet().Manage(in.GetStatus()).MarkTrue(WorkloadReady) +} + +func (in *LLMInferenceService) MarkRouterNotReady(reason, messageFormat string, messageA ...interface{}) { + in.GetConditionSet().Manage(in.GetStatus()).MarkFalse(RouterReady, reason, messageFormat, messageA...) +} + +func (in *LLMInferenceService) MarkRouterReady() { + in.GetConditionSet().Manage(in.GetStatus()).MarkTrue(RouterReady) +} + +func (in *LLMInferenceService) MarkPresetsCombinedReady() { + in.GetConditionSet().Manage(in.GetStatus()).MarkTrue(PresetsCombined) +} + +func (in *LLMInferenceService) MarkPresetsCombinedNotReady(reason, messageFormat string, messageA ...interface{}) { + in.GetConditionSet().Manage(in.GetStatus()).MarkFalse(PresetsCombined, reason, messageFormat, messageA...) +} + +func (in *LLMInferenceService) MarkSchedulerWorkloadReady() { + in.GetConditionSet().Manage(in.GetStatus()).MarkTrue(SchedulerWorkloadReady) +} + +func (in *LLMInferenceService) MarkSchedulerWorkloadNotReady(reason, messageFormat string, messageA ...interface{}) { + in.GetConditionSet().Manage(in.GetStatus()).MarkFalse(SchedulerWorkloadReady, reason, messageFormat, messageA...) +} diff --git a/pkg/apis/serving/v1alpha1/llm_inference_service_types.go b/pkg/apis/serving/v1alpha1/llm_inference_service_types.go index c9fb37f50f6..c359a061d5f 100644 --- a/pkg/apis/serving/v1alpha1/llm_inference_service_types.go +++ b/pkg/apis/serving/v1alpha1/llm_inference_service_types.go @@ -59,6 +59,8 @@ type LLMInferenceServiceConfig struct { // LLMInferenceServiceSpec defines the desired state of LLMInferenceService. type LLMInferenceServiceSpec struct { // Model specification, including its URI, potential LoRA adapters, and storage details. + // It's optional for `LLMInferenceServiceConfig` kind. + // +optional Model LLMModelSpec `json:"model"` // WorkloadSpec configurations for the primary inference deployment. @@ -245,11 +247,22 @@ type InferencePoolSpec struct { type ParallelismSpec struct { // Tensor parallelism size. // +optional - Tensor *int64 `json:"tensor,omitempty"` + Tensor *int32 `json:"tensor,omitempty"` // Pipeline parallelism size. // +optional - Pipeline *int64 `json:"pipeline,omitempty"` - // TODO more to be added ... + Pipeline *int32 `json:"pipeline,omitempty"` + // Data parallelism size. + // +optional + Data *int32 `json:"data,omitempty"` + // DataLocal data local parallelism size. + // +optional + DataLocal *int32 `json:"dataLocal,omitempty"` + // DataRPCPort is the data parallelism RPC port. + // +optional + DataRPCPort *int32 `json:"dataRPCPort,omitempty"` + // Expert enables expert parallelism. + // +optional + Expert bool `json:"expert,omitempty"` } // LLMStorageSpec is a copy of the v1beta1.StorageSpec. It is duplicated here to avoid diff --git a/pkg/apis/serving/v1alpha1/llm_inference_service_types_func.go b/pkg/apis/serving/v1alpha1/llm_inference_service_types_func.go new file mode 100644 index 00000000000..968262bc8ed --- /dev/null +++ b/pkg/apis/serving/v1alpha1/llm_inference_service_types_func.go @@ -0,0 +1,75 @@ +/* +Copyright 2025 The KServe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package v1alpha1 + +import ( + "k8s.io/utils/ptr" +) + +func (in *GatewayRoutesSpec) IsManaged() bool { + return in != nil && in == &GatewayRoutesSpec{} +} + +func (in *GatewaySpec) HasRefs() bool { + return in != nil && len(in.Refs) > 0 +} + +func (r *HTTPRouteSpec) HasRefs() bool { + return r != nil && len(r.Refs) > 0 +} + +func (r *HTTPRouteSpec) HasSpec() bool { + return r != nil && r.Spec != nil +} + +func (p *InferencePoolSpec) HasRef() bool { + return p != nil && p.Ref != nil && p.Ref.Name != "" +} + +func (p *ParallelismSpec) IsPipelineParallel() bool { + if p == nil { + return false + } + return ptr.Deref(p.Pipeline, 0) > 0 +} + +func (p *ParallelismSpec) IsDataParallel() bool { + if p == nil { + return false + } + return ptr.Deref(p.Data, 0) > 0 || ptr.Deref(p.DataLocal, 0) > 0 +} + +func (p *ParallelismSpec) IsTensorParallel() bool { + if p == nil { + return false + } + return ptr.Deref(p.Tensor, 0) > 0 +} + +func (p *ParallelismSpec) GetSize() *int32 { + if p == nil { + return nil + } + if p.IsDataParallel() { + return ptr.To(max(ptr.Deref(p.Data, 1), 1) / max(ptr.Deref(p.DataLocal, 1), 1)) + } + if p.IsPipelineParallel() { + return p.Pipeline + } + return nil +} diff --git a/pkg/apis/serving/v1alpha1/v1alpha1.go b/pkg/apis/serving/v1alpha1/v1alpha1.go index eeae0f49a08..515eaaa036f 100644 --- a/pkg/apis/serving/v1alpha1/v1alpha1.go +++ b/pkg/apis/serving/v1alpha1/v1alpha1.go @@ -40,6 +40,10 @@ var ( // SchemeBuilder is used to add go types to the GroupVersionKind scheme SchemeBuilder = &scheme.Builder{GroupVersion: SchemeGroupVersion} + LLMInferenceServiceGVK = SchemeGroupVersion.WithKind("LLMInferenceService") + + LLMInferenceServiceConfigGVK = SchemeGroupVersion.WithKind("LLMInferenceServiceConfig") + // AddToScheme is required by pkg/client/... AddToScheme = SchemeBuilder.AddToScheme ) diff --git a/pkg/apis/serving/v1alpha1/zz_generated.deepcopy.go b/pkg/apis/serving/v1alpha1/zz_generated.deepcopy.go index de46cb10a5b..6aab3234fa3 100644 --- a/pkg/apis/serving/v1alpha1/zz_generated.deepcopy.go +++ b/pkg/apis/serving/v1alpha1/zz_generated.deepcopy.go @@ -1162,12 +1162,27 @@ func (in *ParallelismSpec) DeepCopyInto(out *ParallelismSpec) { *out = *in if in.Tensor != nil { in, out := &in.Tensor, &out.Tensor - *out = new(int64) + *out = new(int32) **out = **in } if in.Pipeline != nil { in, out := &in.Pipeline, &out.Pipeline - *out = new(int64) + *out = new(int32) + **out = **in + } + if in.Data != nil { + in, out := &in.Data, &out.Data + *out = new(int32) + **out = **in + } + if in.DataLocal != nil { + in, out := &in.DataLocal, &out.DataLocal + *out = new(int32) + **out = **in + } + if in.DataRPCPort != nil { + in, out := &in.DataRPCPort, &out.DataRPCPort + *out = new(int32) **out = **in } } diff --git a/pkg/controller/v1alpha1/llmisvc/config_merge_test.go b/pkg/controller/v1alpha1/llmisvc/config_merge_test.go index 92e96e949ca..27b751d4f8a 100644 --- a/pkg/controller/v1alpha1/llmisvc/config_merge_test.go +++ b/pkg/controller/v1alpha1/llmisvc/config_merge_test.go @@ -287,7 +287,7 @@ func TestMergeSpecs(t *testing.T) { { WorkloadSpec: v1alpha1.WorkloadSpec{ Parallelism: &v1alpha1.ParallelismSpec{ - Tensor: ptr.To[int64](2), + Tensor: ptr.To[int32](2), }, }, }, @@ -295,7 +295,7 @@ func TestMergeSpecs(t *testing.T) { { WorkloadSpec: v1alpha1.WorkloadSpec{ Parallelism: &v1alpha1.ParallelismSpec{ - Pipeline: ptr.To[int64](4), + Pipeline: ptr.To[int32](4), }, }, }, @@ -304,8 +304,8 @@ func TestMergeSpecs(t *testing.T) { WorkloadSpec: v1alpha1.WorkloadSpec{ // Both parallelism values should be present Parallelism: &v1alpha1.ParallelismSpec{ - Tensor: ptr.To[int64](2), - Pipeline: ptr.To[int64](4), + Tensor: ptr.To[int32](2), + Pipeline: ptr.To[int32](4), }, }, }, @@ -379,8 +379,8 @@ func TestMergeSpecs(t *testing.T) { { WorkloadSpec: v1alpha1.WorkloadSpec{ Parallelism: &v1alpha1.ParallelismSpec{ - Tensor: ptr.To[int64](1), - Pipeline: ptr.To[int64](1), + Tensor: ptr.To[int32](1), + Pipeline: ptr.To[int32](1), }, Worker: &corev1.PodSpec{ Containers: []corev1.Container{ @@ -425,8 +425,8 @@ func TestMergeSpecs(t *testing.T) { { WorkloadSpec: v1alpha1.WorkloadSpec{ Parallelism: &v1alpha1.ParallelismSpec{ - Tensor: ptr.To[int64](4), - Pipeline: ptr.To[int64](2), + Tensor: ptr.To[int32](4), + Pipeline: ptr.To[int32](2), }, Worker: &corev1.PodSpec{ Containers: []corev1.Container{ @@ -472,8 +472,8 @@ func TestMergeSpecs(t *testing.T) { }, WorkloadSpec: v1alpha1.WorkloadSpec{ Parallelism: &v1alpha1.ParallelismSpec{ - Tensor: ptr.To[int64](4), - Pipeline: ptr.To[int64](2), + Tensor: ptr.To[int32](4), + Pipeline: ptr.To[int32](2), }, Worker: &corev1.PodSpec{ Containers: []corev1.Container{ @@ -532,8 +532,8 @@ func TestMergeSpecs(t *testing.T) { { WorkloadSpec: v1alpha1.WorkloadSpec{ Parallelism: &v1alpha1.ParallelismSpec{ - Tensor: ptr.To[int64](1), - Pipeline: ptr.To[int64](1), + Tensor: ptr.To[int32](1), + Pipeline: ptr.To[int32](1), }, Worker: &corev1.PodSpec{ Containers: []corev1.Container{ @@ -564,8 +564,8 @@ func TestMergeSpecs(t *testing.T) { { WorkloadSpec: v1alpha1.WorkloadSpec{ Parallelism: &v1alpha1.ParallelismSpec{ - Tensor: ptr.To[int64](4), - Pipeline: ptr.To[int64](2), + Tensor: ptr.To[int32](4), + Pipeline: ptr.To[int32](2), }, Worker: &corev1.PodSpec{ Containers: []corev1.Container{ @@ -614,8 +614,8 @@ func TestMergeSpecs(t *testing.T) { }, WorkloadSpec: v1alpha1.WorkloadSpec{ Parallelism: &v1alpha1.ParallelismSpec{ - Tensor: ptr.To[int64](4), - Pipeline: ptr.To[int64](2), + Tensor: ptr.To[int32](4), + Pipeline: ptr.To[int32](2), }, Worker: &corev1.PodSpec{ Containers: []corev1.Container{ @@ -645,8 +645,8 @@ func TestMergeSpecs(t *testing.T) { { WorkloadSpec: v1alpha1.WorkloadSpec{ Parallelism: &v1alpha1.ParallelismSpec{ - Tensor: ptr.To[int64](1), - Pipeline: ptr.To[int64](1), + Tensor: ptr.To[int32](1), + Pipeline: ptr.To[int32](1), }, Worker: &corev1.PodSpec{ Containers: []corev1.Container{ @@ -707,8 +707,8 @@ func TestMergeSpecs(t *testing.T) { }, WorkloadSpec: v1alpha1.WorkloadSpec{ Parallelism: &v1alpha1.ParallelismSpec{ - Tensor: ptr.To[int64](1), - Pipeline: ptr.To[int64](1), + Tensor: ptr.To[int32](1), + Pipeline: ptr.To[int32](1), }, Worker: &corev1.PodSpec{ Containers: []corev1.Container{ @@ -964,7 +964,7 @@ func TestMergeSpecs(t *testing.T) { WorkloadSpec: v1alpha1.WorkloadSpec{ Replicas: ptr.To[int32](1), Parallelism: &v1alpha1.ParallelismSpec{ - Tensor: ptr.To[int64](2), + Tensor: ptr.To[int32](2), }, }, Router: &v1alpha1.RouterSpec{ @@ -992,7 +992,7 @@ func TestMergeSpecs(t *testing.T) { WorkloadSpec: v1alpha1.WorkloadSpec{ Replicas: ptr.To[int32](5), Parallelism: &v1alpha1.ParallelismSpec{ - Pipeline: ptr.To[int64](4), + Pipeline: ptr.To[int32](4), }, }, Router: &v1alpha1.RouterSpec{ @@ -1025,8 +1025,8 @@ func TestMergeSpecs(t *testing.T) { WorkloadSpec: v1alpha1.WorkloadSpec{ Replicas: ptr.To[int32](5), Parallelism: &v1alpha1.ParallelismSpec{ - Tensor: ptr.To[int64](2), // Base tensor preserved - Pipeline: ptr.To[int64](4), // Override pipeline + Tensor: ptr.To[int32](2), // Base tensor preserved + Pipeline: ptr.To[int32](4), // Override pipeline }, }, Router: &v1alpha1.RouterSpec{ From ed64a1e35d1959708f6b9da39aa8ce7d9fb1f477 Mon Sep 17 00:00:00 2001 From: Vedant Mahabaleshwarkar Date: Sat, 26 Jul 2025 21:45:55 -0400 Subject: [PATCH 8/9] Add LLM InferenceService base configurations (#4613) Signed-off-by: Vedant Mahabaleshwarkar Co-authored-by: Bartosz Majsak Co-authored-by: Pierangelo Di Pilato Co-authored-by: Sivanantham <90966311+sivanantha321@users.noreply.github.com> --- Makefile | 3 + .../templates/config-llm-decode-template.yaml | 131 ++++++ ...onfig-llm-decode-worker-data-parallel.yaml | 170 ++++++++ .../config-llm-prefill-template.yaml | 81 ++++ ...nfig-llm-prefill-worker-data-parallel.yaml | 124 ++++++ .../templates/config-llm-router-route.yaml | 38 ++ .../templates/config-llm-scheduler.yaml | 89 ++++ .../templates/config-llm-template.yaml | 82 ++++ .../config-llm-worker-data-parallel.yaml | 123 ++++++ .../llmisvc/config-llm-decode-template.yaml | 131 ++++++ ...onfig-llm-decode-worker-data-parallel.yaml | 170 ++++++++ .../llmisvc/config-llm-prefill-template.yaml | 81 ++++ ...nfig-llm-prefill-worker-data-parallel.yaml | 124 ++++++ config/llmisvc/config-llm-router-route.yaml | 38 ++ config/llmisvc/config-llm-scheduler.yaml | 89 ++++ config/llmisvc/config-llm-template.yaml | 82 ++++ .../config-llm-worker-data-parallel.yaml | 123 ++++++ config/llmisvc/kustomization.yaml | 14 + .../llm_inference_service_defaults.go | 1 + .../v1alpha1/llmisvc/config_merge.go | 164 ++++++++ .../v1alpha1/llmisvc/config_presets_test.go | 382 ++++++++++++++++++ pkg/controller/v1alpha1/llmisvc/sample.go | 226 +++++++++++ pkg/controller/v1alpha1/llmisvc/utils.go | 32 ++ 23 files changed, 2498 insertions(+) create mode 100644 charts/llmisvc-resources/templates/config-llm-decode-template.yaml create mode 100644 charts/llmisvc-resources/templates/config-llm-decode-worker-data-parallel.yaml create mode 100644 charts/llmisvc-resources/templates/config-llm-prefill-template.yaml create mode 100644 charts/llmisvc-resources/templates/config-llm-prefill-worker-data-parallel.yaml create mode 100644 charts/llmisvc-resources/templates/config-llm-router-route.yaml create mode 100644 charts/llmisvc-resources/templates/config-llm-scheduler.yaml create mode 100644 charts/llmisvc-resources/templates/config-llm-template.yaml create mode 100644 charts/llmisvc-resources/templates/config-llm-worker-data-parallel.yaml create mode 100644 config/llmisvc/config-llm-decode-template.yaml create mode 100644 config/llmisvc/config-llm-decode-worker-data-parallel.yaml create mode 100644 config/llmisvc/config-llm-prefill-template.yaml create mode 100644 config/llmisvc/config-llm-prefill-worker-data-parallel.yaml create mode 100644 config/llmisvc/config-llm-router-route.yaml create mode 100644 config/llmisvc/config-llm-scheduler.yaml create mode 100644 config/llmisvc/config-llm-template.yaml create mode 100644 config/llmisvc/config-llm-worker-data-parallel.yaml create mode 100644 config/llmisvc/kustomization.yaml create mode 100644 pkg/controller/v1alpha1/llmisvc/config_presets_test.go create mode 100644 pkg/controller/v1alpha1/llmisvc/sample.go create mode 100644 pkg/controller/v1alpha1/llmisvc/utils.go diff --git a/Makefile b/Makefile index 8436a286af0..34d31c2a92e 100644 --- a/Makefile +++ b/Makefile @@ -96,6 +96,9 @@ manifests: controller-gen yq echo '{{- if .Values.kserve.localmodel.enabled }}'> charts/kserve-resources/templates/localmodelnode/role.yaml cat config/rbac/localmodelnode/role.yaml >> charts/kserve-resources/templates/localmodelnode/role.yaml echo '{{- end }}' >> charts/kserve-resources/templates/localmodelnode/role.yaml + # Copy the llmisvc templates + cp config/llmisvc/* charts/llmisvc-resources/templates/ + rm charts/llmisvc-resources/templates/kustomization.yaml @$(CONTROLLER_GEN) object:headerFile="hack/boilerplate.go.txt" paths=./pkg/apis/serving/v1alpha1 @$(CONTROLLER_GEN) object:headerFile="hack/boilerplate.go.txt" paths=./pkg/apis/serving/v1beta1 diff --git a/charts/llmisvc-resources/templates/config-llm-decode-template.yaml b/charts/llmisvc-resources/templates/config-llm-decode-template.yaml new file mode 100644 index 00000000000..07934abacf7 --- /dev/null +++ b/charts/llmisvc-resources/templates/config-llm-decode-template.yaml @@ -0,0 +1,131 @@ +apiVersion: serving.kserve.io/v1alpha1 +kind: LLMInferenceServiceConfig +metadata: + name: kserve-config-llm-decode-template +spec: + template: + containers: + - image: ghcr.io/llm-d/llm-d:v0.2.0 + imagePullPolicy: IfNotPresent + name: main + ports: + - containerPort: 8001 + protocol: TCP + command: + - vllm + - serve + args: + - --served-model-name + - "{{ .Spec.Model.Name }}" + - --port + - "8001" + - --disable-log-requests + - --enable-ssl-refresh + - --ssl-certfile + - /etc/ssl/certs/tls.crt + - --ssl-keyfile + - /etc/ssl/certs/tls.key + env: + - name: HOME + value: /home + - name: VLLM_LOGGING_LEVEL + value: INFO + - name: HF_HUB_CACHE + value: /models + securityContext: + allowPrivilegeEscalation: false + capabilities: + drop: + - MKNOD + terminationMessagePath: /dev/termination-log + terminationMessagePolicy: FallbackToLogsOnError + livenessProbe: + httpGet: + path: /health + port: 8001 + scheme: HTTPS + initialDelaySeconds: 120 + periodSeconds: 10 + timeoutSeconds: 10 + failureThreshold: 3 + readinessProbe: + httpGet: + path: /health + port: 8001 + scheme: HTTPS + initialDelaySeconds: 10 + periodSeconds: 10 + timeoutSeconds: 5 + failureThreshold: 60 + volumeMounts: + - mountPath: /home + name: home + - mountPath: /dev/shm + name: dshm + - mountPath: /models + name: model-cache + - mountPath: /etc/ssl/certs + name: tls-certs + readOnly: true + initContainers: + - name: llm-d-routing-sidecar + imagePullPolicy: IfNotPresent + image: ghcr.io/llm-d/llm-d-routing-sidecar:v0.2.0 + restartPolicy: Always + ports: + - containerPort: 8000 + protocol: TCP + resources: { } + securityContext: { } + terminationMessagePath: /dev/termination-log + terminationMessagePolicy: FallbackToLogsOnError + livenessProbe: + httpGet: + path: /health + port: 8000 + scheme: HTTPS + initialDelaySeconds: 10 + periodSeconds: 10 + timeoutSeconds: 10 + failureThreshold: 3 + readinessProbe: + httpGet: + path: /health + port: 8000 + scheme: HTTPS + initialDelaySeconds: 10 + periodSeconds: 10 + timeoutSeconds: 5 + failureThreshold: 10 + args: + - "--port=8000" + - "--vllm-port=8001" + - "--secure-proxy=true" + - "--cert-path=/etc/ssl/certs" + - "--decoder-use-tls=true" + - "--decoder-tls-insecure-skip-verify=true" + - "--prefiller-use-tls=true" + - "--prefiller-tls-insecure-skip-verify=true" + - "--enable-ssrf-protection=true" + volumeMounts: + - mountPath: /etc/ssl/certs + name: tls-certs + readOnly: true + env: + - name: INFERENCE_POOL_NAMESPACE + valueFrom: + fieldRef: + fieldPath: metadata.namespace + terminationGracePeriodSeconds: 30 + volumes: + - emptyDir: { } + name: home + - emptyDir: + medium: Memory + sizeLimit: 1Gi + name: dshm + - emptyDir: { } + name: model-cache + - name: tls-certs + secret: + secretName: "{{ ChildName .ObjectMeta.Name `-kserve-self-signed-certs` }}" \ No newline at end of file diff --git a/charts/llmisvc-resources/templates/config-llm-decode-worker-data-parallel.yaml b/charts/llmisvc-resources/templates/config-llm-decode-worker-data-parallel.yaml new file mode 100644 index 00000000000..bf5e9c8e197 --- /dev/null +++ b/charts/llmisvc-resources/templates/config-llm-decode-worker-data-parallel.yaml @@ -0,0 +1,170 @@ +apiVersion: serving.kserve.io/v1alpha1 +kind: LLMInferenceServiceConfig +metadata: + name: kserve-config-llm-decode-worker-data-parallel +spec: + worker: + initContainers: + - name: llm-d-routing-sidecar + imagePullPolicy: IfNotPresent + image: ghcr.io/llm-d/llm-d-routing-sidecar:v0.2.0 + restartPolicy: Always + ports: + - containerPort: 8000 + protocol: TCP + terminationMessagePath: /dev/termination-log + terminationMessagePolicy: FallbackToLogsOnError + livenessProbe: + httpGet: + path: /health + port: 8000 + scheme: HTTPS + initialDelaySeconds: 10 + periodSeconds: 10 + timeoutSeconds: 10 + failureThreshold: 3 + readinessProbe: + httpGet: + path: /health + port: 8000 + scheme: HTTPS + initialDelaySeconds: 10 + periodSeconds: 10 + timeoutSeconds: 5 + failureThreshold: 10 + args: + - "--port=8000" + - "--vllm-port=8001" + - "--secure-proxy=true" + - "--cert-path=/etc/ssl/certs" + - "--decoder-use-tls=true" + - "--decoder-tls-insecure-skip-verify=true" + - "--prefiller-use-tls=true" + - "--prefiller-tls-insecure-skip-verify=true" + - "--enable-ssrf-protection=true" + volumeMounts: + - mountPath: /etc/ssl/certs + name: tls-certs + readOnly: true + env: + - name: INFERENCE_POOL_NAMESPACE + valueFrom: + fieldRef: + fieldPath: metadata.namespace + containers: + - image: ghcr.io/llm-d/llm-d:v0.2.0 + imagePullPolicy: IfNotPresent + name: main + ports: + - containerPort: 8001 + protocol: TCP + stdin: true + tty: true + command: + - "/bin/sh" + - "-c" + args: + - |- + + START_RANK=$(( ${LWS_WORKER_INDEX:-0} * {{ or .Spec.Parallelism.DataLocal 1 }} )) + if [ "${LWS_WORKER_INDEX:-0}" -eq 0 ]; then + ################# + # Leader-only launch + ################# + vllm serve \ + {{ .Spec.Model.Name }} \ + --port 8001 \ + --api-server-count 4 \ + --disable-log-requests \ + {{- if .Spec.Parallelism.Expert -}}--enable-expert-parallel \{{- end }} + {{- if .Spec.Parallelism.Tensor -}}--tensor-parallel-size {{ .Spec.Parallelism.Tensor }} \{{- end }} + --data-parallel-size {{ or .Spec.Parallelism.Data 1 }} \ + --data-parallel-size-local {{ or .Spec.Parallelism.DataLocal 1 }} \ + --data-parallel-address $(LWS_LEADER_ADDRESS) \ + --data-parallel-rpc-port {{ if .Spec.Parallelism.DataRPCPort }}{{ .Spec.Parallelism.DataRPCPort }}{{ else }}5555{{- end }} \ + --data-parallel-start-rank $START_RANK \ + --trust-remote-code \ + --enable-ssl-refresh \ + --ssl-certfile \ + /etc/ssl/certs/tls.crt \ + --ssl-keyfile \ + /etc/ssl/certs/tls.key + else + ################# + # Worker-only launch + ################# + vllm serve \ + {{ .Spec.Model.Name }} \ + --port 8001 \ + --disable-log-requests \ + {{- if .Spec.Parallelism.Expert }}--enable-expert-parallel \{{- end }} + {{- if .Spec.Parallelism.Tensor }}--tensor-parallel-size {{ .Spec.Parallelism.Tensor }} \{{- end }} + --data-parallel-size {{ or .Spec.Parallelism.Data 1 }} \ + --data-parallel-size-local {{ or .Spec.Parallelism.DataLocal 1 }} \ + --data-parallel-address $(LWS_LEADER_ADDRESS) \ + --data-parallel-rpc-port {{ if .Spec.Parallelism.DataRPCPort }}{{ .Spec.Parallelism.DataRPCPort }}{{ else }}5555{{- end }} \ + --data-parallel-start-rank $START_RANK \ + --trust-remote-code \ + --headless \ + --enable-ssl-refresh \ + --ssl-certfile \ + /etc/ssl/certs/tls.crt \ + --ssl-keyfile \ + /etc/ssl/certs/tls.key + fi + env: + - name: HOME + value: /home + - name: VLLM_LOGGING_LEVEL + value: INFO + - name: HF_HUB_CACHE + value: /models + securityContext: + allowPrivilegeEscalation: false + capabilities: + add: + - "IPC_LOCK" + - "SYS_RAWIO" + terminationMessagePath: /dev/termination-log + terminationMessagePolicy: FallbackToLogsOnError + livenessProbe: + httpGet: + path: /health + port: 8001 + scheme: HTTPS + initialDelaySeconds: 120 + periodSeconds: 10 + timeoutSeconds: 10 + failureThreshold: 3 + readinessProbe: + httpGet: + path: /health + port: 8001 + scheme: HTTPS + initialDelaySeconds: 10 + periodSeconds: 10 + timeoutSeconds: 5 + failureThreshold: 60 + volumeMounts: + - mountPath: /home + name: home + - mountPath: /dev/shm + name: dshm + - mountPath: /models + name: model-cache + - mountPath: /etc/ssl/certs + name: tls-certs + readOnly: true + terminationGracePeriodSeconds: 30 + volumes: + - emptyDir: { } + name: home + - emptyDir: + medium: Memory + sizeLimit: 1Gi + name: dshm + - emptyDir: { } + name: model-cache + - name: tls-certs + secret: + secretName: "{{ ChildName .ObjectMeta.Name `-kserve-self-signed-certs` }}" \ No newline at end of file diff --git a/charts/llmisvc-resources/templates/config-llm-prefill-template.yaml b/charts/llmisvc-resources/templates/config-llm-prefill-template.yaml new file mode 100644 index 00000000000..3b00fa0390f --- /dev/null +++ b/charts/llmisvc-resources/templates/config-llm-prefill-template.yaml @@ -0,0 +1,81 @@ +apiVersion: serving.kserve.io/v1alpha1 +kind: LLMInferenceServiceConfig +metadata: + name: kserve-config-llm-prefill-template +spec: + prefill: + template: + containers: + - image: ghcr.io/llm-d/llm-d:v0.2.0 + imagePullPolicy: IfNotPresent + name: main + ports: + - containerPort: 8000 + protocol: TCP + command: + - vllm + - serve + - "{{ .Spec.Model.Name }}" + args: + - --served-model-name + - "{{ .Spec.Model.Name }}" + - --port + - "8000" + - --disable-log-requests + - --enable-ssl-refresh + - --ssl-certfile + - /etc/ssl/certs/tls.crt + - --ssl-keyfile + - /etc/ssl/certs/tls.key + env: + - name: HOME + value: /home + - name: VLLM_LOGGING_LEVEL + value: INFO + - name: HF_HUB_CACHE + value: /models + securityContext: + allowPrivilegeEscalation: false + terminationMessagePath: /dev/termination-log + terminationMessagePolicy: File + livenessProbe: + httpGet: + path: /health + port: 8000 + scheme: HTTPS + initialDelaySeconds: 120 + periodSeconds: 10 + timeoutSeconds: 10 + failureThreshold: 3 + readinessProbe: + httpGet: + path: /health + port: 8000 + scheme: HTTPS + initialDelaySeconds: 10 + periodSeconds: 10 + timeoutSeconds: 5 + failureThreshold: 60 + volumeMounts: + - mountPath: /home + name: home + - mountPath: /dev/shm + name: dshm + - mountPath: /models + name: model-cache + - mountPath: /etc/ssl/certs + name: tls-certs + readOnly: true + terminationGracePeriodSeconds: 30 + volumes: + - emptyDir: { } + name: home + - emptyDir: + medium: Memory + sizeLimit: 1Gi + name: dshm + - emptyDir: { } + name: model-cache + - name: tls-certs + secret: + secretName: "{{ ChildName .ObjectMeta.Name `-kserve-self-signed-certs` }}" \ No newline at end of file diff --git a/charts/llmisvc-resources/templates/config-llm-prefill-worker-data-parallel.yaml b/charts/llmisvc-resources/templates/config-llm-prefill-worker-data-parallel.yaml new file mode 100644 index 00000000000..b9ef3b80e5e --- /dev/null +++ b/charts/llmisvc-resources/templates/config-llm-prefill-worker-data-parallel.yaml @@ -0,0 +1,124 @@ +apiVersion: serving.kserve.io/v1alpha1 +kind: LLMInferenceServiceConfig +metadata: + name: kserve-config-llm-prefill-worker-data-parallel +spec: + prefill: + worker: + containers: + - image: ghcr.io/llm-d/llm-d:v0.2.0 + imagePullPolicy: IfNotPresent + name: main + ports: + - containerPort: 8000 + protocol: TCP + stdin: true + tty: true + command: + - "/bin/sh" + - "-c" + args: + - |- + + START_RANK=$(( ${LWS_WORKER_INDEX:-0} * {{ or .Spec.Prefill.Parallelism.DataLocal 1 }} )) + if [ "${LWS_WORKER_INDEX:-0}" -eq 0 ]; then + ################# + # Leader-only launch + ################# + vllm serve \ + {{ .Spec.Model.Name }} \ + --port 8000 \ + --api-server-count 4 \ + --disable-log-requests \ + {{- if .Spec.Prefill.Parallelism.Expert -}}--enable-expert-parallel \{{- end }} + {{- if .Spec.Prefill.Parallelism.Tensor -}}--tensor-parallel-size {{ .Spec.Prefill.Parallelism.Tensor }} \{{- end }} + --data-parallel-size {{ or .Spec.Prefill.Parallelism.Data 1 }} \ + --data-parallel-size-local {{ or .Spec.Prefill.Parallelism.DataLocal 1 }} \ + --data-parallel-address $(LWS_LEADER_ADDRESS) \ + --data-parallel-rpc-port {{ if .Spec.Prefill.Parallelism.DataRPCPort }}{{ .Spec.Prefill.Parallelism.DataRPCPort }}{{ else }}5555{{- end }} \ + --data-parallel-start-rank $START_RANK \ + --trust-remote-code \ + --enable-ssl-refresh \ + --ssl-certfile \ + /etc/ssl/certs/tls.crt \ + --ssl-keyfile \ + /etc/ssl/certs/tls.key + else + ################# + # Worker-only launch + ################# + vllm serve \ + {{ .Spec.Model.Name }} \ + --port 8000 \ + --disable-log-requests \ + {{- if .Spec.Prefill.Parallelism.Expert -}}--enable-expert-parallel \{{- end }} + {{- if .Spec.Prefill.Parallelism.Tensor -}}--tensor-parallel-size {{ .Spec.Prefill.Parallelism.Tensor }} \{{- end }} + --data-parallel-size {{ or .Spec.Prefill.Parallelism.Data 1 }} \ + --data-parallel-size-local {{ or .Spec.Prefill.Parallelism.DataLocal 1 }} \ + --data-parallel-address $(LWS_LEADER_ADDRESS) \ + --data-parallel-rpc-port {{ if .Spec.Prefill.Parallelism.DataRPCPort }}{{ .Spec.Prefill.Parallelism.DataRPCPort }}{{ else }}5555{{- end }} \ + --data-parallel-start-rank $START_RANK \ + --trust-remote-code \ + --headless \ + --enable-ssl-refresh \ + --ssl-certfile \ + /etc/ssl/certs/tls.crt \ + --ssl-keyfile \ + /etc/ssl/certs/tls.key + fi + env: + - name: HOME + value: /home + - name: VLLM_LOGGING_LEVEL + value: INFO + - name: HF_HUB_CACHE + value: /models + securityContext: + allowPrivilegeEscalation: false + capabilities: + add: + - "IPC_LOCK" + - "SYS_RAWIO" + terminationMessagePath: /dev/termination-log + terminationMessagePolicy: FallbackToLogsOnError + livenessProbe: + httpGet: + path: /health + port: 8000 + scheme: HTTPS + initialDelaySeconds: 120 + periodSeconds: 10 + timeoutSeconds: 10 + failureThreshold: 3 + readinessProbe: + httpGet: + path: /health + port: 8000 + scheme: HTTPS + initialDelaySeconds: 10 + periodSeconds: 10 + timeoutSeconds: 5 + failureThreshold: 60 + volumeMounts: + - mountPath: /home + name: home + - mountPath: /dev/shm + name: dshm + - mountPath: /models + name: model-cache + - mountPath: /etc/ssl/certs + name: tls-certs + readOnly: true + terminationGracePeriodSeconds: 30 + volumes: + - emptyDir: { } + name: home + - emptyDir: + medium: Memory + sizeLimit: 1Gi + name: dshm + - emptyDir: { } + name: model-cache + - name: tls-certs + secret: + secretName: "{{ ChildName .ObjectMeta.Name `-kserve-self-signed-certs` }}" \ No newline at end of file diff --git a/charts/llmisvc-resources/templates/config-llm-router-route.yaml b/charts/llmisvc-resources/templates/config-llm-router-route.yaml new file mode 100644 index 00000000000..bf4b5e8145c --- /dev/null +++ b/charts/llmisvc-resources/templates/config-llm-router-route.yaml @@ -0,0 +1,38 @@ +apiVersion: serving.kserve.io/v1alpha1 +kind: LLMInferenceServiceConfig +metadata: + name: kserve-config-llm-router-route +spec: + router: + route: + http: + spec: + parentRefs: + - group: gateway.networking.k8s.io + kind: Gateway + name: |- + {{ .GlobalConfig.IngressGatewayName }} + namespace: |- + {{ .GlobalConfig.IngressGatewayNamespace }} + rules: + - backendRefs: + - group: inference.networking.x-k8s.io + kind: InferencePool + name: |- + {{ ChildName .ObjectMeta.Name `-inference-pool` }} + port: 8000 + weight: 1 + matches: + - path: + type: PathPrefix + value: |- + /{{ .ObjectMeta.Namespace }}/{{ .ObjectMeta.Name }} + filters: + - type: URLRewrite + urlRewrite: + path: + type: ReplacePrefixMatch + replacePrefixMatch: / + timeouts: + backendRequest: 0s + request: 0s \ No newline at end of file diff --git a/charts/llmisvc-resources/templates/config-llm-scheduler.yaml b/charts/llmisvc-resources/templates/config-llm-scheduler.yaml new file mode 100644 index 00000000000..fc614fd7b0e --- /dev/null +++ b/charts/llmisvc-resources/templates/config-llm-scheduler.yaml @@ -0,0 +1,89 @@ +apiVersion: serving.kserve.io/v1alpha1 +kind: LLMInferenceServiceConfig +metadata: + name: kserve-config-llm-scheduler +spec: + router: + scheduler: + pool: + spec: + extensionRef: + failureMode: FailOpen + kind: Service + name: |- + {{ ChildName .ObjectMeta.Name `-epp-service` }} + selector: { } + targetPortNumber: 8000 + template: + containers: + - name: main + ports: + - containerPort: 9002 + name: grpc + protocol: TCP + - containerPort: 9003 + name: grpc-health + protocol: TCP + - containerPort: 9090 + name: metrics + protocol: TCP + image: ghcr.io/llm-d/llm-d-inference-scheduler:v0.2.0 + imagePullPolicy: IfNotPresent + livenessProbe: + failureThreshold: 3 + grpc: + port: 9003 + service: envoy.service.ext_proc.v3.ExternalProcessor + initialDelaySeconds: 5 + periodSeconds: 10 + successThreshold: 1 + timeoutSeconds: 1 + readinessProbe: + failureThreshold: 3 + grpc: + port: 9003 + service: envoy.service.ext_proc.v3.ExternalProcessor + initialDelaySeconds: 30 + periodSeconds: 10 + successThreshold: 1 + timeoutSeconds: 1 + args: + - --poolName + - "{{ ChildName .ObjectMeta.Name `-inference-pool` }}" + - --poolNamespace + - "{{ .ObjectMeta.Namespace }}" + - --zap-encoder + - json + - --grpcPort + - "9002" + - --grpcHealthPort + - "9003" + - --secureServing + - --certPath + - "/etc/ssl/certs" + resources: + requests: + cpu: 256m + memory: 500Mi + terminationMessagePath: /dev/termination-log + terminationMessagePolicy: FallbackToLogsOnError + securityContext: + allowPrivilegeEscalation: false + readOnlyRootFilesystem: true + runAsNonRoot: true + capabilities: + drop: + - ALL + seccompProfile: + type: RuntimeDefault + volumeMounts: + - mountPath: /etc/ssl/certs + name: tls-certs + readOnly: true + volumes: + - name: tls-certs + secret: + secretName: "{{ ChildName .ObjectMeta.Name `-kserve-self-signed-certs` }}" + dnsPolicy: ClusterFirst + restartPolicy: Always + terminationGracePeriodSeconds: 30 \ No newline at end of file diff --git a/charts/llmisvc-resources/templates/config-llm-template.yaml b/charts/llmisvc-resources/templates/config-llm-template.yaml new file mode 100644 index 00000000000..b731c15ce38 --- /dev/null +++ b/charts/llmisvc-resources/templates/config-llm-template.yaml @@ -0,0 +1,82 @@ +apiVersion: serving.kserve.io/v1alpha1 +kind: LLMInferenceServiceConfig +metadata: + name: kserve-config-llm-template +spec: + template: + containers: + - image: ghcr.io/llm-d/llm-d:v0.2.0 + imagePullPolicy: IfNotPresent + name: main + ports: + - containerPort: 8000 + protocol: TCP + command: + - vllm + - serve + args: + - --served-model-name + - "{{ .Spec.Model.Name }}" + - --port + - "8000" + - --disable-log-requests + - --enable-ssl-refresh + - --ssl-certfile + - /etc/ssl/certs/tls.crt + - --ssl-keyfile + - /etc/ssl/certs/tls.key + env: + - name: HOME + value: /home + - name: VLLM_LOGGING_LEVEL + value: INFO + - name: HF_HUB_CACHE + value: /models + securityContext: + allowPrivilegeEscalation: false + capabilities: + drop: + - MKNOD + terminationMessagePath: /dev/termination-log + terminationMessagePolicy: FallbackToLogsOnError + livenessProbe: + httpGet: + path: /health + port: 8000 + scheme: HTTPS + initialDelaySeconds: 120 + periodSeconds: 10 + timeoutSeconds: 10 + failureThreshold: 3 + readinessProbe: + httpGet: + path: /health + port: 8000 + scheme: HTTPS + initialDelaySeconds: 10 + periodSeconds: 10 + timeoutSeconds: 5 + failureThreshold: 60 + volumeMounts: + - mountPath: /home + name: home + - mountPath: /dev/shm + name: dshm + - mountPath: /models + name: model-cache + - mountPath: /etc/ssl/certs + name: tls-certs + readOnly: true + terminationGracePeriodSeconds: 30 + volumes: + - emptyDir: { } + name: home + - emptyDir: + medium: Memory + sizeLimit: 1Gi + name: dshm + - emptyDir: { } + name: model-cache + - name: tls-certs + secret: + secretName: "{{ ChildName .ObjectMeta.Name `-kserve-self-signed-certs` }}" \ No newline at end of file diff --git a/charts/llmisvc-resources/templates/config-llm-worker-data-parallel.yaml b/charts/llmisvc-resources/templates/config-llm-worker-data-parallel.yaml new file mode 100644 index 00000000000..f66e58cc1c2 --- /dev/null +++ b/charts/llmisvc-resources/templates/config-llm-worker-data-parallel.yaml @@ -0,0 +1,123 @@ +apiVersion: serving.kserve.io/v1alpha1 +kind: LLMInferenceServiceConfig +metadata: + name: kserve-config-llm-worker-data-parallel +spec: + worker: + containers: + - image: ghcr.io/llm-d/llm-d:v0.2.0 + imagePullPolicy: IfNotPresent + name: main + ports: + - containerPort: 8000 + protocol: TCP + stdin: true + tty: true + command: + - "/bin/sh" + - "-c" + args: + - |- + + START_RANK=$(( ${LWS_WORKER_INDEX:-0} * {{ or .Spec.Parallelism.DataLocal 1 }} )) + if [ "${LWS_WORKER_INDEX:-0}" -eq 0 ]; then + ################# + # Leader-only launch + ################# + vllm serve \ + {{ .Spec.Model.Name }} \ + --port 8000 \ + --api-server-count 4 \ + --disable-log-requests \ + {{- if .Spec.Parallelism.Expert -}}--enable-expert-parallel \{{- end }} + {{- if .Spec.Parallelism.Tensor -}}--tensor-parallel-size {{ .Spec.Parallelism.Tensor }} \{{- end }} + --data-parallel-size {{ or .Spec.Parallelism.Data 1 }} \ + --data-parallel-size-local {{ or .Spec.Parallelism.DataLocal 1 }} \ + --data-parallel-address $(LWS_LEADER_ADDRESS) \ + --data-parallel-rpc-port {{ if .Spec.Parallelism.DataRPCPort }}{{ .Spec.Parallelism.DataRPCPort }}{{ else }}5555{{- end }} \ + --data-parallel-start-rank $START_RANK \ + --trust-remote-code \ + --enable-ssl-refresh \ + --ssl-certfile \ + /etc/ssl/certs/tls.crt \ + --ssl-keyfile \ + /etc/ssl/certs/tls.key + else + ################# + # Worker-only launch + ################# + vllm serve \ + {{ .Spec.Model.Name }} \ + --port 8000 \ + --disable-log-requests \ + {{- if .Spec.Parallelism.Expert }}--enable-expert-parallel \{{- end }} + {{- if .Spec.Parallelism.Tensor }}--tensor-parallel-size {{ .Spec.Parallelism.Tensor }} \{{- end }} + --data-parallel-size {{ or .Spec.Parallelism.Data 1 }} \ + --data-parallel-size-local {{ or .Spec.Parallelism.DataLocal 1 }} \ + --data-parallel-address $(LWS_LEADER_ADDRESS) \ + --data-parallel-rpc-port {{ if .Spec.Parallelism.DataRPCPort }}{{ .Spec.Parallelism.DataRPCPort }}{{ else }}5555{{- end }} \ + --data-parallel-start-rank $START_RANK \ + --trust-remote-code \ + --headless \ + --enable-ssl-refresh \ + --ssl-certfile \ + /etc/ssl/certs/tls.crt \ + --ssl-keyfile \ + /etc/ssl/certs/tls.key + fi + env: + - name: HOME + value: /home + - name: VLLM_LOGGING_LEVEL + value: INFO + - name: HF_HUB_CACHE + value: /models + securityContext: + allowPrivilegeEscalation: false + capabilities: + add: + - "IPC_LOCK" + - "SYS_RAWIO" + terminationMessagePath: /dev/termination-log + terminationMessagePolicy: FallbackToLogsOnError + livenessProbe: + httpGet: + path: /health + port: 8000 + scheme: HTTPS + initialDelaySeconds: 120 + periodSeconds: 10 + timeoutSeconds: 10 + failureThreshold: 3 + readinessProbe: + httpGet: + path: /health + port: 8000 + scheme: HTTPS + initialDelaySeconds: 10 + periodSeconds: 10 + timeoutSeconds: 5 + failureThreshold: 60 + volumeMounts: + - mountPath: /home + name: home + - mountPath: /dev/shm + name: dshm + - mountPath: /models + name: model-cache + - mountPath: /etc/ssl/certs + name: tls-certs + readOnly: true + terminationGracePeriodSeconds: 30 + volumes: + - emptyDir: { } + name: home + - emptyDir: + medium: Memory + sizeLimit: 1Gi + name: dshm + - emptyDir: { } + name: model-cache + - name: tls-certs + secret: + secretName: "{{ ChildName .ObjectMeta.Name `-kserve-self-signed-certs` }}" \ No newline at end of file diff --git a/config/llmisvc/config-llm-decode-template.yaml b/config/llmisvc/config-llm-decode-template.yaml new file mode 100644 index 00000000000..07934abacf7 --- /dev/null +++ b/config/llmisvc/config-llm-decode-template.yaml @@ -0,0 +1,131 @@ +apiVersion: serving.kserve.io/v1alpha1 +kind: LLMInferenceServiceConfig +metadata: + name: kserve-config-llm-decode-template +spec: + template: + containers: + - image: ghcr.io/llm-d/llm-d:v0.2.0 + imagePullPolicy: IfNotPresent + name: main + ports: + - containerPort: 8001 + protocol: TCP + command: + - vllm + - serve + args: + - --served-model-name + - "{{ .Spec.Model.Name }}" + - --port + - "8001" + - --disable-log-requests + - --enable-ssl-refresh + - --ssl-certfile + - /etc/ssl/certs/tls.crt + - --ssl-keyfile + - /etc/ssl/certs/tls.key + env: + - name: HOME + value: /home + - name: VLLM_LOGGING_LEVEL + value: INFO + - name: HF_HUB_CACHE + value: /models + securityContext: + allowPrivilegeEscalation: false + capabilities: + drop: + - MKNOD + terminationMessagePath: /dev/termination-log + terminationMessagePolicy: FallbackToLogsOnError + livenessProbe: + httpGet: + path: /health + port: 8001 + scheme: HTTPS + initialDelaySeconds: 120 + periodSeconds: 10 + timeoutSeconds: 10 + failureThreshold: 3 + readinessProbe: + httpGet: + path: /health + port: 8001 + scheme: HTTPS + initialDelaySeconds: 10 + periodSeconds: 10 + timeoutSeconds: 5 + failureThreshold: 60 + volumeMounts: + - mountPath: /home + name: home + - mountPath: /dev/shm + name: dshm + - mountPath: /models + name: model-cache + - mountPath: /etc/ssl/certs + name: tls-certs + readOnly: true + initContainers: + - name: llm-d-routing-sidecar + imagePullPolicy: IfNotPresent + image: ghcr.io/llm-d/llm-d-routing-sidecar:v0.2.0 + restartPolicy: Always + ports: + - containerPort: 8000 + protocol: TCP + resources: { } + securityContext: { } + terminationMessagePath: /dev/termination-log + terminationMessagePolicy: FallbackToLogsOnError + livenessProbe: + httpGet: + path: /health + port: 8000 + scheme: HTTPS + initialDelaySeconds: 10 + periodSeconds: 10 + timeoutSeconds: 10 + failureThreshold: 3 + readinessProbe: + httpGet: + path: /health + port: 8000 + scheme: HTTPS + initialDelaySeconds: 10 + periodSeconds: 10 + timeoutSeconds: 5 + failureThreshold: 10 + args: + - "--port=8000" + - "--vllm-port=8001" + - "--secure-proxy=true" + - "--cert-path=/etc/ssl/certs" + - "--decoder-use-tls=true" + - "--decoder-tls-insecure-skip-verify=true" + - "--prefiller-use-tls=true" + - "--prefiller-tls-insecure-skip-verify=true" + - "--enable-ssrf-protection=true" + volumeMounts: + - mountPath: /etc/ssl/certs + name: tls-certs + readOnly: true + env: + - name: INFERENCE_POOL_NAMESPACE + valueFrom: + fieldRef: + fieldPath: metadata.namespace + terminationGracePeriodSeconds: 30 + volumes: + - emptyDir: { } + name: home + - emptyDir: + medium: Memory + sizeLimit: 1Gi + name: dshm + - emptyDir: { } + name: model-cache + - name: tls-certs + secret: + secretName: "{{ ChildName .ObjectMeta.Name `-kserve-self-signed-certs` }}" \ No newline at end of file diff --git a/config/llmisvc/config-llm-decode-worker-data-parallel.yaml b/config/llmisvc/config-llm-decode-worker-data-parallel.yaml new file mode 100644 index 00000000000..bf5e9c8e197 --- /dev/null +++ b/config/llmisvc/config-llm-decode-worker-data-parallel.yaml @@ -0,0 +1,170 @@ +apiVersion: serving.kserve.io/v1alpha1 +kind: LLMInferenceServiceConfig +metadata: + name: kserve-config-llm-decode-worker-data-parallel +spec: + worker: + initContainers: + - name: llm-d-routing-sidecar + imagePullPolicy: IfNotPresent + image: ghcr.io/llm-d/llm-d-routing-sidecar:v0.2.0 + restartPolicy: Always + ports: + - containerPort: 8000 + protocol: TCP + terminationMessagePath: /dev/termination-log + terminationMessagePolicy: FallbackToLogsOnError + livenessProbe: + httpGet: + path: /health + port: 8000 + scheme: HTTPS + initialDelaySeconds: 10 + periodSeconds: 10 + timeoutSeconds: 10 + failureThreshold: 3 + readinessProbe: + httpGet: + path: /health + port: 8000 + scheme: HTTPS + initialDelaySeconds: 10 + periodSeconds: 10 + timeoutSeconds: 5 + failureThreshold: 10 + args: + - "--port=8000" + - "--vllm-port=8001" + - "--secure-proxy=true" + - "--cert-path=/etc/ssl/certs" + - "--decoder-use-tls=true" + - "--decoder-tls-insecure-skip-verify=true" + - "--prefiller-use-tls=true" + - "--prefiller-tls-insecure-skip-verify=true" + - "--enable-ssrf-protection=true" + volumeMounts: + - mountPath: /etc/ssl/certs + name: tls-certs + readOnly: true + env: + - name: INFERENCE_POOL_NAMESPACE + valueFrom: + fieldRef: + fieldPath: metadata.namespace + containers: + - image: ghcr.io/llm-d/llm-d:v0.2.0 + imagePullPolicy: IfNotPresent + name: main + ports: + - containerPort: 8001 + protocol: TCP + stdin: true + tty: true + command: + - "/bin/sh" + - "-c" + args: + - |- + + START_RANK=$(( ${LWS_WORKER_INDEX:-0} * {{ or .Spec.Parallelism.DataLocal 1 }} )) + if [ "${LWS_WORKER_INDEX:-0}" -eq 0 ]; then + ################# + # Leader-only launch + ################# + vllm serve \ + {{ .Spec.Model.Name }} \ + --port 8001 \ + --api-server-count 4 \ + --disable-log-requests \ + {{- if .Spec.Parallelism.Expert -}}--enable-expert-parallel \{{- end }} + {{- if .Spec.Parallelism.Tensor -}}--tensor-parallel-size {{ .Spec.Parallelism.Tensor }} \{{- end }} + --data-parallel-size {{ or .Spec.Parallelism.Data 1 }} \ + --data-parallel-size-local {{ or .Spec.Parallelism.DataLocal 1 }} \ + --data-parallel-address $(LWS_LEADER_ADDRESS) \ + --data-parallel-rpc-port {{ if .Spec.Parallelism.DataRPCPort }}{{ .Spec.Parallelism.DataRPCPort }}{{ else }}5555{{- end }} \ + --data-parallel-start-rank $START_RANK \ + --trust-remote-code \ + --enable-ssl-refresh \ + --ssl-certfile \ + /etc/ssl/certs/tls.crt \ + --ssl-keyfile \ + /etc/ssl/certs/tls.key + else + ################# + # Worker-only launch + ################# + vllm serve \ + {{ .Spec.Model.Name }} \ + --port 8001 \ + --disable-log-requests \ + {{- if .Spec.Parallelism.Expert }}--enable-expert-parallel \{{- end }} + {{- if .Spec.Parallelism.Tensor }}--tensor-parallel-size {{ .Spec.Parallelism.Tensor }} \{{- end }} + --data-parallel-size {{ or .Spec.Parallelism.Data 1 }} \ + --data-parallel-size-local {{ or .Spec.Parallelism.DataLocal 1 }} \ + --data-parallel-address $(LWS_LEADER_ADDRESS) \ + --data-parallel-rpc-port {{ if .Spec.Parallelism.DataRPCPort }}{{ .Spec.Parallelism.DataRPCPort }}{{ else }}5555{{- end }} \ + --data-parallel-start-rank $START_RANK \ + --trust-remote-code \ + --headless \ + --enable-ssl-refresh \ + --ssl-certfile \ + /etc/ssl/certs/tls.crt \ + --ssl-keyfile \ + /etc/ssl/certs/tls.key + fi + env: + - name: HOME + value: /home + - name: VLLM_LOGGING_LEVEL + value: INFO + - name: HF_HUB_CACHE + value: /models + securityContext: + allowPrivilegeEscalation: false + capabilities: + add: + - "IPC_LOCK" + - "SYS_RAWIO" + terminationMessagePath: /dev/termination-log + terminationMessagePolicy: FallbackToLogsOnError + livenessProbe: + httpGet: + path: /health + port: 8001 + scheme: HTTPS + initialDelaySeconds: 120 + periodSeconds: 10 + timeoutSeconds: 10 + failureThreshold: 3 + readinessProbe: + httpGet: + path: /health + port: 8001 + scheme: HTTPS + initialDelaySeconds: 10 + periodSeconds: 10 + timeoutSeconds: 5 + failureThreshold: 60 + volumeMounts: + - mountPath: /home + name: home + - mountPath: /dev/shm + name: dshm + - mountPath: /models + name: model-cache + - mountPath: /etc/ssl/certs + name: tls-certs + readOnly: true + terminationGracePeriodSeconds: 30 + volumes: + - emptyDir: { } + name: home + - emptyDir: + medium: Memory + sizeLimit: 1Gi + name: dshm + - emptyDir: { } + name: model-cache + - name: tls-certs + secret: + secretName: "{{ ChildName .ObjectMeta.Name `-kserve-self-signed-certs` }}" \ No newline at end of file diff --git a/config/llmisvc/config-llm-prefill-template.yaml b/config/llmisvc/config-llm-prefill-template.yaml new file mode 100644 index 00000000000..3b00fa0390f --- /dev/null +++ b/config/llmisvc/config-llm-prefill-template.yaml @@ -0,0 +1,81 @@ +apiVersion: serving.kserve.io/v1alpha1 +kind: LLMInferenceServiceConfig +metadata: + name: kserve-config-llm-prefill-template +spec: + prefill: + template: + containers: + - image: ghcr.io/llm-d/llm-d:v0.2.0 + imagePullPolicy: IfNotPresent + name: main + ports: + - containerPort: 8000 + protocol: TCP + command: + - vllm + - serve + - "{{ .Spec.Model.Name }}" + args: + - --served-model-name + - "{{ .Spec.Model.Name }}" + - --port + - "8000" + - --disable-log-requests + - --enable-ssl-refresh + - --ssl-certfile + - /etc/ssl/certs/tls.crt + - --ssl-keyfile + - /etc/ssl/certs/tls.key + env: + - name: HOME + value: /home + - name: VLLM_LOGGING_LEVEL + value: INFO + - name: HF_HUB_CACHE + value: /models + securityContext: + allowPrivilegeEscalation: false + terminationMessagePath: /dev/termination-log + terminationMessagePolicy: File + livenessProbe: + httpGet: + path: /health + port: 8000 + scheme: HTTPS + initialDelaySeconds: 120 + periodSeconds: 10 + timeoutSeconds: 10 + failureThreshold: 3 + readinessProbe: + httpGet: + path: /health + port: 8000 + scheme: HTTPS + initialDelaySeconds: 10 + periodSeconds: 10 + timeoutSeconds: 5 + failureThreshold: 60 + volumeMounts: + - mountPath: /home + name: home + - mountPath: /dev/shm + name: dshm + - mountPath: /models + name: model-cache + - mountPath: /etc/ssl/certs + name: tls-certs + readOnly: true + terminationGracePeriodSeconds: 30 + volumes: + - emptyDir: { } + name: home + - emptyDir: + medium: Memory + sizeLimit: 1Gi + name: dshm + - emptyDir: { } + name: model-cache + - name: tls-certs + secret: + secretName: "{{ ChildName .ObjectMeta.Name `-kserve-self-signed-certs` }}" \ No newline at end of file diff --git a/config/llmisvc/config-llm-prefill-worker-data-parallel.yaml b/config/llmisvc/config-llm-prefill-worker-data-parallel.yaml new file mode 100644 index 00000000000..b9ef3b80e5e --- /dev/null +++ b/config/llmisvc/config-llm-prefill-worker-data-parallel.yaml @@ -0,0 +1,124 @@ +apiVersion: serving.kserve.io/v1alpha1 +kind: LLMInferenceServiceConfig +metadata: + name: kserve-config-llm-prefill-worker-data-parallel +spec: + prefill: + worker: + containers: + - image: ghcr.io/llm-d/llm-d:v0.2.0 + imagePullPolicy: IfNotPresent + name: main + ports: + - containerPort: 8000 + protocol: TCP + stdin: true + tty: true + command: + - "/bin/sh" + - "-c" + args: + - |- + + START_RANK=$(( ${LWS_WORKER_INDEX:-0} * {{ or .Spec.Prefill.Parallelism.DataLocal 1 }} )) + if [ "${LWS_WORKER_INDEX:-0}" -eq 0 ]; then + ################# + # Leader-only launch + ################# + vllm serve \ + {{ .Spec.Model.Name }} \ + --port 8000 \ + --api-server-count 4 \ + --disable-log-requests \ + {{- if .Spec.Prefill.Parallelism.Expert -}}--enable-expert-parallel \{{- end }} + {{- if .Spec.Prefill.Parallelism.Tensor -}}--tensor-parallel-size {{ .Spec.Prefill.Parallelism.Tensor }} \{{- end }} + --data-parallel-size {{ or .Spec.Prefill.Parallelism.Data 1 }} \ + --data-parallel-size-local {{ or .Spec.Prefill.Parallelism.DataLocal 1 }} \ + --data-parallel-address $(LWS_LEADER_ADDRESS) \ + --data-parallel-rpc-port {{ if .Spec.Prefill.Parallelism.DataRPCPort }}{{ .Spec.Prefill.Parallelism.DataRPCPort }}{{ else }}5555{{- end }} \ + --data-parallel-start-rank $START_RANK \ + --trust-remote-code \ + --enable-ssl-refresh \ + --ssl-certfile \ + /etc/ssl/certs/tls.crt \ + --ssl-keyfile \ + /etc/ssl/certs/tls.key + else + ################# + # Worker-only launch + ################# + vllm serve \ + {{ .Spec.Model.Name }} \ + --port 8000 \ + --disable-log-requests \ + {{- if .Spec.Prefill.Parallelism.Expert -}}--enable-expert-parallel \{{- end }} + {{- if .Spec.Prefill.Parallelism.Tensor -}}--tensor-parallel-size {{ .Spec.Prefill.Parallelism.Tensor }} \{{- end }} + --data-parallel-size {{ or .Spec.Prefill.Parallelism.Data 1 }} \ + --data-parallel-size-local {{ or .Spec.Prefill.Parallelism.DataLocal 1 }} \ + --data-parallel-address $(LWS_LEADER_ADDRESS) \ + --data-parallel-rpc-port {{ if .Spec.Prefill.Parallelism.DataRPCPort }}{{ .Spec.Prefill.Parallelism.DataRPCPort }}{{ else }}5555{{- end }} \ + --data-parallel-start-rank $START_RANK \ + --trust-remote-code \ + --headless \ + --enable-ssl-refresh \ + --ssl-certfile \ + /etc/ssl/certs/tls.crt \ + --ssl-keyfile \ + /etc/ssl/certs/tls.key + fi + env: + - name: HOME + value: /home + - name: VLLM_LOGGING_LEVEL + value: INFO + - name: HF_HUB_CACHE + value: /models + securityContext: + allowPrivilegeEscalation: false + capabilities: + add: + - "IPC_LOCK" + - "SYS_RAWIO" + terminationMessagePath: /dev/termination-log + terminationMessagePolicy: FallbackToLogsOnError + livenessProbe: + httpGet: + path: /health + port: 8000 + scheme: HTTPS + initialDelaySeconds: 120 + periodSeconds: 10 + timeoutSeconds: 10 + failureThreshold: 3 + readinessProbe: + httpGet: + path: /health + port: 8000 + scheme: HTTPS + initialDelaySeconds: 10 + periodSeconds: 10 + timeoutSeconds: 5 + failureThreshold: 60 + volumeMounts: + - mountPath: /home + name: home + - mountPath: /dev/shm + name: dshm + - mountPath: /models + name: model-cache + - mountPath: /etc/ssl/certs + name: tls-certs + readOnly: true + terminationGracePeriodSeconds: 30 + volumes: + - emptyDir: { } + name: home + - emptyDir: + medium: Memory + sizeLimit: 1Gi + name: dshm + - emptyDir: { } + name: model-cache + - name: tls-certs + secret: + secretName: "{{ ChildName .ObjectMeta.Name `-kserve-self-signed-certs` }}" \ No newline at end of file diff --git a/config/llmisvc/config-llm-router-route.yaml b/config/llmisvc/config-llm-router-route.yaml new file mode 100644 index 00000000000..bf4b5e8145c --- /dev/null +++ b/config/llmisvc/config-llm-router-route.yaml @@ -0,0 +1,38 @@ +apiVersion: serving.kserve.io/v1alpha1 +kind: LLMInferenceServiceConfig +metadata: + name: kserve-config-llm-router-route +spec: + router: + route: + http: + spec: + parentRefs: + - group: gateway.networking.k8s.io + kind: Gateway + name: |- + {{ .GlobalConfig.IngressGatewayName }} + namespace: |- + {{ .GlobalConfig.IngressGatewayNamespace }} + rules: + - backendRefs: + - group: inference.networking.x-k8s.io + kind: InferencePool + name: |- + {{ ChildName .ObjectMeta.Name `-inference-pool` }} + port: 8000 + weight: 1 + matches: + - path: + type: PathPrefix + value: |- + /{{ .ObjectMeta.Namespace }}/{{ .ObjectMeta.Name }} + filters: + - type: URLRewrite + urlRewrite: + path: + type: ReplacePrefixMatch + replacePrefixMatch: / + timeouts: + backendRequest: 0s + request: 0s \ No newline at end of file diff --git a/config/llmisvc/config-llm-scheduler.yaml b/config/llmisvc/config-llm-scheduler.yaml new file mode 100644 index 00000000000..fc614fd7b0e --- /dev/null +++ b/config/llmisvc/config-llm-scheduler.yaml @@ -0,0 +1,89 @@ +apiVersion: serving.kserve.io/v1alpha1 +kind: LLMInferenceServiceConfig +metadata: + name: kserve-config-llm-scheduler +spec: + router: + scheduler: + pool: + spec: + extensionRef: + failureMode: FailOpen + kind: Service + name: |- + {{ ChildName .ObjectMeta.Name `-epp-service` }} + selector: { } + targetPortNumber: 8000 + template: + containers: + - name: main + ports: + - containerPort: 9002 + name: grpc + protocol: TCP + - containerPort: 9003 + name: grpc-health + protocol: TCP + - containerPort: 9090 + name: metrics + protocol: TCP + image: ghcr.io/llm-d/llm-d-inference-scheduler:v0.2.0 + imagePullPolicy: IfNotPresent + livenessProbe: + failureThreshold: 3 + grpc: + port: 9003 + service: envoy.service.ext_proc.v3.ExternalProcessor + initialDelaySeconds: 5 + periodSeconds: 10 + successThreshold: 1 + timeoutSeconds: 1 + readinessProbe: + failureThreshold: 3 + grpc: + port: 9003 + service: envoy.service.ext_proc.v3.ExternalProcessor + initialDelaySeconds: 30 + periodSeconds: 10 + successThreshold: 1 + timeoutSeconds: 1 + args: + - --poolName + - "{{ ChildName .ObjectMeta.Name `-inference-pool` }}" + - --poolNamespace + - "{{ .ObjectMeta.Namespace }}" + - --zap-encoder + - json + - --grpcPort + - "9002" + - --grpcHealthPort + - "9003" + - --secureServing + - --certPath + - "/etc/ssl/certs" + resources: + requests: + cpu: 256m + memory: 500Mi + terminationMessagePath: /dev/termination-log + terminationMessagePolicy: FallbackToLogsOnError + securityContext: + allowPrivilegeEscalation: false + readOnlyRootFilesystem: true + runAsNonRoot: true + capabilities: + drop: + - ALL + seccompProfile: + type: RuntimeDefault + volumeMounts: + - mountPath: /etc/ssl/certs + name: tls-certs + readOnly: true + volumes: + - name: tls-certs + secret: + secretName: "{{ ChildName .ObjectMeta.Name `-kserve-self-signed-certs` }}" + dnsPolicy: ClusterFirst + restartPolicy: Always + terminationGracePeriodSeconds: 30 \ No newline at end of file diff --git a/config/llmisvc/config-llm-template.yaml b/config/llmisvc/config-llm-template.yaml new file mode 100644 index 00000000000..b731c15ce38 --- /dev/null +++ b/config/llmisvc/config-llm-template.yaml @@ -0,0 +1,82 @@ +apiVersion: serving.kserve.io/v1alpha1 +kind: LLMInferenceServiceConfig +metadata: + name: kserve-config-llm-template +spec: + template: + containers: + - image: ghcr.io/llm-d/llm-d:v0.2.0 + imagePullPolicy: IfNotPresent + name: main + ports: + - containerPort: 8000 + protocol: TCP + command: + - vllm + - serve + args: + - --served-model-name + - "{{ .Spec.Model.Name }}" + - --port + - "8000" + - --disable-log-requests + - --enable-ssl-refresh + - --ssl-certfile + - /etc/ssl/certs/tls.crt + - --ssl-keyfile + - /etc/ssl/certs/tls.key + env: + - name: HOME + value: /home + - name: VLLM_LOGGING_LEVEL + value: INFO + - name: HF_HUB_CACHE + value: /models + securityContext: + allowPrivilegeEscalation: false + capabilities: + drop: + - MKNOD + terminationMessagePath: /dev/termination-log + terminationMessagePolicy: FallbackToLogsOnError + livenessProbe: + httpGet: + path: /health + port: 8000 + scheme: HTTPS + initialDelaySeconds: 120 + periodSeconds: 10 + timeoutSeconds: 10 + failureThreshold: 3 + readinessProbe: + httpGet: + path: /health + port: 8000 + scheme: HTTPS + initialDelaySeconds: 10 + periodSeconds: 10 + timeoutSeconds: 5 + failureThreshold: 60 + volumeMounts: + - mountPath: /home + name: home + - mountPath: /dev/shm + name: dshm + - mountPath: /models + name: model-cache + - mountPath: /etc/ssl/certs + name: tls-certs + readOnly: true + terminationGracePeriodSeconds: 30 + volumes: + - emptyDir: { } + name: home + - emptyDir: + medium: Memory + sizeLimit: 1Gi + name: dshm + - emptyDir: { } + name: model-cache + - name: tls-certs + secret: + secretName: "{{ ChildName .ObjectMeta.Name `-kserve-self-signed-certs` }}" \ No newline at end of file diff --git a/config/llmisvc/config-llm-worker-data-parallel.yaml b/config/llmisvc/config-llm-worker-data-parallel.yaml new file mode 100644 index 00000000000..f66e58cc1c2 --- /dev/null +++ b/config/llmisvc/config-llm-worker-data-parallel.yaml @@ -0,0 +1,123 @@ +apiVersion: serving.kserve.io/v1alpha1 +kind: LLMInferenceServiceConfig +metadata: + name: kserve-config-llm-worker-data-parallel +spec: + worker: + containers: + - image: ghcr.io/llm-d/llm-d:v0.2.0 + imagePullPolicy: IfNotPresent + name: main + ports: + - containerPort: 8000 + protocol: TCP + stdin: true + tty: true + command: + - "/bin/sh" + - "-c" + args: + - |- + + START_RANK=$(( ${LWS_WORKER_INDEX:-0} * {{ or .Spec.Parallelism.DataLocal 1 }} )) + if [ "${LWS_WORKER_INDEX:-0}" -eq 0 ]; then + ################# + # Leader-only launch + ################# + vllm serve \ + {{ .Spec.Model.Name }} \ + --port 8000 \ + --api-server-count 4 \ + --disable-log-requests \ + {{- if .Spec.Parallelism.Expert -}}--enable-expert-parallel \{{- end }} + {{- if .Spec.Parallelism.Tensor -}}--tensor-parallel-size {{ .Spec.Parallelism.Tensor }} \{{- end }} + --data-parallel-size {{ or .Spec.Parallelism.Data 1 }} \ + --data-parallel-size-local {{ or .Spec.Parallelism.DataLocal 1 }} \ + --data-parallel-address $(LWS_LEADER_ADDRESS) \ + --data-parallel-rpc-port {{ if .Spec.Parallelism.DataRPCPort }}{{ .Spec.Parallelism.DataRPCPort }}{{ else }}5555{{- end }} \ + --data-parallel-start-rank $START_RANK \ + --trust-remote-code \ + --enable-ssl-refresh \ + --ssl-certfile \ + /etc/ssl/certs/tls.crt \ + --ssl-keyfile \ + /etc/ssl/certs/tls.key + else + ################# + # Worker-only launch + ################# + vllm serve \ + {{ .Spec.Model.Name }} \ + --port 8000 \ + --disable-log-requests \ + {{- if .Spec.Parallelism.Expert }}--enable-expert-parallel \{{- end }} + {{- if .Spec.Parallelism.Tensor }}--tensor-parallel-size {{ .Spec.Parallelism.Tensor }} \{{- end }} + --data-parallel-size {{ or .Spec.Parallelism.Data 1 }} \ + --data-parallel-size-local {{ or .Spec.Parallelism.DataLocal 1 }} \ + --data-parallel-address $(LWS_LEADER_ADDRESS) \ + --data-parallel-rpc-port {{ if .Spec.Parallelism.DataRPCPort }}{{ .Spec.Parallelism.DataRPCPort }}{{ else }}5555{{- end }} \ + --data-parallel-start-rank $START_RANK \ + --trust-remote-code \ + --headless \ + --enable-ssl-refresh \ + --ssl-certfile \ + /etc/ssl/certs/tls.crt \ + --ssl-keyfile \ + /etc/ssl/certs/tls.key + fi + env: + - name: HOME + value: /home + - name: VLLM_LOGGING_LEVEL + value: INFO + - name: HF_HUB_CACHE + value: /models + securityContext: + allowPrivilegeEscalation: false + capabilities: + add: + - "IPC_LOCK" + - "SYS_RAWIO" + terminationMessagePath: /dev/termination-log + terminationMessagePolicy: FallbackToLogsOnError + livenessProbe: + httpGet: + path: /health + port: 8000 + scheme: HTTPS + initialDelaySeconds: 120 + periodSeconds: 10 + timeoutSeconds: 10 + failureThreshold: 3 + readinessProbe: + httpGet: + path: /health + port: 8000 + scheme: HTTPS + initialDelaySeconds: 10 + periodSeconds: 10 + timeoutSeconds: 5 + failureThreshold: 60 + volumeMounts: + - mountPath: /home + name: home + - mountPath: /dev/shm + name: dshm + - mountPath: /models + name: model-cache + - mountPath: /etc/ssl/certs + name: tls-certs + readOnly: true + terminationGracePeriodSeconds: 30 + volumes: + - emptyDir: { } + name: home + - emptyDir: + medium: Memory + sizeLimit: 1Gi + name: dshm + - emptyDir: { } + name: model-cache + - name: tls-certs + secret: + secretName: "{{ ChildName .ObjectMeta.Name `-kserve-self-signed-certs` }}" \ No newline at end of file diff --git a/config/llmisvc/kustomization.yaml b/config/llmisvc/kustomization.yaml new file mode 100644 index 00000000000..a29a5458d69 --- /dev/null +++ b/config/llmisvc/kustomization.yaml @@ -0,0 +1,14 @@ +apiVersion: kustomize.config.k8s.io/v1beta1 +kind: Kustomization + +namespace: kserve + +resources: + - config-llm-decode-template.yaml + - config-llm-decode-worker-data-parallel.yaml + - config-llm-prefill-template.yaml + - config-llm-prefill-worker-data-parallel.yaml + - config-llm-router-route.yaml + - config-llm-scheduler.yaml + - config-llm-template.yaml + - config-llm-worker-data-parallel.yaml \ No newline at end of file diff --git a/pkg/apis/serving/v1alpha1/llm_inference_service_defaults.go b/pkg/apis/serving/v1alpha1/llm_inference_service_defaults.go index 899c23945e2..7eab13f8612 100644 --- a/pkg/apis/serving/v1alpha1/llm_inference_service_defaults.go +++ b/pkg/apis/serving/v1alpha1/llm_inference_service_defaults.go @@ -1,4 +1,5 @@ /* + Copyright 2025 The KServe Authors. Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/pkg/controller/v1alpha1/llmisvc/config_merge.go b/pkg/controller/v1alpha1/llmisvc/config_merge.go index 57ad1d71ac7..27e785438a4 100644 --- a/pkg/controller/v1alpha1/llmisvc/config_merge.go +++ b/pkg/controller/v1alpha1/llmisvc/config_merge.go @@ -18,16 +18,161 @@ package llmisvc import ( "bytes" + "context" "encoding/json" "fmt" "text/template" + "github.com/kserve/kserve/pkg/constants" + + corev1 "k8s.io/api/core/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/util/sets" "k8s.io/apimachinery/pkg/util/strategicpatch" "knative.dev/pkg/kmeta" + "sigs.k8s.io/controller-runtime/pkg/client" + igwapi "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" "github.com/kserve/kserve/pkg/apis/serving/v1alpha1" ) +const ( + configPrefix = "kserve-" + configTemplateName = configPrefix + "config-llm-template" + configDecodeTemplateName = configPrefix + "config-llm-decode-template" + configDecodeWorkerPipelineParallelName = configPrefix + "config-llm-decode-worker-pipeline-parallel" + configWorkerPipelineParallelName = configPrefix + "config-llm-worker-pipeline-parallel" + configWorkerDataParallelName = configPrefix + "config-llm-worker-data-parallel" + configDecodeWorkerDataParallelName = configPrefix + "config-llm-decode-worker-data-parallel" + configPrefillTemplateName = configPrefix + "config-llm-prefill-template" + configPrefillWorkerPipelineParallelName = configPrefix + "config-llm-prefill-worker-pipeline-parallel" + configPrefillWorkerDataParallelName = configPrefix + "config-llm-prefill-worker-data-parallel" + configRouterSchedulerName = configPrefix + "config-llm-scheduler" + configRouterRouteName = configPrefix + "config-llm-router-route" +) + +// FIXME move those presets to well-known when they're finally known :) +var _ = sets.New[string]( + configPrefillWorkerPipelineParallelName, + configDecodeWorkerPipelineParallelName, + configWorkerPipelineParallelName, +) + +var WellKnownDefaultConfigs = sets.New[string]( + configTemplateName, + configDecodeTemplateName, + configWorkerDataParallelName, + configDecodeWorkerDataParallelName, + configPrefillTemplateName, + configPrefillWorkerDataParallelName, + configRouterSchedulerName, + configRouterRouteName, +) + +// combineBaseRefsConfig applies well-known config overlays to inject default values for various components, when some components are +// enabled. These LLMInferenceServiceConfig resources must exist in either resource namespace (prioritized) or +// SystemNamespace (e.g. `kserve`). +func (r *LLMISVCReconciler) combineBaseRefsConfig(ctx context.Context, llmSvc *v1alpha1.LLMInferenceService, reconcilerConfig *Config) (*v1alpha1.LLMInferenceServiceConfig, error) { + // Creates the initial spec with the merged BaseRefs, so that we know what's "Enabled". + resolvedSpec := *llmSvc.Spec.DeepCopy() + for _, ref := range llmSvc.Spec.BaseRefs { + cfg, err := r.getConfig(ctx, llmSvc, ref.Name) + if err != nil { + return nil, err + } + if cfg != nil { + var resolvedErr error + resolvedSpec, resolvedErr = mergeSpecs(resolvedSpec, cfg.Spec) + if resolvedErr != nil { + return nil, fmt.Errorf("failed to merge specs: %w", resolvedErr) + } + } + } + + if resolvedSpec.Model.Name != nil { + // If original model name was defaulted check if it was not substituted by baseRef + llmSvc.Spec.Model.Name = resolvedSpec.Model.Name + } + + refs := make([]corev1.LocalObjectReference, 0, len(llmSvc.Spec.BaseRefs)) + if resolvedSpec.Router != nil && resolvedSpec.Router.Scheduler != nil && !resolvedSpec.Router.Scheduler.Pool.HasRef() { + refs = append(refs, corev1.LocalObjectReference{Name: configRouterSchedulerName}) + } + if resolvedSpec.Router != nil && resolvedSpec.Router.Route != nil && !resolvedSpec.Router.Route.HTTP.HasRefs() { + refs = append(refs, corev1.LocalObjectReference{Name: configRouterRouteName}) + } + switch { + // Disaggregated prefill and decode (P/D) cases. + case resolvedSpec.Prefill != nil && resolvedSpec.Prefill.Worker == nil: + refs = append(refs, corev1.LocalObjectReference{Name: configPrefillTemplateName}) + refs = append(refs, corev1.LocalObjectReference{Name: configDecodeTemplateName}) + case resolvedSpec.Prefill != nil && resolvedSpec.Prefill.Worker != nil && resolvedSpec.Prefill.Parallelism.IsPipelineParallel(): + refs = append(refs, corev1.LocalObjectReference{Name: configDecodeWorkerPipelineParallelName}) + refs = append(refs, corev1.LocalObjectReference{Name: configPrefillWorkerPipelineParallelName}) + case resolvedSpec.Prefill != nil && resolvedSpec.Prefill.Worker != nil && resolvedSpec.Prefill.Parallelism.IsDataParallel(): + refs = append(refs, corev1.LocalObjectReference{Name: configDecodeWorkerDataParallelName}) + refs = append(refs, corev1.LocalObjectReference{Name: configPrefillWorkerDataParallelName}) + // Multi Node without Disaggregated prefill and decode (P/D) cases. + case resolvedSpec.Worker != nil && resolvedSpec.Parallelism.IsPipelineParallel(): + refs = append(refs, corev1.LocalObjectReference{Name: configWorkerPipelineParallelName}) + case resolvedSpec.Worker != nil && resolvedSpec.Parallelism.IsDataParallel(): + refs = append(refs, corev1.LocalObjectReference{Name: configWorkerDataParallelName}) + default: + // Single Node case. + refs = append(refs, corev1.LocalObjectReference{Name: configTemplateName}) + } + // Append explicit base refs to override well know configs. + refs = append(refs, llmSvc.Spec.BaseRefs...) + + specs := make([]v1alpha1.LLMInferenceServiceSpec, 0, len(llmSvc.Spec.BaseRefs)+1) + for _, ref := range refs { + cfg, err := r.getConfig(ctx, llmSvc, ref.Name) + if err != nil { + return nil, err + } + if cfg != nil { + specs = append(specs, cfg.Spec) + } + } + spec, err := MergeSpecs(append(specs, llmSvc.Spec)...) + if err != nil { + return nil, fmt.Errorf("failed to merge specs: %w", err) + } + + llmSvcCfg := &v1alpha1.LLMInferenceServiceConfig{ + ObjectMeta: *llmSvc.ObjectMeta.DeepCopy(), + Spec: spec, + } + + if llmSvcCfg.Spec.Router != nil && + llmSvcCfg.Spec.Router.Scheduler != nil && + llmSvcCfg.Spec.Router.Scheduler.Pool != nil && + llmSvcCfg.Spec.Router.Scheduler.Pool.Spec != nil && + len(llmSvcCfg.Spec.Router.Scheduler.Pool.Spec.Selector) == 0 { + selector := getInferencePoolWorkloadLabelSelector(llmSvc.ObjectMeta, &llmSvcCfg.Spec) + + gieSelector := make(map[igwapi.LabelKey]igwapi.LabelValue, len(selector)) + for k, v := range selector { + gieSelector[igwapi.LabelKey(k)] = igwapi.LabelValue(v) + } + llmSvcCfg.Spec.Router.Scheduler.Pool.Spec.Selector = gieSelector + } + + if llmSvcCfg.Spec.Router != nil && + llmSvcCfg.Spec.Router.Scheduler != nil && + llmSvcCfg.Spec.Router.Scheduler.Template != nil && + llmSvcCfg.Spec.Router.Scheduler.Template.ServiceAccountName == "" { + llmSvcCfg.Spec.Router.Scheduler.Template.ServiceAccountName = kmeta.ChildName(llmSvc.GetName(), "-epp-sa") + } + + llmSvcCfg, err = ReplaceVariables(llmSvc, llmSvcCfg, reconcilerConfig) + if err != nil { + return llmSvcCfg, err + } + + return llmSvcCfg, nil +} + func ReplaceVariables(llmSvc *v1alpha1.LLMInferenceService, llmSvcCfg *v1alpha1.LLMInferenceServiceConfig, reconcilerConfig *Config) (*v1alpha1.LLMInferenceServiceConfig, error) { templateBytes, _ := json.Marshal(llmSvcCfg) buf := bytes.NewBuffer(nil) @@ -58,6 +203,25 @@ func ReplaceVariables(llmSvc *v1alpha1.LLMInferenceService, llmSvcCfg *v1alpha1. return out, nil } +// getConfig retrieves kserveapis.LLMInferenceServiceConfig with the given name from either the kserveapis.LLMInferenceService +// namespace or from the SystemNamespace (e.g. 'kserve'), prioritizing the former. +func (r *LLMISVCReconciler) getConfig(ctx context.Context, llmSvc *v1alpha1.LLMInferenceService, name string) (*v1alpha1.LLMInferenceServiceConfig, error) { + cfg := &v1alpha1.LLMInferenceServiceConfig{} + if err := r.Client.Get(ctx, client.ObjectKey{Name: name, Namespace: llmSvc.Namespace}, cfg); err != nil { + if apierrors.IsNotFound(err) { + cfg = &v1alpha1.LLMInferenceServiceConfig{} + if err := r.Client.Get(ctx, client.ObjectKey{Name: name, Namespace: constants.KServeNamespace}, cfg); err != nil { + // TODO: add available LLMInferenceServiceConfig in system namespace and llmSvc.Namespace namespace if not found + + return nil, fmt.Errorf("failed to get LLMInferenceServiceConfig %q from namespaces [%q, %q]: %w", name, llmSvc.Namespace, constants.KServeNamespace, err) + } + return cfg, nil + } + return nil, fmt.Errorf("failed to get LLMInferenceServiceConfig %s/%s: %w", llmSvc.Namespace, name, err) + } + return cfg, nil +} + func MergeSpecs(cfgs ...v1alpha1.LLMInferenceServiceSpec) (v1alpha1.LLMInferenceServiceSpec, error) { if len(cfgs) == 0 { return v1alpha1.LLMInferenceServiceSpec{}, nil diff --git a/pkg/controller/v1alpha1/llmisvc/config_presets_test.go b/pkg/controller/v1alpha1/llmisvc/config_presets_test.go new file mode 100644 index 00000000000..5ceb488c38b --- /dev/null +++ b/pkg/controller/v1alpha1/llmisvc/config_presets_test.go @@ -0,0 +1,382 @@ +/* +Copyright 2025 The KServe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package llmisvc_test + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/equality" + "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/intstr" + "k8s.io/utils/ptr" + "sigs.k8s.io/yaml" + + "github.com/kserve/kserve/pkg/controller/v1alpha1/llmisvc" + + kservetesting "github.com/kserve/kserve/pkg/testing" + + "github.com/kserve/kserve/pkg/apis/serving/v1alpha1" +) + +func TestPresetFiles(t *testing.T) { + presetsDir := filepath.Join(kservetesting.ProjectRoot(), "config", "llmisvc") + + llmSvc := llmisvc.LLMInferenceServiceSample() + kserveSystemConfig := llmisvc.Config{ + SystemNamespace: "kserve", + IngressGatewayName: "kserve-ingress-gateway", + IngressGatewayNamespace: "kserve", + } + + tt := map[string]struct { + expected *v1alpha1.LLMInferenceServiceConfig + }{ + "config-llm-decode-worker-data-parallel.yaml": { + expected: &v1alpha1.LLMInferenceServiceConfig{ + TypeMeta: metav1.TypeMeta{ + APIVersion: "serving.kserve.io/v1alpha1", + Kind: "LLMInferenceServiceConfig", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: "kserve-config-llm-decode-worker-data-parallel", + }, + Spec: v1alpha1.LLMInferenceServiceSpec{ + WorkloadSpec: v1alpha1.WorkloadSpec{ + Worker: &corev1.PodSpec{ + Volumes: []corev1.Volume{ + { + Name: "home", + VolumeSource: corev1.VolumeSource{ + EmptyDir: &corev1.EmptyDirVolumeSource{}, + }, + }, + { + Name: "dshm", + VolumeSource: corev1.VolumeSource{ + EmptyDir: &corev1.EmptyDirVolumeSource{ + Medium: corev1.StorageMediumMemory, + SizeLimit: ptr.To(resource.MustParse("1Gi")), + }, + }, + }, + { + Name: "model-cache", + VolumeSource: corev1.VolumeSource{ + EmptyDir: &corev1.EmptyDirVolumeSource{}, + }, + }, + { + Name: "tls-certs", + VolumeSource: corev1.VolumeSource{Secret: &corev1.SecretVolumeSource{SecretName: "test-llm-preset-kserve-self-signed-certs"}}, + }, + }, + TerminationGracePeriodSeconds: ptr.To(int64(30)), + InitContainers: []corev1.Container{ + { + Name: "llm-d-routing-sidecar", + Image: "ghcr.io/llm-d/llm-d-routing-sidecar:v0.2.0", + Args: []string{ + "--port=8000", + "--vllm-port=8001", + "--secure-proxy=true", + "--cert-path=/etc/ssl/certs", + "--decoder-use-tls=true", + "--decoder-tls-insecure-skip-verify=true", + "--prefiller-use-tls=true", + "--prefiller-tls-insecure-skip-verify=true", + "--enable-ssrf-protection=true", + }, + Env: []corev1.EnvVar{ + { + Name: "INFERENCE_POOL_NAMESPACE", + ValueFrom: &corev1.EnvVarSource{ + FieldRef: &corev1.ObjectFieldSelector{ + FieldPath: "metadata.namespace", + }, + }, + }, + }, + Ports: []corev1.ContainerPort{ + { + ContainerPort: 8000, + Protocol: corev1.ProtocolTCP, + }, + }, + RestartPolicy: ptr.To(corev1.ContainerRestartPolicyAlways), + TerminationMessagePath: "/dev/termination-log", + TerminationMessagePolicy: "FallbackToLogsOnError", + ImagePullPolicy: "IfNotPresent", + VolumeMounts: []corev1.VolumeMount{ + { + Name: "tls-certs", + ReadOnly: true, + MountPath: "/etc/ssl/certs", + }, + }, + ReadinessProbe: &corev1.Probe{ + ProbeHandler: corev1.ProbeHandler{ + HTTPGet: &corev1.HTTPGetAction{ + Path: "/health", + Port: intstr.FromInt32(8000), + Scheme: corev1.URISchemeHTTPS, + }, + }, + InitialDelaySeconds: 10, + TimeoutSeconds: 5, + PeriodSeconds: 10, + FailureThreshold: 10, + }, + LivenessProbe: &corev1.Probe{ + ProbeHandler: corev1.ProbeHandler{ + HTTPGet: &corev1.HTTPGetAction{ + Path: "/health", + Port: intstr.FromInt32(8000), + Scheme: corev1.URISchemeHTTPS, + }, + }, + InitialDelaySeconds: 10, + TimeoutSeconds: 10, + PeriodSeconds: 10, + FailureThreshold: 3, + }, + }, + }, + Containers: []corev1.Container{ + { + Name: "main", + Image: "ghcr.io/llm-d/llm-d:v0.2.0", + Command: []string{"/bin/sh", "-c"}, + Ports: []corev1.ContainerPort{ + { + ContainerPort: 8001, + Protocol: corev1.ProtocolTCP, + }, + }, + VolumeMounts: []corev1.VolumeMount{ + { + Name: "home", + MountPath: "/home", + }, + { + Name: "dshm", + MountPath: "/dev/shm", + }, + { + Name: "model-cache", + MountPath: "/models", + }, + { + Name: "tls-certs", + ReadOnly: true, + MountPath: "/etc/ssl/certs", + }, + }, + LivenessProbe: &corev1.Probe{ + ProbeHandler: corev1.ProbeHandler{ + HTTPGet: &corev1.HTTPGetAction{ + Path: "/health", + Port: intstr.FromInt32(8001), + Scheme: corev1.URISchemeHTTPS, + }, + }, + InitialDelaySeconds: 120, + PeriodSeconds: 10, + TimeoutSeconds: 10, + FailureThreshold: 3, + }, + ReadinessProbe: &corev1.Probe{ + ProbeHandler: corev1.ProbeHandler{ + HTTPGet: &corev1.HTTPGetAction{ + Path: "/health", + Port: intstr.FromInt32(8001), + Scheme: corev1.URISchemeHTTPS, + }, + }, + InitialDelaySeconds: 10, + PeriodSeconds: 10, + TimeoutSeconds: 5, + FailureThreshold: 60, + }, + SecurityContext: &corev1.SecurityContext{ + AllowPrivilegeEscalation: ptr.To(false), + Capabilities: &corev1.Capabilities{ + Add: []corev1.Capability{ + "IPC_LOCK", + "SYS_RAWIO", + }, + }, + }, + Env: []corev1.EnvVar{ + { + Name: "HOME", + Value: "/home", + }, + { + Name: "VLLM_LOGGING_LEVEL", + Value: "INFO", + }, + { + Name: "HF_HUB_CACHE", + Value: "/models", + }, + }, + TerminationMessagePath: "/dev/termination-log", + TerminationMessagePolicy: "FallbackToLogsOnError", + ImagePullPolicy: "IfNotPresent", + Stdin: true, + TTY: true, + Args: []string{` +START_RANK=$(( ${LWS_WORKER_INDEX:-0} * 2 )) +if [ "${LWS_WORKER_INDEX:-0}" -eq 0 ]; then + ################# + # Leader-only launch + ################# + vllm serve \ + llama \ + --port 8001 \ + --api-server-count 4 \ + --disable-log-requests \ +--enable-expert-parallel \ +--tensor-parallel-size 1 \ + --data-parallel-size 4 \ + --data-parallel-size-local 2 \ + --data-parallel-address $(LWS_LEADER_ADDRESS) \ + --data-parallel-rpc-port 5555 \ + --data-parallel-start-rank $START_RANK \ + --trust-remote-code \ + --enable-ssl-refresh \ + --ssl-certfile \ + /etc/ssl/certs/tls.crt \ + --ssl-keyfile \ + /etc/ssl/certs/tls.key +else + ################# + # Worker-only launch + ################# + vllm serve \ + llama \ + --port 8001 \ + --disable-log-requests \ +--enable-expert-parallel \ +--tensor-parallel-size 1 \ + --data-parallel-size 4 \ + --data-parallel-size-local 2 \ + --data-parallel-address $(LWS_LEADER_ADDRESS) \ + --data-parallel-rpc-port 5555 \ + --data-parallel-start-rank $START_RANK \ + --trust-remote-code \ + --headless \ + --enable-ssl-refresh \ + --ssl-certfile \ + /etc/ssl/certs/tls.crt \ + --ssl-keyfile \ + /etc/ssl/certs/tls.key +fi`}, + }, + }, + }, + }, + }, + }, + }, + } + + remaining := llmisvc.WellKnownDefaultConfigs.Clone() + + _ = filepath.Walk(presetsDir, func(path string, info os.FileInfo, err error) error { + if err != nil { + t.Errorf("Failed to walk %s: %v", path, err) + return err + } + + filename := info.Name() + if info.IsDir() || !strings.HasSuffix(filename, ".yaml") || !strings.HasPrefix(filename, "config-") { + return nil + } + + t.Run(filename, func(t *testing.T) { + filePath := filepath.Join(presetsDir, filename) + data, err := os.ReadFile(filePath) + if err != nil { + t.Errorf("Failed to read file %s: %v", filePath, err) + return + } + + config := loadConfig(t, data, filePath) + + name := config.ObjectMeta.Name + if !llmisvc.WellKnownDefaultConfigs.Has(name) { + t.Fatalf("Expected %s to exist in WellKnownDefaultConfigs %#v", name, llmisvc.WellKnownDefaultConfigs) + } + // Remove from the tracked set + remaining = remaining.Delete(name) + + out, err := llmisvc.ReplaceVariables(llmSvc, config, &kserveSystemConfig) + if err != nil { + t.Errorf("ReplaceVariables() failed for %s: %v", filename, err) + } + + // Verify the actual Spec rendered if provided for the found file. + if tc, exist := tt[filename]; exist { + if !equality.Semantic.DeepEqual(tc.expected, out) { + diff := cmp.Diff(tc.expected, out) + t.Errorf("ReplaceVariables() returned unexpected diff (-want +got):\n%s", diff) + } + } + }) + + return nil + }) + + if remaining.Len() > 0 { + t.Errorf("Found %d remaining well-known-configs that are missing as manifest files: %#v", remaining.Len(), remaining) + } +} + +func loadConfig(t *testing.T, data []byte, filePath string) *v1alpha1.LLMInferenceServiceConfig { + config := &v1alpha1.LLMInferenceServiceConfig{} + if err := yaml.Unmarshal(data, config); err != nil { + t.Errorf("Failed to unmarshal YAML from %s: %v", filePath, err) + return nil + } + if err := yaml.Unmarshal(data, config); err != nil { + t.Errorf("Failed to unmarshal YAML from %s: %v", filePath, err) + return nil + } + + expectedGroupVersion := v1alpha1.LLMInferenceServiceConfigGVK.GroupVersion().String() + if config.APIVersion != expectedGroupVersion { + t.Errorf("Expected APIVersion to be '%s', got '%s'", expectedGroupVersion, config.APIVersion) + } + + expectedKind := v1alpha1.LLMInferenceServiceConfigGVK.Kind + if config.Kind != expectedKind { + t.Errorf("Expected Kind to be '%s', got %s", expectedKind, config.Kind) + } + + if config.ObjectMeta.Name == "" { + t.Error("Expected ObjectMeta.Name to be set") + } + + return config +} diff --git a/pkg/controller/v1alpha1/llmisvc/sample.go b/pkg/controller/v1alpha1/llmisvc/sample.go new file mode 100644 index 00000000000..8f9a3603f3e --- /dev/null +++ b/pkg/controller/v1alpha1/llmisvc/sample.go @@ -0,0 +1,226 @@ +/* +Copyright 2025 The KServe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package llmisvc + +import ( + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/utils/ptr" + "knative.dev/pkg/apis" + + "github.com/kserve/kserve/pkg/apis/serving/v1alpha1" +) + +// LLMInferenceServiceSample defines a full sample of LLMInferenceService that can be used +// as a basis to apply LLMInferenceServiceConfigs. It is used for validating templated values +// in LLMInferenceServiceConfig CR. +func LLMInferenceServiceSample() *v1alpha1.LLMInferenceService { + svcName := "test-llm-preset" + nsName := "test-llm-preset-test" + modelURL, _ := apis.ParseURL("llama") + + return &v1alpha1.LLMInferenceService{ + ObjectMeta: metav1.ObjectMeta{ + Name: svcName, + Namespace: nsName, + Labels: map[string]string{ + "app.kubernetes.io/name": "llminferenceservice", + "app.kubernetes.io/instance": svcName, + "app.kubernetes.io/component": "inference", + }, + Annotations: map[string]string{ + "serving.kserve.io/model-uri": modelURL.String(), + }, + }, + Spec: v1alpha1.LLMInferenceServiceSpec{ + Model: v1alpha1.LLMModelSpec{ + Name: ptr.To("llama"), + URI: *modelURL, + Storage: &v1alpha1.LLMStorageSpec{ + Path: ptr.To("/models"), + Parameters: &map[string]string{ + "storageUri": modelURL.String(), + }, + }, + }, + WorkloadSpec: v1alpha1.WorkloadSpec{ + Replicas: ptr.To[int32](2), + Parallelism: &v1alpha1.ParallelismSpec{ + Data: ptr.To[int32](4), + DataLocal: ptr.To[int32](2), + Tensor: ptr.To[int32](1), + Expert: true, + }, + Template: &corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "kserve-container", + Image: "ghcr.io/llm-d/llm-d:v0.2.0", + Ports: []corev1.ContainerPort{ + { + ContainerPort: 8000, + Name: "http", + Protocol: corev1.ProtocolTCP, + }, + }, + Resources: corev1.ResourceRequirements{ + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("2"), + corev1.ResourceMemory: resource.MustParse("4Gi"), + }, + Limits: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("4"), + corev1.ResourceMemory: resource.MustParse("8Gi"), + }, + }, + Env: []corev1.EnvVar{ + { + Name: "MODEL_NAME", + Value: "facebook/opt-125m", + }, + { + Name: "VLLM_LOGGING_LEVEL", + Value: "INFO", + }, + }, + }, + }, + Tolerations: []corev1.Toleration{ + { + Key: "nvidia.com/gpu", + Operator: corev1.TolerationOpExists, + Effect: corev1.TaintEffectNoSchedule, + }, + }, + NodeSelector: map[string]string{ + "node.kubernetes.io/instance-type": "gpu-node", + }, + }, + Worker: &corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "kserve-container", + Image: "ghcr.io/llm-d/llm-d:0.2.0", + Resources: corev1.ResourceRequirements{ + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("1"), + corev1.ResourceMemory: resource.MustParse("2Gi"), + "nvidia.com/gpu": resource.MustParse("1"), + }, + Limits: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("2"), + corev1.ResourceMemory: resource.MustParse("4Gi"), + "nvidia.com/gpu": resource.MustParse("1"), + }, + }, + }, + }, + }, + }, + Prefill: &v1alpha1.WorkloadSpec{ + Replicas: ptr.To[int32](1), + Parallelism: &v1alpha1.ParallelismSpec{ + Tensor: ptr.To[int32](1), + Pipeline: ptr.To[int32](1), + }, + Template: &corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "kserve-container", + Image: "ghcr.io/llm-d/llm-d:v0.2.0", + Ports: []corev1.ContainerPort{ + { + ContainerPort: 8000, + Name: "http", + Protocol: corev1.ProtocolTCP, + }, + }, + Resources: corev1.ResourceRequirements{ + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("4"), + corev1.ResourceMemory: resource.MustParse("8Gi"), + "nvidia.com/gpu": resource.MustParse("2"), + }, + Limits: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("8"), + corev1.ResourceMemory: resource.MustParse("16Gi"), + "nvidia.com/gpu": resource.MustParse("2"), + }, + }, + }, + }, + }, + }, + Router: &v1alpha1.RouterSpec{ + Route: &v1alpha1.GatewayRoutesSpec{ + HTTP: &v1alpha1.HTTPRouteSpec{ + Refs: []corev1.LocalObjectReference{ + {Name: "custom-http-route"}, + }, + }, + }, + Gateway: &v1alpha1.GatewaySpec{ + Refs: []v1alpha1.UntypedObjectReference{ + { + Name: "kserve-ingress-gateway", + Namespace: "kserve", + }, + }, + }, + Scheduler: &v1alpha1.SchedulerSpec{ + Pool: &v1alpha1.InferencePoolSpec{ + Ref: &corev1.LocalObjectReference{ + Name: "custom-inference-pool", + }, + }, + Template: &corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "scheduler", + Image: "ghcr.io/llm-d/llm-d-inference-scheduler:0.0.4", + Ports: []corev1.ContainerPort{ + { + ContainerPort: 9002, + Name: "grpc", + Protocol: corev1.ProtocolTCP, + }, + { + ContainerPort: 9003, + Name: "grpc-health", + Protocol: corev1.ProtocolTCP, + }, + }, + Resources: corev1.ResourceRequirements{ + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("256m"), + corev1.ResourceMemory: resource.MustParse("500Mi"), + }, + }, + Env: []corev1.EnvVar{ + {Name: "ENABLE_LOAD_AWARE_SCORER", Value: "true"}, + {Name: "POOL_NAME", Value: svcName + "-inference-pool"}, + {Name: "POOL_NAMESPACE", Value: nsName}, + }, + }, + }, + }, + }, + }, + }, + } +} diff --git a/pkg/controller/v1alpha1/llmisvc/utils.go b/pkg/controller/v1alpha1/llmisvc/utils.go new file mode 100644 index 00000000000..bd36f0502a9 --- /dev/null +++ b/pkg/controller/v1alpha1/llmisvc/utils.go @@ -0,0 +1,32 @@ +/* +Copyright 2025 The KServe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package llmisvc + +import ( + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/kserve/kserve/pkg/apis/serving/v1alpha1" +) + +func getInferencePoolWorkloadLabelSelector(meta metav1.ObjectMeta, _ *v1alpha1.LLMInferenceServiceSpec) map[string]string { + s := map[string]string{ + "app.kubernetes.io/part-of": "llminferenceservice", + "app.kubernetes.io/name": meta.GetName(), + "kserve.io/component": "workload", + } + return s +} From 99dae0d5ebaad210d50e142eb1382980e3bf3dd4 Mon Sep 17 00:00:00 2001 From: Bartosz Majsak Date: Mon, 28 Jul 2025 16:18:15 +0200 Subject: [PATCH 9/9] fix: fixes misleading print for parallelism when $# > 2 (#784) (#4619) Signed-off-by: Bartosz Majsak --- test/scripts/gh-actions/run-e2e-tests.sh | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/test/scripts/gh-actions/run-e2e-tests.sh b/test/scripts/gh-actions/run-e2e-tests.sh index 88ab0ea2308..bc573712502 100755 --- a/test/scripts/gh-actions/run-e2e-tests.sh +++ b/test/scripts/gh-actions/run-e2e-tests.sh @@ -22,16 +22,12 @@ set -o nounset set -o pipefail echo "Starting E2E functional tests ..." -if [ $# -eq 2 ]; then - echo "Parallelism requested for pytest is $2" -else - echo "No parallelism requested for pytest. Will use default value of 1" -fi - MARKER="${1}" PARALLELISM="${2:-1}" NETWORK_LAYER="${3:-'istio'}" +echo "Parallelism requested for pytest is ${PARALLELISM}" + source python/kserve/.venv/bin/activate pushd test/e2e >/dev/null if [[ $MARKER == "raw" && $NETWORK_LAYER == "istio-ingress" ]]; then