feat: refine sftp pane upload workflow

This commit is contained in:
liumangmang
2026-04-22 17:59:07 +08:00
parent 423cca97a6
commit 165cc0e35b
10 changed files with 1188 additions and 74 deletions
@@ -40,6 +40,7 @@ import java.util.stream.Stream;
public class SftpController {
private static final Logger log = LoggerFactory.getLogger(SftpController.class);
private static final String UPLOAD_CONFLICT_CODE = "SFTP_UPLOAD_CONFLICT";
private final ConnectionService connectionService;
private final UserRepository userRepository;
@@ -292,11 +293,32 @@ public class SftpController {
public ResponseEntity<Map<String, Object>> upload(
@RequestParam Long connectionId,
@RequestParam String path,
@RequestParam(defaultValue = "false") boolean overwrite,
@RequestParam("file") MultipartFile file,
Authentication authentication) {
java.io.File tempFile = null;
try {
Long userId = getCurrentUserId(authentication);
String key = sessionKey(userId, connectionId);
String filename = file.getOriginalFilename() != null ? file.getOriginalFilename() : "upload.bin";
String remotePath = resolveUploadPath(path, filename);
UploadConflictInfo initialConflict = withSessionLock(key, () -> {
try {
SftpService.SftpSession session = getOrCreateSession(connectionId, userId);
return detectUploadConflict(session, remotePath, filename, overwrite);
} catch (Exception e) {
SftpService.SftpSession existing = sessions.remove(key);
if (existing != null) {
existing.disconnect();
}
throw new RuntimeException(e);
}
});
if (initialConflict != null) {
return buildUploadConflictResponse(initialConflict);
}
String taskId = UUID.randomUUID().toString();
String taskKey = uploadTaskKey(userId, taskId);
@@ -305,25 +327,25 @@ public class SftpController {
if (!uploadTempDir.exists() && !uploadTempDir.mkdirs()) {
throw new IOException("Failed to create upload temp directory: " + uploadTempDir.getAbsolutePath());
}
tempFile = new java.io.File(uploadTempDir, taskId + "_" + file.getOriginalFilename());
tempFile = new java.io.File(uploadTempDir, taskId + "_" + filename);
file.transferTo(tempFile);
final java.io.File savedFile = tempFile;
UploadTaskStatus status = new UploadTaskStatus(taskId, userId, connectionId,
path, file.getOriginalFilename(), file.getSize());
path, filename, file.getSize());
status.setController(this);
uploadTasks.put(taskKey, status);
Future<?> future = transferTaskExecutor.submit(() -> {
status.setStatus("running");
String key = sessionKey(userId, connectionId);
try {
withSessionLock(key, () -> {
try {
SftpService.SftpSession session = getOrCreateSession(connectionId, userId);
String remotePath = (path == null || path.isEmpty() || path.equals("/"))
? "/" + savedFile.getName().substring(savedFile.getName().indexOf("_") + 1)
: (path.endsWith("/") ? path + savedFile.getName().substring(savedFile.getName().indexOf("_") + 1) : path + "/" + savedFile.getName().substring(savedFile.getName().indexOf("_") + 1));
UploadConflictInfo conflict = detectUploadConflict(session, remotePath, filename, overwrite);
if (conflict != null) {
throw new IllegalStateException(conflict.message);
}
AtomicLong transferred = new AtomicLong(0);
try (java.io.InputStream in = new java.io.FileInputStream(savedFile)) {
@@ -383,6 +405,42 @@ public class SftpController {
}
}
private String resolveUploadPath(String path, String filename) {
if (path == null || path.isEmpty() || "/".equals(path)) {
return "/" + filename;
}
return path.endsWith("/") ? path + filename : path + "/" + filename;
}
private UploadConflictInfo detectUploadConflict(SftpService.SftpSession session,
String remotePath,
String filename,
boolean overwrite) throws Exception {
SftpService.PathInfo existing = sftpService.statIfExists(session, remotePath);
if (existing == null) {
return null;
}
if (overwrite && !existing.directory) {
return null;
}
boolean canOverwrite = !existing.directory;
String message = existing.directory
? "目标目录中已存在同名文件夹,无法覆盖。"
: "目标目录中已存在同名文件。";
return new UploadConflictInfo(filename, existing.directory ? "dir" : "file", canOverwrite, message);
}
private ResponseEntity<Map<String, Object>> buildUploadConflictResponse(UploadConflictInfo conflict) {
Map<String, Object> response = new HashMap<>();
response.put("code", UPLOAD_CONFLICT_CODE);
response.put("fileName", conflict.fileName);
response.put("conflictType", conflict.conflictType);
response.put("canOverwrite", conflict.canOverwrite);
response.put("message", conflict.message);
return ResponseEntity.status(409).body(response);
}
@DeleteMapping("/delete")
public ResponseEntity<Map<String, String>> delete(
@RequestParam Long connectionId,
@@ -739,6 +797,20 @@ public class SftpController {
private final SftpSessionExpiryCleanup cleanupTask = new SftpSessionExpiryCleanup();
private static class UploadConflictInfo {
private final String fileName;
private final String conflictType;
private final boolean canOverwrite;
private final String message;
private UploadConflictInfo(String fileName, String conflictType, boolean canOverwrite, String message) {
this.fileName = fileName;
this.conflictType = conflictType;
this.canOverwrite = canOverwrite;
this.message = message;
}
}
public static class SftpSessionExpiryCleanup {
private final Map<String, Long> lastAccessTime = new ConcurrentHashMap<>();
@@ -91,10 +91,10 @@ public class SftpService {
}
public static class FileInfo {
public String name;
public boolean directory;
public long size;
public long mtime;
public String name;
public boolean directory;
public long size;
public long mtime;
public FileInfo(String name, boolean directory, long size, long mtime) {
this.name = name;
@@ -104,16 +104,24 @@ public class SftpService {
}
}
public static class PathInfo {
public final boolean directory;
public PathInfo(boolean directory) {
this.directory = directory;
}
}
public interface TransferProgressListener {
void onStart(long totalBytes);
void onProgress(long transferredBytes, long totalBytes);
}
public List<FileInfo> listFiles(SftpSession sftpSession, String path) throws Exception {
String listPath = (path == null || path.trim().isEmpty()) ? "." : path.trim();
try {
Vector<?> entries = sftpSession.getChannel().ls(listPath);
public List<FileInfo> listFiles(SftpSession sftpSession, String path) throws Exception {
String listPath = (path == null || path.trim().isEmpty()) ? "." : path.trim();
try {
Vector<?> entries = sftpSession.getChannel().ls(listPath);
List<FileInfo> result = new ArrayList<>();
for (Object obj : entries) {
ChannelSftp.LsEntry entry = (ChannelSftp.LsEntry) obj;
@@ -297,4 +305,16 @@ public class SftpService {
}
}
}
public PathInfo statIfExists(SftpSession sftpSession, String path) throws Exception {
try {
SftpATTRS attrs = sftpSession.getChannel().stat(path);
return new PathInfo(attrs.isDir());
} catch (SftpException e) {
if (e.id == ChannelSftp.SSH_FX_NO_SUCH_FILE) {
return null;
}
throw e;
}
}
}
@@ -10,15 +10,24 @@ import com.sshmanager.service.SftpService;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.security.core.Authentication;
import org.springframework.test.util.ReflectionTestUtils;
import org.springframework.mock.web.MockMultipartFile;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.io.TempDir;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ExecutorService;
import java.util.stream.Stream;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
@@ -28,7 +37,9 @@ import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.lenient;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
@@ -45,15 +56,26 @@ class SftpControllerTest {
@Mock
private SftpService sftpService;
@InjectMocks
private SftpController sftpController;
@TempDir
Path tempDir;
@BeforeEach
void setUp() {
User user = new User();
user.setId(1L);
user.setUsername("alice");
when(userRepository.findByUsername("alice")).thenReturn(Optional.of(user));
sftpController = new SftpController(connectionService, userRepository, sftpService, tempDir.toString());
}
@AfterEach
void tearDown() {
ExecutorService executor = (ExecutorService) ReflectionTestUtils.getField(sftpController, "transferTaskExecutor");
if (executor != null) {
executor.shutdownNow();
}
}
@Test
@@ -133,17 +155,93 @@ class SftpControllerTest {
assertTrue(response.getBody().get("error").contains("boom"));
}
@Test
void uploadReturnsConflictWhenTargetFileExistsAndOverwriteDisabled() throws Exception {
when(connectionService.getConnectionForSsh(anyLong(), eq(1L))).thenReturn(new Connection());
SftpService.SftpSession session = connectedSession(true);
when(sftpService.connect(any(Connection.class), any(), any(), any())).thenReturn(session);
when(sftpService.statIfExists(session, "/uploads/demo.txt")).thenReturn(new SftpService.PathInfo(false));
ResponseEntity<Map<String, Object>> response = sftpController.upload(
7L,
"/uploads",
false,
uploadFile("demo.txt"),
authentication()
);
assertEquals(HttpStatus.CONFLICT, response.getStatusCode());
assertEquals("SFTP_UPLOAD_CONFLICT", response.getBody().get("code"));
assertEquals("demo.txt", response.getBody().get("fileName"));
assertEquals("file", response.getBody().get("conflictType"));
assertEquals(Boolean.TRUE, response.getBody().get("canOverwrite"));
verify(sftpService, never()).upload(any(), anyString(), any(InputStream.class), any(SftpService.TransferProgressListener.class));
try (Stream<Path> files = Files.list(tempDir)) {
assertEquals(0L, files.count());
}
}
@Test
void uploadAllowsOverwriteWhenTargetFileExistsAndOverwriteEnabled() throws Exception {
when(connectionService.getConnectionForSsh(anyLong(), eq(1L))).thenReturn(new Connection());
SftpService.SftpSession session = connectedSession(true);
when(sftpService.connect(any(Connection.class), any(), any(), any())).thenReturn(session);
when(sftpService.statIfExists(session, "/uploads/demo.txt")).thenReturn(new SftpService.PathInfo(false));
ResponseEntity<Map<String, Object>> response = sftpController.upload(
7L,
"/uploads",
true,
uploadFile("demo.txt"),
authentication()
);
assertEquals(HttpStatus.OK, response.getStatusCode());
assertTrue(response.getBody().containsKey("taskId"));
verify(sftpService, timeout(1000)).upload(
eq(session),
eq("/uploads/demo.txt"),
any(InputStream.class),
any(SftpService.TransferProgressListener.class)
);
}
@Test
void uploadReturnsConflictWhenTargetDirectoryExistsEvenWithOverwriteEnabled() throws Exception {
when(connectionService.getConnectionForSsh(anyLong(), eq(1L))).thenReturn(new Connection());
SftpService.SftpSession session = connectedSession(true);
when(sftpService.connect(any(Connection.class), any(), any(), any())).thenReturn(session);
when(sftpService.statIfExists(session, "/uploads/demo.txt")).thenReturn(new SftpService.PathInfo(true));
ResponseEntity<Map<String, Object>> response = sftpController.upload(
7L,
"/uploads",
true,
uploadFile("demo.txt"),
authentication()
);
assertEquals(HttpStatus.CONFLICT, response.getStatusCode());
assertEquals("dir", response.getBody().get("conflictType"));
assertEquals(Boolean.FALSE, response.getBody().get("canOverwrite"));
verify(sftpService, never()).upload(any(), anyString(), any(InputStream.class), any(SftpService.TransferProgressListener.class));
}
private Authentication authentication() {
Authentication authentication = mock(Authentication.class);
when(authentication.getName()).thenReturn("alice");
return authentication;
}
private MockMultipartFile uploadFile(String filename) {
return new MockMultipartFile("file", filename, "text/plain", "hello".getBytes(StandardCharsets.UTF_8));
}
private SftpService.SftpSession connectedSession(boolean connected) {
Session session = mock(Session.class);
ChannelSftp channel = mock(ChannelSftp.class);
if (connected) {
when(channel.isConnected()).thenReturn(true);
lenient().when(channel.isConnected()).thenReturn(true);
}
return new SftpService.SftpSession(session, channel);
}
@@ -1,5 +1,9 @@
package com.sshmanager.service;
import com.jcraft.jsch.ChannelSftp;
import com.jcraft.jsch.Session;
import com.jcraft.jsch.SftpATTRS;
import com.jcraft.jsch.SftpException;
import com.sshmanager.entity.Connection;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
@@ -8,6 +12,8 @@ import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
class SftpServiceTest {
@@ -48,4 +54,29 @@ class SftpServiceTest {
executorService.shutdown();
assertTrue(executorService.isTerminated() || executorService.isShutdown());
}
@Test
void statIfExistsReturnsNullWhenRemotePathIsMissing() throws Exception {
Session session = mock(Session.class);
ChannelSftp channel = mock(ChannelSftp.class);
when(channel.stat("/missing.txt")).thenThrow(new SftpException(ChannelSftp.SSH_FX_NO_SUCH_FILE, "missing"));
SftpService.PathInfo result = sftpService.statIfExists(new SftpService.SftpSession(session, channel), "/missing.txt");
assertNull(result);
}
@Test
void statIfExistsReturnsDirectoryFlagForExistingPath() throws Exception {
Session session = mock(Session.class);
ChannelSftp channel = mock(ChannelSftp.class);
SftpATTRS attrs = mock(SftpATTRS.class);
when(channel.stat("/existing")).thenReturn(attrs);
when(attrs.isDir()).thenReturn(true);
SftpService.PathInfo result = sftpService.statIfExists(new SftpService.SftpSession(session, channel), "/existing");
assertNotNull(result);
assertTrue(result.directory);
}
}