Java21 并发处理说明

发布于:2025-05-25 ⋅ 阅读:(22) ⋅ 点赞:(0)

目录

一、结构化并发(Structured Concurrency)

1. 核心概念

2. 基本用法

3. 实际应用示例:并行API调用与结果聚合

4. 结构化并发与CompletableFuture对比

5. 结构化并发的优势

二、虚拟线程支持数百万并发连接

1. 高并发HTTP服务器实现

2. 百万级WebSocket连接处理

3. 数据库连接池与虚拟线程集成

4. 性能对比:虚拟线程 vs 平台线程

三、Scoped Values实现高效线程内数据共享

1. 基本用法

2. Web应用中的请求上下文传播

3. 分布式跟踪与日志关联

4. ScopedValue与ThreadLocal对比

5. 实际应用:微服务上下文传播

总结与最佳实践

结构化并发最佳实践

虚拟线程最佳实践

ScopedValue最佳实践


一、结构化并发(Structured Concurrency)

结构化并发是Java 21引入的预览特性,它通过将并发任务组织成一个有明确生命周期和作用域的结构,使并发代码更易于理解、维护和调试。它解决了传统异步编程中的几个关键问题:任务泄漏、错误传播和上下文传递。

1. 核心概念

结构化并发基于以下核心概念:

  • 结构化任务作用域(StructuredTaskScope):定义并发任务的边界和生命周期
  • 子任务:在作用域内启动的并发任务
  • 作用域生命周期:作用域关闭时,所有子任务必须完成或被取消
  • 错误传播:子任务的异常会传播到父作用域

2. 基本用法

java

import java.util.concurrent.Future;
import java.util.concurrent.StructuredTaskScope;
import java.util.concurrent.StructuredTaskScope.ShutdownOnFailure;
import java.util.concurrent.StructuredTaskScope.ShutdownOnSuccess;

public class StructuredConcurrencyExample {
    
    // 使用ShutdownOnFailure策略:任一子任务失败时关闭作用域
    public static Result performTasks() throws Exception {
        try (var scope = new ShutdownOnFailure()) {
            // 启动多个子任务
            Future<DataPart1> future1 = scope.fork(() -> fetchDataPart1());
            Future<DataPart2> future2 = scope.fork(() -> fetchDataPart2());
            
            // 等待所有子任务完成或失败
            scope.join();
            
            // 如果有任务失败,抛出异常
            scope.throwIfFailed(e -> new RuntimeException("Task failed", e));
            
            // 处理结果
            return new Result(future1.resultNow(), future2.resultNow());
        }
    }
    
    // 使用ShutdownOnSuccess策略:任一子任务成功时关闭作用域
    public static Response findFirstAvailable() throws Exception {
        try (var scope = new ShutdownOnSuccess<Response>()) {
            // 向多个服务发送相同请求
            scope.fork(() -> callService1());
            scope.fork(() -> callService2());
            scope.fork(() -> callService3());
            
            // 等待任一子任务成功或全部失败
            scope.join();
            
            // 获取第一个成功的结果
            return scope.result();
        } catch (Exception e) {
            // 所有服务都失败
            return fallbackResponse();
        }
    }
    
    // 自定义作用域策略
    public static void customScopeStrategy() throws Exception {
        try (var scope = new StructuredTaskScope.Subtask<Object>()) {
            // 启动多个子任务
            Future<String> future1 = scope.fork(() -> task1());
            Future<Integer> future2 = scope.fork(() -> task2());
            Future<Boolean> future3 = scope.fork(() -> task3());
            
            // 等待所有子任务完成(成功或失败)
            scope.join();
            
            // 手动处理每个任务的结果
            if (future1.state() == Future.State.SUCCESS) {
                System.out.println("Task 1 result: " + future1.resultNow());
            } else if (future1.state() == Future.State.FAILED) {
                System.out.println("Task 1 failed: " + future1.exceptionNow());
            }
            
            // 处理其他任务...
        }
    }
    
    // 模拟方法
    private static DataPart1 fetchDataPart1() { /* 实现省略 */ return new DataPart1(); }
    private static DataPart2 fetchDataPart2() { /* 实现省略 */ return new DataPart2(); }
    private static Response callService1() { /* 实现省略 */ return new Response(); }
    private static Response callService2() { /* 实现省略 */ return new Response(); }
    private static Response callService3() { /* 实现省略 */ return new Response(); }
    private static Response fallbackResponse() { /* 实现省略 */ return new Response(); }
    private static String task1() { /* 实现省略 */ return "result"; }
    private static Integer task2() { /* 实现省略 */ return 42; }
    private static Boolean task3() { /* 实现省略 */ return true; }
    
    // 数据类
    static class DataPart1 {}
    static class DataPart2 {}
    static class Result {
        public Result(DataPart1 part1, DataPart2 part2) {}
    }
    static class Response {}
}

3. 实际应用示例:并行API调用与结果聚合

java

import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.time.Duration;
import java.util.concurrent.Future;
import java.util.concurrent.StructuredTaskScope;
import java.util.concurrent.TimeoutException;

public class ParallelApiCalls {
    
    private final HttpClient client = HttpClient.newBuilder( )
            .connectTimeout(Duration.ofSeconds(5))
            .build();
    
    // 并行获取用户资料和订单历史,超时自动取消
    public UserProfileView getUserProfileView(String userId) throws Exception {
        try (var scope = new StructuredTaskScope.ShutdownOnFailure()) {
            // 启动并行API调用
            Future<UserProfile> profileFuture = scope.fork(() -> fetchUserProfile(userId));
            Future<OrderHistory> ordersFuture = scope.fork(() -> fetchOrderHistory(userId));
            Future<Recommendations> recommendationsFuture = scope.fork(() -> fetchRecommendations(userId));
            
            // 等待所有任务完成或任一任务失败,最多等待3秒
            try {
                scope.joinUntil(java.time.Instant.now().plusSeconds(3));
            } catch (TimeoutException e) {
                // 超时处理
                System.out.println("Request timed out, returning partial data");
                // 继续执行,使用已完成的结果
            }
            
            // 检查是否有任务失败
            try {
                scope.throwIfFailed(e -> new ServiceException("Failed to fetch user data", e));
            } catch (ServiceException e) {
                // 记录错误,但仍尝试返回部分数据
                System.err.println("Error fetching some user data: " + e.getMessage());
            }
            
            // 构建结果,即使部分数据缺失
            UserProfile profile = (profileFuture.state() == Future.State.SUCCESS) 
                ? profileFuture.resultNow() : new UserProfile(userId, "Unknown", null);
                
            OrderHistory orders = (ordersFuture.state() == Future.State.SUCCESS)
                ? ordersFuture.resultNow() : new OrderHistory(userId, java.util.Collections.emptyList());
                
            Recommendations recommendations = (recommendationsFuture.state() == Future.State.SUCCESS)
                ? recommendationsFuture.resultNow() : new Recommendations(java.util.Collections.emptyList());
            
            return new UserProfileView(profile, orders, recommendations);
        }
    }
    
    // 实现API调用方法
    private UserProfile fetchUserProfile(String userId) throws Exception {
        HttpRequest request = HttpRequest.newBuilder()
                .uri(new URI("https://api.example.com/users/" + userId ))
                .GET()
                .build();
        
        HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
        
        if (response.statusCode() != 200) {
            throw new ServiceException("Failed to fetch profile, status: " + response.statusCode());
        }
        
        // 解析JSON响应(实际实现会使用JSON库)
        return parseUserProfile(response.body());
    }
    
    private OrderHistory fetchOrderHistory(String userId) throws Exception {
        HttpRequest request = HttpRequest.newBuilder()
                .uri(new URI("https://api.example.com/users/" + userId + "/orders" ))
                .GET()
                .build();
        
        HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
        
        if (response.statusCode() != 200) {
            throw new ServiceException("Failed to fetch orders, status: " + response.statusCode());
        }
        
        return parseOrderHistory(response.body());
    }
    
    private Recommendations fetchRecommendations(String userId) throws Exception {
        HttpRequest request = HttpRequest.newBuilder()
                .uri(new URI("https://api.example.com/users/" + userId + "/recommendations" ))
                .GET()
                .build();
        
        HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
        
        if (response.statusCode() != 200) {
            throw new ServiceException("Failed to fetch recommendations, status: " + response.statusCode());
        }
        
        return parseRecommendations(response.body());
    }
    
    // 解析方法和数据类(简化实现)
    private UserProfile parseUserProfile(String json) { /* 实现省略 */ return new UserProfile("1", "John Doe", "john@example.com"); }
    private OrderHistory parseOrderHistory(String json) { /* 实现省略 */ return new OrderHistory("1", java.util.List.of(new Order("123", 99.99))); }
    private Recommendations parseRecommendations(String json) { /* 实现省略 */ return new Recommendations(java.util.List.of("prod1", "prod2")); }
    
    // 数据类
    static class UserProfile {
        private final String id;
        private final String name;
        private final String email;
        
        public UserProfile(String id, String name, String email) {
            this.id = id;
            this.name = name;
            this.email = email;
        }
        
        // Getters省略
    }
    
    static class Order {
        private final String orderId;
        private final double amount;
        
        public Order(String orderId, double amount) {
            this.orderId = orderId;
            this.amount = amount;
        }
        
        // Getters省略
    }
    
    static class OrderHistory {
        private final String userId;
        private final java.util.List<Order> orders;
        
        public OrderHistory(String userId, java.util.List<Order> orders) {
            this.userId = userId;
            this.orders = orders;
        }
        
        // Getters省略
    }
    
    static class Recommendations {
        private final java.util.List<String> productIds;
        
        public Recommendations(java.util.List<String> productIds) {
            this.productIds = productIds;
        }
        
        // Getters省略
    }
    
    static class UserProfileView {
        private final UserProfile profile;
        private final OrderHistory orderHistory;
        private final Recommendations recommendations;
        
        public UserProfileView(UserProfile profile, OrderHistory orderHistory, Recommendations recommendations) {
            this.profile = profile;
            this.orderHistory = orderHistory;
            this.recommendations = recommendations;
        }
        
        // Getters省略
    }
    
    static class ServiceException extends Exception {
        public ServiceException(String message) {
            super(message);
        }
        
        public ServiceException(String message, Throwable cause) {
            super(message, cause);
        }
    }
}

4. 结构化并发与CompletableFuture对比

java

// 传统CompletableFuture方式
public UserProfileView getUserProfileViewWithCompletableFuture(String userId) {
    CompletableFuture<UserProfile> profileFuture = CompletableFuture.supplyAsync(() -> {
        try {
            return fetchUserProfile(userId);
        } catch (Exception e) {
            throw new CompletionException(e);
        }
    });
    
    CompletableFuture<OrderHistory> ordersFuture = CompletableFuture.supplyAsync(() -> {
        try {
            return fetchOrderHistory(userId);
        } catch (Exception e) {
            throw new CompletionException(e);
        }
    });
    
    CompletableFuture<Recommendations> recommendationsFuture = CompletableFuture.supplyAsync(() -> {
        try {
            return fetchRecommendations(userId);
        } catch (Exception e) {
            throw new CompletionException(e);
        }
    });
    
    // 组合结果
    try {
        CompletableFuture<UserProfileView> combined = profileFuture
            .thenCombine(ordersFuture, (profile, orders) -> new Pair<>(profile, orders))
            .thenCombine(recommendationsFuture, (pair, recommendations) -> 
                new UserProfileView(pair.first, pair.second, recommendations));
        
        // 添加超时
        return combined.orTimeout(3, TimeUnit.SECONDS)
            .exceptionally(e -> {
                // 错误处理和部分结果构建逻辑
                UserProfile profile = profileFuture.isCompletedExceptionally() ? 
                    new UserProfile(userId, "Unknown", null) : profileFuture.join();
                
                OrderHistory orders = ordersFuture.isCompletedExceptionally() ?
                    new OrderHistory(userId, Collections.emptyList()) : ordersFuture.join();
                
                Recommendations recommendations = recommendationsFuture.isCompletedExceptionally() ?
                    new Recommendations(Collections.emptyList()) : recommendationsFuture.join();
                
                return new UserProfileView(profile, orders, recommendations);
            })
            .join();
    } catch (Exception e) {
        // 最终错误处理
        return createFallbackUserProfileView(userId);
    }
}

// 辅助类
static class Pair<A, B> {
    final A first;
    final B second;
    
    Pair(A first, B second) {
        this.first = first;
        this.second = second;
    }
}

5. 结构化并发的优势

  1. 生命周期管理:自动处理子任务的生命周期,防止任务泄漏
  2. 错误传播:子任务的异常会自动传播到父作用域
  3. 取消传播:取消父作用域会自动取消所有子任务
  4. 超时处理:支持整个作用域的超时设置
  5. 代码结构:代码结构与执行流程一致,更易于理解和维护
  6. 资源管理:通过try-with-resources确保资源正确释放
  7. 可组合性:作用域可以嵌套,支持复杂的并发模式

二、虚拟线程支持数百万并发连接

虚拟线程是Java 21引入的轻量级线程实现,它们由JVM管理并调度到平台线程上,使应用能够高效处理大量并发连接。

1. 高并发HTTP服务器实现

java

import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicInteger;

public class MillionConnectionsServer {
    
    private static final int PORT = 8080;
    private static final AtomicInteger activeConnections = new AtomicInteger(0);
    private static final AtomicInteger totalConnections = new AtomicInteger(0);
    
    public static void main(String[] args) throws IOException {
        // 创建服务器套接字
        try (ServerSocketChannel serverChannel = ServerSocketChannel.open()) {
            serverChannel.bind(new InetSocketAddress(PORT));
            System.out.println("Server started on port " + PORT);
            
            // 创建虚拟线程执行器
            try (ExecutorService executor = Executors.newVirtualThreadPerTaskExecutor()) {
                // 启动监控线程
                executor.submit(MillionConnectionsServer::monitorConnections);
                
                while (true) {
                    // 接受新连接
                    SocketChannel clientChannel = serverChannel.accept();
                    
                    // 为每个连接创建一个虚拟线程
                    executor.submit(() -> handleConnection(clientChannel));
                }
            }
        }
    }
    
    private static void handleConnection(SocketChannel clientChannel) {
        int connectionId = totalConnections.incrementAndGet();
        int current = activeConnections.incrementAndGet();
        
        try {
            System.out.printf("Connection #%d accepted. Active: %d%n", connectionId, current);
            
            // 配置为非阻塞模式
            clientChannel.configureBlocking(true);
            
            // 读取请求
            ByteBuffer buffer = ByteBuffer.allocate(1024);
            int bytesRead = clientChannel.read(buffer);
            
            if (bytesRead > 0) {
                buffer.flip();
                String request = StandardCharsets.UTF_8.decode(buffer).toString();
                System.out.printf("Connection #%d received: %s%n", connectionId, request.trim());
                
                // 模拟处理时间(可能是数据库查询或其他I/O操作)
                Thread.sleep(Duration.ofMillis(100));
                
                // 准备响应
                String response = "HTTP/1.1 200 OK\r\n" +
                                 "Content-Type: text/plain\r\n" +
                                 "Connection: close\r\n" +
                                 "\r\n" +
                                 "Hello from virtual thread #" + connectionId + "\r\n";
                
                ByteBuffer responseBuffer = ByteBuffer.wrap(response.getBytes(StandardCharsets.UTF_8));
                
                // 发送响应
                while (responseBuffer.hasRemaining()) {
                    clientChannel.write(responseBuffer);
                }
            }
        } catch (Exception e) {
            System.err.printf("Error handling connection #%d: %s%n", connectionId, e.getMessage());
        } finally {
            try {
                clientChannel.close();
            } catch (IOException e) {
                System.err.printf("Error closing connection #%d: %s%n", connectionId, e.getMessage());
            }
            activeConnections.decrementAndGet();
        }
    }
    
    private static void monitorConnections() {
        try {
            while (true) {
                Thread.sleep(Duration.ofSeconds(5));
                int active = activeConnections.get();
                int total = totalConnections.get();
                System.out.printf("=== STATS: Active connections: %d, Total handled: %d ===%n", active, total);
                
                // 输出JVM内存使用情况
                Runtime rt = Runtime.getRuntime();
                long usedMemory = (rt.totalMemory() - rt.freeMemory()) / 1024 / 1024;
                long totalMemory = rt.totalMemory() / 1024 / 1024;
                System.out.printf("=== MEMORY: Used: %d MB, Total: %d MB ===%n", usedMemory, totalMemory);
            }
        } catch (InterruptedException e) {
            System.out.println("Monitoring thread interrupted");
        }
    }
}

2. 百万级WebSocket连接处理

java

import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

import org.java_websocket.WebSocket;
import org.java_websocket.handshake.ClientHandshake;
import org.java_websocket.server.WebSocketServer;

public class MillionWebSocketServer extends WebSocketServer {
    
    private static final int PORT = 8080;
    private static final AtomicInteger connectedClients = new AtomicInteger(0);
    private static final AtomicInteger totalMessages = new AtomicInteger(0);
    private final ConcurrentHashMap<WebSocket, ConnectionInfo> connections = new ConcurrentHashMap<>();
    
    public static void main(String[] args) {
        // 设置虚拟线程作为默认线程工厂
        System.setProperty("java.util.concurrent.ForkJoinPool.common.threadFactory", "java.lang.Thread$VirtualThreadFactory");
        
        // 启动WebSocket服务器
        MillionWebSocketServer server = new MillionWebSocketServer(new InetSocketAddress(PORT));
        server.setReuseAddr(true);
        server.start();
        
        // 启动监控线程
        ScheduledExecutorService scheduler = Executors.newSingleThreadScheduledExecutor();
        scheduler.scheduleAtFixedRate(() -> {
            int clients = connectedClients.get();
            int messages = totalMessages.get();
            System.out.printf("Connected clients: %d, Total messages: %d%n", clients, messages);
            
            // 输出内存使用情况
            Runtime rt = Runtime.getRuntime();
            long usedMemory = (rt.totalMemory() - rt.freeMemory()) / 1024 / 1024;
            long totalMemory = rt.totalMemory() / 1024 / 1024;
            System.out.printf("Memory usage: %d MB / %d MB%n", usedMemory, totalMemory);
        }, 5, 5, TimeUnit.SECONDS);
        
        // 定期广播消息(模拟实时更新)
        scheduler.scheduleAtFixedRate(() -> {
            server.broadcast("Server time: " + java.time.LocalDateTime.now());
        }, 10, 10, TimeUnit.SECONDS);
    }
    
    public MillionWebSocketServer(InetSocketAddress address) {
        super(address);
    }
    
    @Override
    public void onOpen(WebSocket conn, ClientHandshake handshake) {
        int clientId = connectedClients.incrementAndGet();
        connections.put(conn, new ConnectionInfo(clientId, System.currentTimeMillis()));
        
        System.out.printf("New connection #%d established from %s%n", 
                clientId, conn.getRemoteSocketAddress());
        
        // 发送欢迎消息
        conn.send("Welcome! You are client #" + clientId);
        
        // 使用虚拟线程处理连接初始化(可能涉及数据库查询等)
        Thread.startVirtualThread(() -> {
            try {
                // 模拟初始化工作
                Thread.sleep(100);
                conn.send("Connection initialized. Start sending messages!");
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            }
        });
    }
    
    @Override
    public void onClose(WebSocket conn, int code, String reason, boolean remote) {
        ConnectionInfo info = connections.remove(conn);
        if (info != null) {
            connectedClients.decrementAndGet();
            System.out.printf("Connection #%d closed. Code: %d, Reason: %s, Remote: %b%n", 
                    info.id, code, reason, remote);
        }
    }
    
    @Override
    public void onMessage(WebSocket conn, String message) {
        ConnectionInfo info = connections.get(conn);
        if (info != null) {
            totalMessages.incrementAndGet();
            info.messageCount++;
            
            // 使用虚拟线程处理消息
            Thread.startVirtualThread(() -> processMessage(conn, info, message));
        }
    }
    
    private void processMessage(WebSocket conn, ConnectionInfo info, String message) {
        try {
            // 模拟消息处理(可能包括数据库操作、外部API调用等)
            Thread.sleep(50);
            
            // 回复客户端
            conn.send("Received message #" + info.messageCount + ": " + message);
            
            // 如果是广播消息,则转发给所有其他客户端
            if (message.startsWith("BROADCAST:")) {
                String broadcastMessage = "User #" + info.id + " says: " + message.substring(10);
                broadcast(broadcastMessage, conn);
            }
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
        }
    }
    
    @Override
    public void onMessage(WebSocket conn, ByteBuffer message) {
        // 处理二进制消息
        ConnectionInfo info = connections.get(conn);
        if (info != null) {
            totalMessages.incrementAndGet();
            info.messageCount++;
            
            // 使用虚拟线程处理二进制消息
            Thread.startVirtualThread(() -> {
                try {
                    // 模拟处理时间
                    Thread.sleep(50);
                    
                    // 回复确认
                    conn.send(ByteBuffer.wrap(("Received binary message of size " + message.remaining() + " bytes").getBytes()));
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                }
            });
        }
    }
    
    @Override
    public void onError(WebSocket conn, Exception ex) {
        ConnectionInfo info = conn != null ? connections.get(conn) : null;
        System.err.printf("Error on connection %s: %s%n", 
                (info != null ? "#" + info.id : "unknown"), ex.getMessage());
        ex.printStackTrace();
    }
    
    @Override
    public void onStart() {
        System.out.println("WebSocket server started on port " + PORT);
        System.out.println("Using virtual threads for connection handling");
        setConnectionLostTimeout(60); // 60 seconds timeout
    }
    
    // 连接信息类
    private static class ConnectionInfo {
        final int id;
        final long connectTime;
        int messageCount = 0;
        
        ConnectionInfo(int id, long connectTime) {
            this.id = id;
            this.connectTime = connectTime;
        }
    }
}

3. 数据库连接池与虚拟线程集成

java

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.atomic.AtomicInteger;

import com.zaxxer.hikari.HikariConfig;
import com.zaxxer.hikari.HikariDataSource;

public class VirtualThreadDbConnectionPool {
    
    private static final int CONCURRENT_REQUESTS = 10_000;
    private static final AtomicInteger completedRequests = new AtomicInteger(0);
    private static final AtomicInteger activeThreads = new AtomicInteger(0);
    
    public static void main(String[] args) throws Exception {
        // 配置HikariCP连接池
        HikariConfig config = new HikariConfig();
        config.setJdbcUrl("jdbc:postgresql://localhost:5432/testdb");
        config.setUsername("postgres");
        config.setPassword("password");
        config.setMaximumPoolSize(20); // 只使用20个物理数据库连接
        config.setMinimumIdle(5);
        config.setConnectionTimeout(30000);
        config.setIdleTimeout(600000);
        config.setMaxLifetime(1800000);
        
        try (HikariDataSource dataSource = new HikariDataSource(config)) {
            System.out.println("Database connection pool initialized");
            
            // 创建虚拟线程执行器
            try (var executor = Executors.newThreadPerTaskExecutor(Thread.ofVirtual().factory())) {
                System.out.println("Starting " + CONCURRENT_REQUESTS + " concurrent database operations");
                
                // 启动监控线程
                Thread monitorThread = Thread.startVirtualThread(() -> monitorProgress());
                
                // 提交大量并发请求
                List<Thread> threads = new ArrayList<>();
                for (int i = 0; i < CONCURRENT_REQUESTS; i++) {
                    final int requestId = i;
                    Thread t = Thread.ofVirtual().start(() -> {
                        activeThreads.incrementAndGet();
                        try {
                            performDatabaseOperation(dataSource, requestId);
                        } finally {
                            activeThreads.decrementAndGet();
                        }
                    });
                    threads.add(t);
                }
                
                // 等待所有请求完成
                for (Thread t : threads) {
                    t.join();
                }
                
                System.out.println("All database operations completed");
                monitorThread.interrupt();
            }
        }
    }
    
    private static void performDatabaseOperation(HikariDataSource dataSource, int requestId) {
        try {
            // 模拟随机处理时间
            Thread.sleep(Duration.ofMillis((long) (Math.random() * 100)));
            
            // 获取数据库连接
            try (Connection conn = dataSource.getConnection()) {
                // 执行查询
                try (PreparedStatement stmt = conn.prepareStatement(
                        "SELECT pg_sleep(0.1), ?::int as request_id")) {
                    stmt.setInt(1, requestId);
                    
                    try (ResultSet rs = stmt.executeQuery()) {
                        if (rs.next()) {
                            int result = rs.getInt("request_id");
                            // 模拟结果处理
                            Thread.sleep(Duration.ofMillis((long) (Math.random() * 50)));
                        }
                    }
                }
            }
            
            // 请求完成
            int completed = completedRequests.incrementAndGet();
            if (completed % 100 == 0) {
                System.out.printf("Completed %d/%d requests%n", completed, CONCURRENT_REQUESTS);
            }
        } catch (SQLException e) {
            System.err.println("Database error in request #" + requestId + ": " + e.getMessage());
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            System.err.println("Request #" + requestId + " interrupted");
        }
    }
    
    private static void monitorProgress() {
        try {
            while (!Thread.currentThread().isInterrupted()) {
                Thread.sleep(Duration.ofSeconds(1));
                int active = activeThreads.get();
                int completed = completedRequests.get();
                
                // 获取JVM内存使用情况
                Runtime rt = Runtime.getRuntime();
                long usedMemory = (rt.totalMemory() - rt.freeMemory()) / 1024 / 1024;
                
                System.out.printf("Progress: %d/%d completed, %d active, Memory: %d MB%n", 
                        completed, CONCURRENT_REQUESTS, active, usedMemory);
            }
        } catch (InterruptedException e) {
            // 预期的中断,正常退出
        }
    }
}

4. 性能对比:虚拟线程 vs 平台线程

java

import java.time.Duration;
import java.time.Instant;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.atomic.AtomicInteger;

public class ThreadPerformanceComparison {
    
    private static final int TASK_COUNT = 100_000;
    private static final int IO_SIMULATION_MS = 50; // 模拟IO操作的时间
    
    public static void main(String[] args) throws Exception {
        // 测试平台线程
        System.out.println("Testing with platform threads...");
        testWithExecutor(Executors.newFixedThreadPool(200), "Platform Thread Pool (200 threads)");
        
        // 等待GC和系统稳定
        System.gc();
        Thread.sleep(2000);
        
        // 测试虚拟线程
        System.out.println("\nTesting with virtual threads...");
        testWithExecutor(Executors.newVirtualThreadPerTaskExecutor(), "Virtual Thread Per Task");
        
        // 测试虚拟线程工厂
        System.out.println("\nTesting with virtual thread factory...");
        testWithThreadFactory(Thread.ofVirtual().factory(), "Virtual Thread Factory");
        
        // 测试平台线程工厂(作为对比)
        System.out.println("\nTesting with platform thread factory...");
        testWithThreadFactory(Thread.ofPlatform().factory(), "Platform Thread Factory");
    }
    
    private static void testWithExecutor(ExecutorService executor, String testName) throws Exception {
        Instant start = Instant.now();
        CountDownLatch latch = new CountDownLatch(TASK_COUNT);
        AtomicInteger activeTasks = new AtomicInteger(0);
        AtomicInteger maxActiveTasks = new AtomicInteger(0);
        
        // 提交任务
        for (int i = 0; i < TASK_COUNT; i++) {
            executor.submit(() -> {
                try {
                    // 跟踪活动任务数
                    int active = activeTasks.incrementAndGet();
                    maxActiveTasks.updateAndGet(max -> Math.max(max, active));
                    
                    // 模拟IO操作
                    Thread.sleep(Duration.ofMillis(IO_SIMULATION_MS));
                    
                    return true;
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                    return false;
                } finally {
                    activeTasks.decrementAndGet();
                    latch.countDown();
                }
            });
        }
        
        // 等待所有任务完成
        latch.await();
        executor.shutdown();
        
        Instant end = Instant.now();
        Duration duration = Duration.between(start, end);
        
        // 输出结果
        System.out.printf("%s:%n", testName);
        System.out.printf("  Completed %d tasks in %d ms%n", TASK_COUNT, duration.toMillis());
        System.out.printf("  Throughput: %.2f tasks/second%n", 
                TASK_COUNT / (duration.toMillis() / 1000.0));
        System.out.printf("  Max concurrent tasks: %d%n", maxActiveTasks.get());
        
        // 输出内存使用情况
        Runtime rt = Runtime.getRuntime();
        long usedMemory = (rt.totalMemory() - rt.freeMemory()) / 1024 / 1024;
        System.out.printf("  Memory usage: %d MB%n", usedMemory);
    }
    
    private static void testWithThreadFactory(ThreadFactory factory, String testName) throws Exception {
        Instant start = Instant.now();
        CountDownLatch latch = new CountDownLatch(TASK_COUNT);
        AtomicInteger activeTasks = new AtomicInteger(0);
        AtomicInteger maxActiveTasks = new AtomicInteger(0);
        
        // 创建并启动线程
        for (int i = 0; i < TASK_COUNT; i++) {
            Thread thread = factory.newThread(() -> {
                try {
                    // 跟踪活动任务数
                    int active = activeTasks.incrementAndGet();
                    maxActiveTasks.updateAndGet(max -> Math.max(max, active));
                    
                    // 模拟IO操作
                    Thread.sleep(Duration.ofMillis(IO_SIMULATION_MS));
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                } finally {
                    activeTasks.decrementAndGet();
                    latch.countDown();
                }
            });
            thread.start();
        }
        
        // 等待所有线程完成
        latch.await();
        
        Instant end = Instant.now();
        Duration duration = Duration.between(start, end);
        
        // 输出结果
        System.out.printf("%s:%n", testName);
        System.out.printf("  Completed %d tasks in %d ms%n", TASK_COUNT, duration.toMillis());
        System.out.printf("  Throughput: %.2f tasks/second%n", 
                TASK_COUNT / (duration.toMillis() / 1000.0));
        System.out.printf("  Max concurrent tasks: %d%n", maxActiveTasks.get());
        
        // 输出内存使用情况
        Runtime rt = Runtime.getRuntime();
        long usedMemory = (rt.totalMemory() - rt.freeMemory()) / 1024 / 1024;
        System.out.printf("  Memory usage: %d MB%n", usedMemory);
    }
}

三、Scoped Values实现高效线程内数据共享

Scoped Values是Java 21引入的预览特性,提供了一种线程内数据共享机制,特别适合虚拟线程环境。它解决了ThreadLocal在虚拟线程中可能导致的内存泄漏问题,并提供了更安全、更高效的上下文传播方式。

1. 基本用法

java

import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadLocalRandom;

public class ScopedValuesBasics {
    
    // 定义ScopedValue
    private static final ScopedValue<String> USER_ID = ScopedValue.newInstance();
    private static final ScopedValue<RequestContext> REQUEST_CONTEXT = ScopedValue.newInstance();
    
    public static void main(String[] args) {
        // 使用虚拟线程执行器
        try (ExecutorService executor = Executors.newVirtualThreadPerTaskExecutor()) {
            // 提交多个任务
            for (int i = 0; i < 10; i++) {
                final int requestId = i;
                executor.submit(() -> processRequest("user-" + requestId));
            }
        }
    }
    
    private static void processRequest(String userId) {
        // 创建请求上下文
        RequestContext context = new RequestContext(
                "REQ-" + ThreadLocalRandom.current().nextInt(10000),
                System.currentTimeMillis());
        
        // 使用ScopedValue.where绑定值并执行代码块
        ScopedValue.where(USER_ID, userId)
            .where(REQUEST_CONTEXT, context)
            .run(() -> {
                // 在这个作用域内,USER_ID和REQUEST_CONTEXT的值是可用的
                handleRequest();
                
                // 嵌套作用域
                ScopedValue.where(USER_ID, userId + "-admin")
                    .run(() -> {
                        // 在这个嵌套作用域内,USER_ID的值被覆盖
                        // 但REQUEST_CONTEXT的值保持不变
                        performAdminOperation();
                    });
                
                // 回到外部作用域,USER_ID恢复原值
                finalizeRequest();
            });
    }
    
    private static void handleRequest() {
        // 访问ScopedValue的值
        String userId = USER_ID.get();
        RequestContext context = REQUEST_CONTEXT.get();
        
        System.out.printf("Handling request for user %s, request ID: %s, timestamp: %d%n",
                userId, context.requestId(), context.timestamp());
        
        // 模拟处理
        try {
            Thread.sleep(100);
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
        }
        
        // 调用其他方法,ScopedValue自动传播
        validatePermissions();
    }
    
    private static void validatePermissions() {
        // 在调用链中的任何位置都可以访问ScopedValue
        String userId = USER_ID.get();
        System.out.printf("Validating permissions for user %s%n", userId);
        
        // 模拟验证
        try {
            Thread.sleep(50);
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
        }
    }
    
    private static void performAdminOperation() {
        // 访问嵌套作用域中的值
        String userId = USER_ID.get();
        RequestContext context = REQUEST_CONTEXT.get();
        
        System.out.printf("Performing admin operation for %s, request ID: %s%n",
                userId, context.requestId());
        
        // 模拟操作
        try {
            Thread.sleep(50);
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
        }
    }
    
    private static void finalizeRequest() {
        String userId = USER_ID.get();
        RequestContext context = REQUEST_CONTEXT.get();
        
        System.out.printf("Finalizing request for user %s, request ID: %s%n",
                userId, context.requestId());
        
        // 模拟完成
        try {
            Thread.sleep(50);
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
        }
    }
    
    // 不可变上下文类
    record RequestContext(String requestId, long timestamp) {}
}

2. Web应用中的请求上下文传播

java

import java.io.IOException;
import java.time.Duration;
import java.time.Instant;
import java.util.UUID;
import java.util.concurrent.Executors;

import jakarta.servlet.Filter;
import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.ServletRequest;
import jakarta.servlet.ServletResponse;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;

public class ScopedValuesWebFilter implements Filter {
    
    // 定义应用级ScopedValue
    private static final ScopedValue<RequestContext> REQUEST_CONTEXT = ScopedValue.newInstance( );
    private static final ScopedValue<UserInfo> CURRENT_USER = ScopedValue.newInstance();
    
    @Override
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
            throws IOException, ServletException {
        
        HttpServletRequest httpRequest = (HttpServletRequest ) request;
        HttpServletResponse httpResponse = (HttpServletResponse ) response;
        
        // 创建请求上下文
        String requestId = UUID.randomUUID().toString();
        Instant startTime = Instant.now();
        RequestContext context = new RequestContext(
                requestId,
                startTime,
                httpRequest.getMethod( ),
                httpRequest.getRequestURI( ),
                httpRequest.getRemoteAddr( ));
        
        // 提取用户信息(从认证令牌、会话等)
        UserInfo userInfo = extractUserInfo(httpRequest );
        
        // 使用ScopedValue绑定上下文并处理请求
        ScopedValue.where(REQUEST_CONTEXT, context)
            .where(CURRENT_USER, userInfo)
            .run(() -> {
                try {
                    // 添加请求ID到响应头
                    httpResponse.setHeader("X-Request-ID", requestId );
                    
                    // 继续过滤器链
                    chain.doFilter(request, response);
                } catch (Exception e) {
                    // 记录异常
                    logException(e);
                    
                    try {
                        httpResponse.sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, 
                                "Internal server error" );
                    } catch (IOException ex) {
                        // 忽略
                    }
                } finally {
                    // 记录请求完成
                    Instant endTime = Instant.now();
                    Duration duration = Duration.between(startTime, endTime);
                    logRequestCompletion(httpResponse.getStatus( ), duration);
                }
            });
    }
    
    // 从请求中提取用户信息
    private UserInfo extractUserInfo(HttpServletRequest request) {
        // 实际实现会从认证令牌、会话等获取用户信息
        String authHeader = request.getHeader("Authorization");
        if (authHeader != null && authHeader.startsWith("Bearer ")) {
            String token = authHeader.substring(7);
            // 解析令牌(简化示例)
            return new UserInfo("user-123", "John Doe", new String[]{"USER", "ADMIN"});
        }
        
        // 匿名用户
        return new UserInfo("anonymous", "Guest", new String[]{"GUEST"});
    }
    
    // 记录异常
    private void logException(Exception e) {
        RequestContext context = REQUEST_CONTEXT.get();
        UserInfo user = CURRENT_USER.get();
        
        System.err.printf("[%s] Error processing request %s %s for user %s: %s%n",
                context.requestId(), context.method(), context.path(), 
                user.username(), e.getMessage());
        e.printStackTrace();
    }
    
    // 记录请求完成
    private void logRequestCompletion(int statusCode, Duration duration) {
        RequestContext context = REQUEST_CONTEXT.get();
        UserInfo user = CURRENT_USER.get();
        
        System.out.printf("[%s] Completed %s %s for user %s with status %d in %d ms%n",
                context.requestId(), context.method(), context.path(), 
                user.username(), statusCode, duration.toMillis());
    }
    
    // 在应用的任何地方访问当前请求上下文
    public static RequestContext getCurrentRequestContext() {
        return REQUEST_CONTEXT.get();
    }
    
    // 在应用的任何地方访问当前用户信息
    public static UserInfo getCurrentUser() {
        return CURRENT_USER.get();
    }
    
    // 检查当前用户是否有特定角色
    public static boolean hasRole(String role) {
        UserInfo user = CURRENT_USER.get();
        for (String userRole : user.roles()) {
            if (userRole.equals(role)) {
                return true;
            }
        }
        return false;
    }
    
    // 不可变上下文类
    record RequestContext(String requestId, Instant startTime, String method, String path, String clientIp) {}
    
    // 不可变用户信息类
    record UserInfo(String id, String username, String[] roles) {}
}

// 在控制器或服务中使用
class UserController {
    
    public void handleRequest() {
        // 获取当前请求上下文
        RequestContext context = ScopedValuesWebFilter.getCurrentRequestContext();
        
        // 获取当前用户
        UserInfo user = ScopedValuesWebFilter.getCurrentUser();
        
        // 检查权限
        if (!ScopedValuesWebFilter.hasRole("ADMIN")) {
            throw new SecurityException("Admin role required");
        }
        
        // 使用上下文和用户信息处理请求
        System.out.printf("Processing request %s for user %s%n", 
                context.requestId(), user.username());
        
        // 异步处理(ScopedValue会自动传播到虚拟线程)
        try (var executor = Executors.newVirtualThreadPerTaskExecutor()) {
            executor.submit(() -> {
                // 在异步任务中访问相同的上下文
                RequestContext asyncContext = ScopedValuesWebFilter.getCurrentRequestContext();
                UserInfo asyncUser = ScopedValuesWebFilter.getCurrentUser();
                
                System.out.printf("Async task for request %s, user %s%n", 
                        asyncContext.requestId(), asyncUser.username());
                
                // 执行异步操作...
            });
        }
    }
}

3. 分布式跟踪与日志关联

java

import java.time.Instant;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;

public class DistributedTracingWithScopedValues {
    
    // 定义跟踪上下文ScopedValue
    private static final ScopedValue<TraceContext> TRACE_CONTEXT = ScopedValue.newInstance();
    
    // 模拟跟踪ID存储(实际应用中可能使用分布式存储)
    private static final Map<String, TraceSpan> SPANS = new ConcurrentHashMap<>();
    
    public static void main(String[] args) {
        // 模拟处理HTTP请求
        processHttpRequest("/api/users/123", "GET");
    }
    
    private static void processHttpRequest(String path, String method) {
        // 创建或提取跟踪ID(在实际应用中可能从请求头提取)
        String traceId = UUID.randomUUID().toString();
        String spanId = UUID.randomUUID().toString();
        
        // 创建根跟踪上下文
        TraceContext rootContext = new TraceContext(traceId, spanId, null, Map.of(
                "http.path", path,
                "http.method", method
         ));
        
        // 使用ScopedValue绑定跟踪上下文
        ScopedValue.where(TRACE_CONTEXT, rootContext)
            .run(() -> {
                try {
                    // 开始根跟踪
                    startSpan(spanId, "http.request", Map.of(
                            "http.path", path,
                            "http.method", method
                     ));
                    
                    // 处理请求
                    handleRequest(path);
                    
                    // 完成根跟踪
                    endSpan(spanId, null);
                } catch (Exception e) {
                    // 记录异常并完成跟踪
                    endSpan(spanId, e);
                    throw e;
                }
            });
        
        // 输出收集的跟踪信息
        System.out.println("\nCollected trace spans:");
        SPANS.values().forEach(System.out::println);
    }
    
    private static void handleRequest(String path) {
        // 创建子跟踪
        withNewSpan("handle-request", Map.of("path", path), () -> {
            log("Processing request: " + path);
            
            // 调用数据库
            queryDatabase("SELECT * FROM users WHERE id = 123");
            
            // 调用外部服务
            callExternalService("https://api.example.com/data" );
            
            log("Request processing completed");
        });
    }
    
    private static void queryDatabase(String query) {
        // 创建数据库操作的子跟踪
        withNewSpan("db.query", Map.of("db.statement", query), () -> {
            log("Executing database query: " + query);
            
            // 模拟数据库操作
            try {
                Thread.sleep(100);
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            }
            
            log("Database query completed");
        });
    }
    
    private static void callExternalService(String url) {
        // 创建外部服务调用的子跟踪
        withNewSpan("http.client", Map.of("http.url", url ), () -> {
            log("Calling external service: " + url);
            
            // 模拟网络调用
            try {
                Thread.sleep(200);
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            }
            
            // 模拟处理响应
            processServiceResponse();
            
            log("External service call completed");
        });
    }
    
    private static void processServiceResponse() {
        // 创建响应处理的子跟踪
        withNewSpan("process.response", Map.of(), () -> {
            log("Processing service response");
            
            // 模拟处理
            try {
                Thread.sleep(50);
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            }
            
            log("Response processing completed");
        });
    }
    
    // 使用新的span执行代码块
    private static void withNewSpan(String name, Map<String, String> attributes, Runnable action) {
        // 获取当前跟踪上下文
        TraceContext parentContext = TRACE_CONTEXT.get();
        String newSpanId = UUID.randomUUID().toString();
        
        // 创建子跟踪上下文
        TraceContext childContext = new TraceContext(
                parentContext.traceId(),
                newSpanId,
                parentContext.spanId(),
                attributes
        );
        
        // 使用子上下文执行操作
        ScopedValue.where(TRACE_CONTEXT, childContext)
            .run(() -> {
                try {
                    // 开始跟踪
                    startSpan(newSpanId, name, attributes);
                    
                    // 执行操作
                    action.run();
                    
                    // 完成跟踪
                    endSpan(newSpanId, null);
                } catch (Exception e) {
                    // 记录异常并完成跟踪
                    endSpan(newSpanId, e);
                    throw e;
                }
            });
    }
    
    // 开始一个跟踪span
    private static void startSpan(String spanId, String name, Map<String, String> attributes) {
        TraceContext context = TRACE_CONTEXT.get();
        
        TraceSpan span = new TraceSpan(
                spanId,
                context.traceId(),
                context.parentSpanId(),
                name,
                Instant.now(),
                null,
                attributes,
                null
        );
        
        SPANS.put(spanId, span);
        log("Started span: " + name);
    }
    
    // 结束一个跟踪span
    private static void endSpan(String spanId, Exception error) {
        TraceSpan span = SPANS.get(spanId);
        if (span != null) {
            TraceSpan updatedSpan = new TraceSpan(
                    span.spanId(),
                    span.traceId(),
                    span.parentSpanId(),
                    span.name(),
                    span.startTime(),
                    Instant.now(),
                    span.attributes(),
                    error != null ? error.toString() : null
            );
            
            SPANS.put(spanId, updatedSpan);
            log("Ended span: " + span.name() + (error != null ? " with error: " + error.getMessage() : ""));
        }
    }
    
    // 记录日志(自动关联当前跟踪上下文)
    private static void log(String message) {
        TraceContext context = TRACE_CONTEXT.get();
        System.out.printf("[trace=%s, span=%s] %s%n", 
                context.traceId(), context.spanId(), message);
    }
    
    // 不可变跟踪上下文类
    record TraceContext(String traceId, String spanId, String parentSpanId, Map<String, String> attributes) {}
    
    // 不可变跟踪span类
    record TraceSpan(
            String spanId,
            String traceId,
            String parentSpanId,
            String name,
            Instant startTime,
            Instant endTime,
            Map<String, String> attributes,
            String error
    ) {}
}

4. ScopedValue与ThreadLocal对比

java

import java.time.Duration;
import java.time.Instant;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicInteger;

public class ScopedValueVsThreadLocal {
    
    // 定义ThreadLocal
    private static final ThreadLocal<String> THREAD_LOCAL_USER = new ThreadLocal<>();
    private static final ThreadLocal<RequestContext> THREAD_LOCAL_CONTEXT = new ThreadLocal<>();
    
    // 定义ScopedValue
    private static final ScopedValue<String> SCOPED_VALUE_USER = ScopedValue.newInstance();
    private static final ScopedValue<RequestContext> SCOPED_VALUE_CONTEXT = ScopedValue.newInstance();
    
    // 计数器
    private static final AtomicInteger threadLocalLeaks = new AtomicInteger(0);
    private static final AtomicInteger scopedValueAccesses = new AtomicInteger(0);
    
    public static void main(String[] args) throws Exception {
        int taskCount = 1_000_000;
        
        // 测试ThreadLocal
        System.out.println("Testing ThreadLocal with virtual threads...");
        Instant threadLocalStart = Instant.now();
        testThreadLocal(taskCount);
        Duration threadLocalDuration = Duration.between(threadLocalStart, Instant.now());
        
        // 等待GC和系统稳定
        System.gc();
        Thread.sleep(2000);
        
        // 测试ScopedValue
        System.out.println("\nTesting ScopedValue with virtual threads...");
        Instant scopedValueStart = Instant.now();
        testScopedValue(taskCount);
        Duration scopedValueDuration = Duration.between(scopedValueStart, Instant.now());
        
        // 输出结果对比
        System.out.println("\nPerformance Comparison:");
        System.out.printf("ThreadLocal: %d tasks in %d ms (%.2f tasks/second)%n",
                taskCount, threadLocalDuration.toMillis(),
                taskCount / (threadLocalDuration.toMillis() / 1000.0));
        System.out.printf("ScopedValue: %d tasks in %d ms (%.2f tasks/second)%n",
                taskCount, scopedValueDuration.toMillis(),
                taskCount / (scopedValueDuration.toMillis() / 1000.0));
        
        System.out.printf("\nThreadLocal potential leaks: %d%n", threadLocalLeaks.get());
        System.out.printf("ScopedValue successful accesses: %d%n", scopedValueAccesses.get());
        
        // 输出内存使用情况
        Runtime rt = Runtime.getRuntime();
        long usedMemory = (rt.totalMemory() - rt.freeMemory()) / 1024 / 1024;
        System.out.printf("\nFinal memory usage: %d MB%n", usedMemory);
    }
    
    private static void testThreadLocal(int taskCount) throws Exception {
        CountDownLatch latch = new CountDownLatch(taskCount);
        
        try (ExecutorService executor = Executors.newVirtualThreadPerTaskExecutor()) {
            for (int i = 0; i < taskCount; i++) {
                final int taskId = i;
                executor.submit(() -> {
                    try {
                        // 设置ThreadLocal值
                        THREAD_LOCAL_USER.set("user-" + taskId);
                        THREAD_LOCAL_CONTEXT.set(new RequestContext("req-" + taskId, System.currentTimeMillis()));
                        
                        // 执行任务
                        processWithThreadLocal();
                        
                        // 模拟忘记清理ThreadLocal(在某些任务中)
                        if (taskId % 10 != 0) {
                            THREAD_LOCAL_USER.remove();
                            THREAD_LOCAL_CONTEXT.remove();
                        } else {
                            // 模拟泄漏
                            threadLocalLeaks.incrementAndGet();
                        }
                    } finally {
                        latch.countDown();
                    }
                });
            }
            
            // 等待所有任务完成
            latch.await();
        }
    }
    
    private static void testScopedValue(int taskCount) throws Exception {
        CountDownLatch latch = new CountDownLatch(taskCount);
        
        try (ExecutorService executor = Executors.newVirtualThreadPerTaskExecutor()) {
            for (int i = 0; i < taskCount; i++) {
                final int taskId = i;
                executor.submit(() -> {
                    try {
                        // 使用ScopedValue
                        ScopedValue.where(SCOPED_VALUE_USER, "user-" + taskId)
                            .where(SCOPED_VALUE_CONTEXT, new RequestContext("req-" + taskId, System.currentTimeMillis()))
                            .run(() -> processWithScopedValue());
                    } finally {
                        latch.countDown();
                    }
                });
            }
            
            // 等待所有任务完成
            latch.await();
        }
    }
    
    private static void processWithThreadLocal() {
        try {
            // 访问ThreadLocal值
            String user = THREAD_LOCAL_USER.get();
            RequestContext context = THREAD_LOCAL_CONTEXT.get();
            
            // 模拟处理
            Thread.sleep(1);
            
            // 调用其他方法
            subTaskWithThreadLocal();
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
        }
    }
    
    private static void subTaskWithThreadLocal() {
        // 访问ThreadLocal值
        String user = THREAD_LOCAL_USER.get();
        RequestContext context = THREAD_LOCAL_CONTEXT.get();
        
        // 模拟处理
        // ...
    }
    
    private static void processWithScopedValue() {
        try {
            // 访问ScopedValue值
            String user = SCOPED_VALUE_USER.get();
            RequestContext context = SCOPED_VALUE_CONTEXT.get();
            
            // 记录成功访问
            scopedValueAccesses.incrementAndGet();
            
            // 模拟处理
            Thread.sleep(1);
            
            // 调用其他方法
            subTaskWithScopedValue();
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
        }
    }
    
    private static void subTaskWithScopedValue() {
        // 访问ScopedValue值
        String user = SCOPED_VALUE_USER.get();
        RequestContext context = SCOPED_VALUE_CONTEXT.get();
        
        // 模拟处理
        // ...
    }
    
    // 不可变上下文类
    record RequestContext(String requestId, long timestamp) {}
}

5. 实际应用:微服务上下文传播

java

import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.time.Duration;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

public class MicroserviceContextPropagation {
    
    // 定义ScopedValue
    private static final ScopedValue<RequestContext> REQUEST_CONTEXT = ScopedValue.newInstance( );
    
    // HTTP客户端
    private static final HttpClient httpClient = HttpClient.newBuilder( )
            .connectTimeout(Duration.ofSeconds(5))
            .build();
    
    // 模拟服务注册表
    private static final Map<String, String> SERVICE_REGISTRY = Map.of(
            "user-service", "http://localhost:8081",
            "order-service", "http://localhost:8082",
            "payment-service", "http://localhost:8083"
     );
    
    // 模拟请求计数
    private static final Map<String, Integer> REQUEST_COUNTS = new ConcurrentHashMap<>();
    
    public static void main(String[] args) {
        // 模拟API网关接收请求
        try (ExecutorService executor = Executors.newVirtualThreadPerTaskExecutor()) {
            // 处理多个并发请求
            for (int i = 0; i < 10; i++) {
                executor.submit(() -> handleGatewayRequest("/api/orders/123", "GET"));
            }
            
            // 等待一段时间让请求完成
            Thread.sleep(5000);
            
            // 输出请求统计
            System.out.println("\nRequest counts by service:");
            REQUEST_COUNTS.forEach((service, count) -> 
                    System.out.printf("%s: %d requests%n", service, count));
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
    
    private static void handleGatewayRequest(String path, String method) {
        // 创建请求上下文
        String requestId = UUID.randomUUID().toString();
        String correlationId = UUID.randomUUID().toString();
        String userId = "user-" + (int)(Math.random() * 1000);
        
        RequestContext context = new RequestContext(
                requestId,
                correlationId,
                userId,
                path,
                method,
                System.currentTimeMillis()
        );
        
        // 使用ScopedValue绑定上下文
        ScopedValue.where(REQUEST_CONTEXT, context)
            .run(() -> {
                try {
                    log("API Gateway received request: " + path);
                    
                    // 路由请求到适当的服务
                    if (path.startsWith("/api/orders")) {
                        callMicroservice("order-service", path);
                    } else if (path.startsWith("/api/users")) {
                        callMicroservice("user-service", path);
                    } else {
                        log("Unknown path: " + path);
                    }
                    
                    log("Request completed");
                } catch (Exception e) {
                    log("Error processing request: " + e.getMessage());
                }
            });
    }
    
    private static void callMicroservice(String serviceName, String path) {
        // 获取当前请求上下文
        RequestContext context = REQUEST_CONTEXT.get();
        
        // 增加服务请求计数
        REQUEST_COUNTS.compute(serviceName, (k, v) -> (v == null) ? 1 : v + 1);
        
        log("Calling " + serviceName + " with path " + path);
        
        try {
            // 构建请求,传播上下文
            String serviceUrl = SERVICE_REGISTRY.get(serviceName) + path;
            HttpRequest request = HttpRequest.newBuilder()
                    .uri(new URI(serviceUrl))
                    .header("X-Request-ID", context.requestId())
                    .header("X-Correlation-ID", context.correlationId())
                    .header("X-User-ID", context.userId())
                    .GET()
                    .build();
            
            // 模拟发送请求
            // 实际环境中会真正发送HTTP请求
            // httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString( ))
            
            // 模拟服务响应
            simulateServiceResponse(serviceName, path);
            
        } catch (Exception e) {
            log("Error calling " + serviceName + ": " + e.getMessage());
        }
    }
    
    private static void simulateServiceResponse(String serviceName, String path) {
        // 获取当前请求上下文
        RequestContext context = REQUEST_CONTEXT.get();
        
        // 模拟服务处理
        try {
            Thread.sleep(100);
            
            log(serviceName + " processing request " + path);
            
            // 模拟服务间调用
            if (serviceName.equals("order-service") && path.contains("/orders/")) {
                // 订单服务需要调用用户服务和支付服务
                
                // 使用虚拟线程并行调用其他服务
                try (var executor = Executors.newVirtualThreadPerTaskExecutor()) {
                    // ScopedValue自动传播到新的虚拟线程
                    CompletableFuture<Void> userFuture = CompletableFuture.runAsync(
                            () -> callMicroservice("user-service", "/api/users/" + context.userId()),
                            executor);
                    
                    CompletableFuture<Void> paymentFuture = CompletableFuture.runAsync(
                            () -> callMicroservice("payment-service", "/api/payments/for-order/123"),
                            executor);
                    
                    // 等待所有调用完成
                    CompletableFuture.allOf(userFuture, paymentFuture).join();
                }
                
                log(serviceName + " completed processing with dependent services");
            } else {
                // 其他服务简单处理
                Thread.sleep(50);
                log(serviceName + " completed processing");
            }
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            log(serviceName + " processing interrupted");
        }
    }
    
    // 记录日志(自动包含上下文信息)
    private static void log(String message) {
        RequestContext context = REQUEST_CONTEXT.get();
        System.out.printf("[req=%s, corr=%s, user=%s] %s%n",
                context.requestId(), context.correlationId(), context.userId(), message);
    }
    
    // 不可变请求上下文类
    record RequestContext(
            String requestId,
            String correlationId,
            String userId,
            String path,
            String method,
            long timestamp
    ) {}
}

总结与最佳实践

结构化并发最佳实践

  1. 明确作用域边界:使用try-with-resources确保作用域正确关闭
  2. 选择合适的策略:根据需求选择ShutdownOnFailure、ShutdownOnSuccess或自定义策略
  3. 处理异常:使用throwIfFailed方法处理子任务异常
  4. 设置超时:使用joinUntil方法为整个作用域设置超时
  5. 避免嵌套过深:保持作用域层次结构简单清晰
  6. 结合虚拟线程:结构化并发与虚拟线程配合使用效果最佳

虚拟线程最佳实践

  1. 适用场景:优先用于I/O密集型任务,不适合CPU密集型任务
  2. 避免同步阻塞:在synchronized块中避免阻塞操作,会"钉住"平台线程
  3. 替换线程池:对于每任务一线程模型,用虚拟线程替换固定大小线程池
  4. ThreadLocal使用:谨慎使用ThreadLocal,考虑改用ScopedValue
  5. 监控与调试:使用JDK Flight Recorder和jcmd工具监控虚拟线程
  6. 平台线程池大小:默认等于处理器数量,通常无需调整

ScopedValue最佳实践

  1. 不可变数据:ScopedValue中存储的对象应该是不可变的
  2. 避免过度使用:只用于需要在调用链中传递的上下文数据
  3. 明确作用域:使用where方法创建明确的作用域
  4. 嵌套使用:可以嵌套使用where方法创建多层作用域
  5. 替代ThreadLocal:在虚拟线程环境中,优先使用ScopedValue替代ThreadLocal
  6. 异常处理:确保异常不会跳过作用域的关闭

这三项技术结合使用,可以显著提升Java应用的并发性能和可维护性,特别适合构建现代云原生应用和微服务架构。


网站公告

今日签到

点亮在社区的每一天
去签到