Please provide the code changes or file diffs you would like me to summarize.

This commit is contained in:
liumangmang
2026-05-07 14:12:54 +08:00
parent 3f0ebe24e0
commit 2d9011b606
17 changed files with 375 additions and 379 deletions
@@ -17,8 +17,6 @@ import org.springframework.web.cors.CorsConfigurationSource;
import org.springframework.web.cors.UrlBasedCorsConfigurationSource; import org.springframework.web.cors.UrlBasedCorsConfigurationSource;
import java.util.Arrays; import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;
@Configuration @Configuration
@EnableWebSecurity @EnableWebSecurity
@@ -2,6 +2,7 @@ package com.sshmanager.config;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;
import org.springframework.core.io.Resource; import org.springframework.core.io.Resource;
import org.springframework.lang.NonNull;
import org.springframework.web.servlet.config.annotation.ResourceHandlerRegistry; import org.springframework.web.servlet.config.annotation.ResourceHandlerRegistry;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer; import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;
import org.springframework.web.servlet.resource.PathResourceResolver; import org.springframework.web.servlet.resource.PathResourceResolver;
@@ -15,13 +16,13 @@ import java.io.IOException;
public class SpaForwardConfig implements WebMvcConfigurer { public class SpaForwardConfig implements WebMvcConfigurer {
@Override @Override
public void addResourceHandlers(ResourceHandlerRegistry registry) { public void addResourceHandlers(@NonNull ResourceHandlerRegistry registry) {
registry.addResourceHandler("/**") registry.addResourceHandler("/**")
.addResourceLocations("classpath:/static/") .addResourceLocations("classpath:/static/")
.resourceChain(true) .resourceChain(true)
.addResolver(new PathResourceResolver() { .addResolver(new PathResourceResolver() {
@Override @Override
protected Resource getResource(String path, Resource location) throws IOException { protected Resource getResource(@NonNull String path, @NonNull Resource location) throws IOException {
Resource resource = location.createRelative(path); Resource resource = location.createRelative(path);
if (resource.exists() && resource.isReadable()) { if (resource.exists() && resource.isReadable()) {
return resource; return resource;
@@ -9,6 +9,8 @@ import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.HandshakeInterceptor; import org.springframework.web.socket.server.HandshakeInterceptor;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import org.springframework.lang.NonNull;
import org.springframework.lang.Nullable;
import java.util.Map; import java.util.Map;
@Component @Component
@@ -21,8 +23,8 @@ public class TerminalHandshakeInterceptor implements HandshakeInterceptor {
} }
@Override @Override
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, public boolean beforeHandshake(@NonNull ServerHttpRequest request, @NonNull ServerHttpResponse response,
WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception { @NonNull WebSocketHandler wsHandler, @NonNull Map<String, Object> attributes) throws Exception {
if (!(request instanceof ServletServerHttpRequest)) { if (!(request instanceof ServletServerHttpRequest)) {
return false; return false;
} }
@@ -50,7 +52,7 @@ public class TerminalHandshakeInterceptor implements HandshakeInterceptor {
} }
@Override @Override
public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, public void afterHandshake(@NonNull ServerHttpRequest request, @NonNull ServerHttpResponse response,
WebSocketHandler wsHandler, Exception exception) { @NonNull WebSocketHandler wsHandler, @Nullable Exception exception) {
} }
} }
@@ -5,6 +5,7 @@ import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.config.annotation.EnableWebSocket; import org.springframework.web.socket.config.annotation.EnableWebSocket;
import org.springframework.web.socket.config.annotation.WebSocketConfigurer; import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry; import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;
import org.springframework.lang.NonNull;
@Configuration @Configuration
@EnableWebSocket @EnableWebSocket
@@ -20,7 +21,8 @@ public class WebSocketConfig implements WebSocketConfigurer {
} }
@Override @Override
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) { @SuppressWarnings("null")
public void registerWebSocketHandlers(@NonNull WebSocketHandlerRegistry registry) {
registry.addHandler(terminalWebSocketHandler, "/ws/terminal") registry.addHandler(terminalWebSocketHandler, "/ws/terminal")
.addInterceptors(terminalHandshakeInterceptor) .addInterceptors(terminalHandshakeInterceptor)
// Docker/remote deployments often use non-localhost origins. // Docker/remote deployments often use non-localhost origins.
@@ -14,7 +14,6 @@ import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType; import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity; import org.springframework.http.ResponseEntity;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.stereotype.Component;
import org.springframework.web.bind.annotation.*; import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile; import org.springframework.web.multipart.MultipartFile;
import org.springframework.web.servlet.mvc.method.annotation.StreamingResponseBody; import org.springframework.web.servlet.mvc.method.annotation.StreamingResponseBody;
@@ -33,7 +32,6 @@ import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Supplier; import java.util.function.Supplier;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.Stream;
@RestController @RestController
@RequestMapping("/api/sftp") @RequestMapping("/api/sftp")
@@ -282,7 +280,7 @@ public class SftpController {
return ResponseEntity.ok() return ResponseEntity.ok()
.header(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=\"" + filename + "\"") .header(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=\"" + filename + "\"")
.contentType(MediaType.APPLICATION_OCTET_STREAM) .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_OCTET_STREAM_VALUE)
.body(stream); .body(stream);
} catch (Exception e) { } catch (Exception e) {
return ResponseEntity.status(500).build(); return ResponseEntity.status(500).build();
@@ -331,12 +329,11 @@ public class SftpController {
file.transferTo(tempFile); file.transferTo(tempFile);
final java.io.File savedFile = tempFile; final java.io.File savedFile = tempFile;
UploadTaskStatus status = new UploadTaskStatus(taskId, userId, connectionId, UploadTaskStatus status = new UploadTaskStatus(taskId, userId, filename, file.getSize());
path, filename, file.getSize());
status.setController(this); status.setController(this);
uploadTasks.put(taskKey, status); uploadTasks.put(taskKey, status);
Future<?> future = transferTaskExecutor.submit(() -> { transferTaskExecutor.submit(() -> {
status.setStatus("running"); status.setStatus("running");
try { try {
withSessionLock(key, () -> { withSessionLock(key, () -> {
@@ -387,8 +384,6 @@ public class SftpController {
} }
} }
}); });
status.setFuture(future);
Map<String, Object> result = new HashMap<>(); Map<String, Object> result = new HashMap<>();
result.put("taskId", taskId); result.put("taskId", taskId);
result.put("message", "Upload started"); result.put("message", "Upload started");
@@ -654,6 +649,7 @@ public class SftpController {
} }
@GetMapping("/transfer-remote/tasks/{taskId}/progress") @GetMapping("/transfer-remote/tasks/{taskId}/progress")
@SuppressWarnings("null")
public SseEmitter streamTransferProgress( public SseEmitter streamTransferProgress(
@PathVariable String taskId, @PathVariable String taskId,
Authentication authentication) { Authentication authentication) {
@@ -692,6 +688,7 @@ public class SftpController {
} }
@GetMapping("/upload/tasks/{taskId}/progress") @GetMapping("/upload/tasks/{taskId}/progress")
@SuppressWarnings("null")
public SseEmitter streamUploadProgress( public SseEmitter streamUploadProgress(
@PathVariable String taskId, @PathVariable String taskId,
Authentication authentication) { Authentication authentication) {
@@ -739,6 +736,7 @@ public class SftpController {
} }
} }
@SuppressWarnings("null")
private void broadcastProgress(String taskKey, Map<String, Object> data) { private void broadcastProgress(String taskKey, Map<String, Object> data) {
CopyOnWriteArrayList<SseEmitter> emitters = taskEmitters.get(taskKey); CopyOnWriteArrayList<SseEmitter> emitters = taskEmitters.get(taskKey);
if (emitters == null || emitters.isEmpty()) { if (emitters == null || emitters.isEmpty()) {
@@ -976,10 +974,7 @@ public class SftpController {
public static class UploadTaskStatus { public static class UploadTaskStatus {
private final String taskId; private final String taskId;
private final Long userId; private final Long userId;
private final Long connectionId;
private final String path;
private final String filename; private final String filename;
private final long fileSize;
private final long createdAt; private final long createdAt;
private volatile String status; private volatile String status;
private volatile String error; private volatile String error;
@@ -987,17 +982,12 @@ public class SftpController {
private volatile long finishedAt; private volatile long finishedAt;
private final AtomicLong totalBytes; private final AtomicLong totalBytes;
private final AtomicLong transferredBytes; private final AtomicLong transferredBytes;
private volatile Future<?> future;
private volatile SftpController controller; private volatile SftpController controller;
public UploadTaskStatus(String taskId, Long userId, Long connectionId, public UploadTaskStatus(String taskId, Long userId, String filename, long fileSize) {
String path, String filename, long fileSize) {
this.taskId = taskId; this.taskId = taskId;
this.userId = userId; this.userId = userId;
this.connectionId = connectionId;
this.path = path;
this.filename = filename; this.filename = filename;
this.fileSize = fileSize;
this.createdAt = System.currentTimeMillis(); this.createdAt = System.currentTimeMillis();
this.status = "queued"; this.status = "queued";
this.totalBytes = new AtomicLong(fileSize); this.totalBytes = new AtomicLong(fileSize);
@@ -1012,9 +1002,7 @@ public class SftpController {
this.controller = controller; this.controller = controller;
} }
public void setFuture(Future<?> future) {
this.future = future;
}
public void setStatus(String status) { public void setStatus(String status) {
this.status = status; this.status = status;
@@ -12,6 +12,7 @@ import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.TextWebSocketHandler; import org.springframework.web.socket.handler.TextWebSocketHandler;
import org.springframework.lang.NonNull;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
@@ -46,7 +47,8 @@ public class TerminalWebSocketHandler extends TextWebSocketHandler {
} }
@Override @Override
public void afterConnectionEstablished(WebSocketSession webSocketSession) throws Exception { @SuppressWarnings("null")
public void afterConnectionEstablished(@NonNull WebSocketSession webSocketSession) throws Exception {
Long connectionId = (Long) webSocketSession.getAttributes().get("connectionId"); Long connectionId = (Long) webSocketSession.getAttributes().get("connectionId");
String username = (String) webSocketSession.getAttributes().get("username"); String username = (String) webSocketSession.getAttributes().get("username");
@@ -101,7 +103,7 @@ public class TerminalWebSocketHandler extends TextWebSocketHandler {
} }
@Override @Override
protected void handleTextMessage(WebSocketSession webSocketSession, TextMessage message) throws Exception { protected void handleTextMessage(@NonNull WebSocketSession webSocketSession, @NonNull TextMessage message) throws Exception {
SshService.SshSession sshSession = sessions.get(webSocketSession.getId()); SshService.SshSession sshSession = sessions.get(webSocketSession.getId());
if (sshSession != null && sshSession.isConnected()) { if (sshSession != null && sshSession.isConnected()) {
lastActivity.put(webSocketSession.getId(), System.currentTimeMillis()); lastActivity.put(webSocketSession.getId(), System.currentTimeMillis());
@@ -122,7 +124,7 @@ public class TerminalWebSocketHandler extends TextWebSocketHandler {
} }
@Override @Override
public void afterConnectionClosed(WebSocketSession webSocketSession, CloseStatus status) throws Exception { public void afterConnectionClosed(@NonNull WebSocketSession webSocketSession, @NonNull CloseStatus status) throws Exception {
SshService.SshSession sshSession = sessions.remove(webSocketSession.getId()); SshService.SshSession sshSession = sessions.remove(webSocketSession.getId());
lastActivity.remove(webSocketSession.getId()); lastActivity.remove(webSocketSession.getId());
if (sshSession != null) { if (sshSession != null) {
@@ -6,7 +6,6 @@ import lombok.AllArgsConstructor;
import javax.persistence.*; import javax.persistence.*;
import java.time.Instant; import java.time.Instant;
import java.time.temporal.ChronoUnit;
@Data @Data
@NoArgsConstructor @NoArgsConstructor
@@ -8,6 +8,7 @@ import org.springframework.security.web.authentication.WebAuthenticationDetailsS
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
import org.springframework.web.filter.OncePerRequestFilter; import org.springframework.web.filter.OncePerRequestFilter;
import org.springframework.lang.NonNull;
import javax.servlet.FilterChain; import javax.servlet.FilterChain;
import javax.servlet.ServletException; import javax.servlet.ServletException;
@@ -27,8 +28,8 @@ public class JwtAuthenticationFilter extends OncePerRequestFilter {
} }
@Override @Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, protected void doFilterInternal(@NonNull HttpServletRequest request, @NonNull HttpServletResponse response,
FilterChain filterChain) throws ServletException, IOException { @NonNull FilterChain filterChain) throws ServletException, IOException {
try { try {
String jwt = getJwtFromRequest(request); String jwt = getJwtFromRequest(request);
if (StringUtils.hasText(jwt) && tokenProvider.validateToken(jwt)) { if (StringUtils.hasText(jwt) && tokenProvider.validateToken(jwt)) {
@@ -7,6 +7,7 @@ import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import org.springframework.web.filter.OncePerRequestFilter; import org.springframework.web.filter.OncePerRequestFilter;
import org.springframework.lang.NonNull;
import javax.servlet.FilterChain; import javax.servlet.FilterChain;
import javax.servlet.ServletException; import javax.servlet.ServletException;
@@ -29,7 +30,7 @@ public class PasswordExpirationFilter extends OncePerRequestFilter {
} }
@Override @Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) protected void doFilterInternal(@NonNull HttpServletRequest request, @NonNull HttpServletResponse response, @NonNull FilterChain filterChain)
throws ServletException, IOException { throws ServletException, IOException {
Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
if (authentication != null && authentication.isAuthenticated()) { if (authentication != null && authentication.isAuthenticated()) {
@@ -59,6 +59,7 @@ public class BackupService {
} }
@Transactional @Transactional
@SuppressWarnings("null")
public BackupImportResponseDto importBackup(Long userId, BackupPackageDto backupPackage) { public BackupImportResponseDto importBackup(Long userId, BackupPackageDto backupPackage) {
if (backupPackage == null) { if (backupPackage == null) {
throw new IllegalArgumentException("Backup package is required"); throw new IllegalArgumentException("Backup package is required");
@@ -15,6 +15,7 @@ import java.util.List;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@Service @Service
@SuppressWarnings("null")
public class ConnectionService { public class ConnectionService {
private final ConnectionRepository connectionRepository; private final ConnectionRepository connectionRepository;
@@ -28,17 +28,13 @@ import java.util.Optional;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class) @ExtendWith(MockitoExtension.class)
@SuppressWarnings("null")
class ConnectionControllerTest { class ConnectionControllerTest {
@Mock @Mock
@@ -45,6 +45,7 @@ import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class) @ExtendWith(MockitoExtension.class)
@SuppressWarnings("null")
class SftpControllerTest { class SftpControllerTest {
@Mock @Mock
@@ -26,6 +26,7 @@ import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class) @ExtendWith(MockitoExtension.class)
@SuppressWarnings("null")
class BackupServiceTest { class BackupServiceTest {
@Mock @Mock
@@ -27,6 +27,7 @@ import static org.mockito.Mockito.when;
import static org.mockito.Mockito.lenient; import static org.mockito.Mockito.lenient;
@ExtendWith(MockitoExtension.class) @ExtendWith(MockitoExtension.class)
@SuppressWarnings("null")
class ConnectionServiceTest { class ConnectionServiceTest {
@Mock @Mock
@@ -23,6 +23,7 @@ import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class) @ExtendWith(MockitoExtension.class)
@SuppressWarnings("null")
class SessionTreeLayoutServiceTest { class SessionTreeLayoutServiceTest {
@Mock @Mock
@@ -110,12 +110,12 @@ export default function TransferCenterModal({
const unsubscribe = subscribeUploadProgress(taskId, (task) => { const unsubscribe = subscribeUploadProgress(taskId, (task) => {
updateTaskGroup(groupId, (current) => { updateTaskGroup(groupId, (current) => {
const nextItems = current.items.map((item) => const nextItems = current.items.map((item) =>
item.taskId === task.taskId item.taskId === task.taskId
? { ? {
...item, ...item,
progress: task.progress, progress: task.progress,
status: task.status, status: task.status,
message: task.error || (task.status === 'success' ? '上传完成' : task.status === 'error' ? '上传失败' : task.status === 'cancelled' ? '已取消' : '正在传输...'), message: task.error || (task.status === 'success' ? '上传完成' : task.status === 'error' ? '上传失败' : '正在传输...'),
} }
: item, : item,
) )
@@ -231,9 +231,9 @@ export default function TransferCenterModal({
<div className="rounded-xl border border-slate-700 bg-black px-4 py-3 text-sm text-slate-400"></div> <div className="rounded-xl border border-slate-700 bg-black px-4 py-3 text-sm text-slate-400"></div>
</div> </div>
</div> </div>
<div className="flex-1 space-y-2"> <div className="flex min-h-0 flex-1 flex-col space-y-2">
<span className="text-sm text-slate-300">3. </span> <span className="shrink-0 text-sm text-slate-300">3. </span>
<div className="rounded-2xl border border-slate-700 bg-black p-2"> <div className="min-h-0 flex-1 overflow-y-auto rounded-2xl border border-slate-700 bg-black p-2">
{connections.map((server) => { {connections.map((server) => {
const st = connectionStatuses[server.id] const st = connectionStatuses[server.id]
const isOnline = st === 'online' const isOnline = st === 'online'