Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@

/**
* The primary class that starts an embedded Jetty server
*
* @author Steve Springett
* @since 1.0.0
*/
Expand Down Expand Up @@ -73,7 +74,7 @@ public static void main(final String[] args) throws Exception {

final Server server = new Server();
final HttpConfiguration httpConfig = new HttpConfiguration();
httpConfig.addCustomizer( new org.eclipse.jetty.server.ForwardedRequestCustomizer() ); // Add support for X-Forwarded headers
httpConfig.addCustomizer(new org.eclipse.jetty.server.ForwardedRequestCustomizer()); // Add support for X-Forwarded headers

// Enable legacy (mimicking Jetty 9) URI compliance.
// This is required to allow URL encoding in path segments, e.g. "/foo/bar%2Fbaz".
Expand All @@ -89,7 +90,7 @@ public static void main(final String[] args) throws Exception {
// here, the only viable long-term solution is to adapt REST APIs to follow Servlet API 6 spec.
httpConfig.setUriCompliance(UriCompliance.LEGACY);

final HttpConnectionFactory connectionFactory = new HttpConnectionFactory( httpConfig );
final HttpConnectionFactory connectionFactory = new HttpConnectionFactory(httpConfig);
final ServerConnector connector = new ServerConnector(server, connectionFactory);
connector.setHost(host);
connector.setPort(port);
Expand All @@ -102,6 +103,7 @@ public static void main(final String[] args) throws Exception {
context.setErrorHandler(new ErrorHandler());
context.setInitParameter("org.eclipse.jetty.servlet.Default.dirAllowed", "false");
context.setAttribute("org.eclipse.jetty.server.webapp.ContainerIncludeJarPattern", ".*/[^/]*taglibs.*\\.jar$");
context.setThrowUnavailableOnStartupException(true);

// Prevent loading of logging classes
context.getProtectedClassMatcher().add("org.apache.log4j.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,10 @@ public enum ConfigKey implements Config.Key {
VULNERABILITY_POLICY_S3_BUCKET_NAME("vulnerability.policy.s3.bucket.name", null),
VULNERABILITY_POLICY_S3_BUNDLE_NAME("vulnerability.policy.s3.bundle.name", null),
VULNERABILITY_POLICY_S3_REGION("vulnerability.policy.s3.region", null),
DATABASE_MIGRATION_URL("database.migration.url", null),
DATABASE_MIGRATION_USERNAME("database.migration.username", null),
DATABASE_MIGRATION_PASSWORD("database.migration.password", null),
DATABASE_RUN_MIGRATIONS("database.run.migrations", true),
DATABASE_RUN_MIGRATIONS_ONLY("database.run.migrations.only", false),
INIT_TASKS_ENABLED("init.tasks.enabled", true),
INIT_TASKS_DATABASE_URL("init.tasks.database.url", null),
INIT_TASKS_DATABASE_USERNAME("init.tasks.database.username", null),
INIT_TASKS_DATABASE_PASSWORD("init.tasks.database.password", null),
INIT_AND_EXIT("init.and.exit", false),

DEV_SERVICES_ENABLED("dev.services.enabled", false),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,9 @@
*/
package org.dependencytrack.health;

import alpine.Config;
import alpine.common.logging.Logger;
import alpine.server.health.HealthCheckRegistry;
import alpine.server.health.checks.DatabaseHealthCheck;
import org.dependencytrack.common.ConfigKey;
import org.dependencytrack.event.kafka.processor.ProcessorsHealthCheck;

import jakarta.servlet.ServletContextEvent;
Expand All @@ -34,12 +32,6 @@ public class HealthCheckInitializer implements ServletContextListener {

@Override
public void contextInitialized(final ServletContextEvent event) {
if (Config.getInstance().getPropertyAsBoolean(ConfigKey.INIT_AND_EXIT)) {
LOGGER.debug("Not registering health checks because %s is enabled"
.formatted(ConfigKey.INIT_AND_EXIT.getPropertyName()));
return;
}

LOGGER.info("Registering health checks");
HealthCheckRegistry.getInstance().register("database", new DatabaseHealthCheck());
HealthCheckRegistry.getInstance().register("kafka-processors", new ProcessorsHealthCheck());
Expand Down
51 changes: 51 additions & 0 deletions apiserver/src/main/java/org/dependencytrack/init/InitTask.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* This file is part of Dependency-Track.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* SPDX-License-Identifier: Apache-2.0
* Copyright (c) OWASP Foundation. All Rights Reserved.
*/
package org.dependencytrack.init;

/**
* A task to be run on application startup.
*
* @since 5.6.0
*/
public interface InitTask {

int PRIORITY_HIGHEST = 100;
int PRIORITY_LOWEST = 0;

/**
* @return Priority of the task.
* @see #PRIORITY_HIGHEST
* @see #PRIORITY_LOWEST
*/
int priority();

/**
* @return Name of the task. Must be globally unique.
*/
String name();

/**
* Execute the task.
*
* @param ctx Context in which the task is executed.
* @throws Exception When the task execution failed.
*/
void execute(InitTaskContext ctx) throws Exception;

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,21 @@
* SPDX-License-Identifier: Apache-2.0
* Copyright (c) OWASP Foundation. All Rights Reserved.
*/
package org.dependencytrack.persistence.defaults;
package org.dependencytrack.init;

import java.io.IOException;
import alpine.Config;

public interface IDefaultObjectImporter {

boolean shouldImport();

void loadDefaults() throws IOException;
import javax.sql.DataSource;

/**
* Context available to {@link InitTask}s.
* <p>
* TODO: Introduce a tiny abstraction over {@link Config} such that
* Alpine specifics don't bleed through to {@link InitTask}s.
*
* @param config A {@link Config} instance to read application configuration.
* @param dataSource A {@link DataSource} which may be used for database interactions.
* @since 5.6.0
*/
public record InitTaskContext(Config config, DataSource dataSource) {
}
180 changes: 180 additions & 0 deletions apiserver/src/main/java/org/dependencytrack/init/InitTaskExecutor.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
/*
* This file is part of Dependency-Track.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* SPDX-License-Identifier: Apache-2.0
* Copyright (c) OWASP Foundation. All Rights Reserved.
*/
package org.dependencytrack.init;

import alpine.Config;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import jakarta.servlet.ServletContextListener;
import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.ServiceLoader;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import java.util.function.Predicate;

import static java.util.Comparator.comparing;
import static java.util.Comparator.reverseOrder;
import static java.util.Objects.requireNonNull;
import static org.dependencytrack.util.ConfigUtil.getPassThroughProperties;

/**
* @since 5.6.0
*/
final class InitTaskExecutor implements ServletContextListener {

private static final Logger LOGGER = LoggerFactory.getLogger(InitTaskExecutor.class);
private static final long ADVISORY_LOCK_KEY = "dependency-track-init-tasks".hashCode();

private final Config config;
private final DataSource dataSource;
private final List<InitTask> tasks;

InitTaskExecutor(final Config config, final DataSource dataSource) {
this(config, dataSource, loadInitTasks());
}

InitTaskExecutor(final Config config, final DataSource dataSource, final List<InitTask> tasks) {
this.config = requireNonNull(config, "config must not be null");
this.dataSource = requireNonNull(dataSource, "dataSource must not be null");
this.tasks = requireNonNull(tasks, "tasks must not be null");
}

public void execute() {
final List<InitTask> orderedTasks = this.tasks.stream()
.peek(requireUniqueName())
.peek(requireValidPriority())
.filter(isTaskEnabled())
.sorted(comparing(InitTask::priority, reverseOrder())
.thenComparing(InitTask::name))
.toList();

final long startTimeNanos = System.nanoTime();

// We're using session-level advisory locks here,
// which won't work when using PgBouncer in "transaction" mode.
// We can't use transaction-level locking because that would
// block some DDL statements executed by database migrations,
// such as "CREATE INDEX CONCURRENTLY".
//
// This GitLab issue describes the problem well:
// https://gitlab.com/gitlab-com/support/support-training/-/issues/3823#locks-block-a-gitlab-database-migration
//
// The intended workaround is to use a separate set of connection
// details specifically for init tasks, which bypasses PgBouncer.
try (final Connection connection = dataSource.getConnection();
final PreparedStatement lockStatement = connection.prepareStatement("""
SELECT PG_ADVISORY_LOCK(?)
""");
final PreparedStatement unlockStatement = connection.prepareStatement("""
SELECT PG_ADVISORY_UNLOCK(?)
""")) {
LOGGER.debug("Trying to acquire lock {}", ADVISORY_LOCK_KEY);
lockStatement.setLong(1, ADVISORY_LOCK_KEY);
lockStatement.execute();
LOGGER.debug(
"Lock {} acquired after {}ms",
ADVISORY_LOCK_KEY,
TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNanos));

final var taskContext = new InitTaskContext(config, dataSource);

try {
long taskStartTimeNanos;
for (final InitTask task : orderedTasks) {
taskStartTimeNanos = System.nanoTime();
LOGGER.info("Executing init task {}", task.name());
try {
task.execute(taskContext);
LOGGER.info(
"Completed init task {} in {}ms",
task.name(),
TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - taskStartTimeNanos));
} catch (Exception e) {
throw new IllegalStateException("Failed to execute init task " + task.name(), e);
}
}
} finally {
LOGGER.debug("Releasing lock {}", ADVISORY_LOCK_KEY);
unlockStatement.setLong(1, ADVISORY_LOCK_KEY);
final ResultSet rs = unlockStatement.executeQuery();
if (!rs.next() || !rs.getBoolean(1)) {
LOGGER.warn("""
Lock {} could not be released, likely because a connection pooler \
in "transaction" mode is being used. Ensure that a direct database connection \
is provided when executing init tasks.""", ADVISORY_LOCK_KEY);
}
}
} catch (SQLException e) {
throw new IllegalStateException("Failed to acquire or release lock " + ADVISORY_LOCK_KEY, e);
}

LOGGER.info(
"All init tasks completed in {}ms",
TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNanos));
}

private static List<InitTask> loadInitTasks() {
return ServiceLoader.load(InitTask.class).stream()
.map(ServiceLoader.Provider::get)
.toList();
}

private Consumer<InitTask> requireUniqueName() {
final var seenTaskClassesByTaskName =
new HashMap<String, Class<? extends InitTask>>(this.tasks.size());

return task -> {
final Class<? extends InitTask> previousClass =
seenTaskClassesByTaskName.put(task.name(), task.getClass());
if (previousClass != null) {
throw new IllegalStateException(
"Duplicate task name %s: Registered by %s and %s".formatted(
task.name(), previousClass.getName(), task.getClass().getName()));
}
};
}

private Consumer<InitTask> requireValidPriority() {
return task -> {
if (task.priority() < InitTask.PRIORITY_LOWEST
|| task.priority() > InitTask.PRIORITY_HIGHEST) {
throw new IllegalStateException(
"Invalid priority of task %s: Must be within [%d..%d] but is %d".formatted(
task.name(), InitTask.PRIORITY_LOWEST, InitTask.PRIORITY_HIGHEST, task.priority()));
}
};
}

private Predicate<InitTask> isTaskEnabled() {
return task -> {
final String propertyPrefix = "init.task." + task.name();
final Map<String, String> properties = getPassThroughProperties(config, propertyPrefix);
return !"false".equals(properties.get(propertyPrefix + ".enabled"));
};
}

}
Loading
Loading