feat: refine sftp pane upload workflow
This commit is contained in:
@@ -40,6 +40,7 @@ import java.util.stream.Stream;
|
||||
public class SftpController {
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(SftpController.class);
|
||||
private static final String UPLOAD_CONFLICT_CODE = "SFTP_UPLOAD_CONFLICT";
|
||||
|
||||
private final ConnectionService connectionService;
|
||||
private final UserRepository userRepository;
|
||||
@@ -292,11 +293,32 @@ public class SftpController {
|
||||
public ResponseEntity<Map<String, Object>> upload(
|
||||
@RequestParam Long connectionId,
|
||||
@RequestParam String path,
|
||||
@RequestParam(defaultValue = "false") boolean overwrite,
|
||||
@RequestParam("file") MultipartFile file,
|
||||
Authentication authentication) {
|
||||
java.io.File tempFile = null;
|
||||
try {
|
||||
Long userId = getCurrentUserId(authentication);
|
||||
String key = sessionKey(userId, connectionId);
|
||||
String filename = file.getOriginalFilename() != null ? file.getOriginalFilename() : "upload.bin";
|
||||
String remotePath = resolveUploadPath(path, filename);
|
||||
|
||||
UploadConflictInfo initialConflict = withSessionLock(key, () -> {
|
||||
try {
|
||||
SftpService.SftpSession session = getOrCreateSession(connectionId, userId);
|
||||
return detectUploadConflict(session, remotePath, filename, overwrite);
|
||||
} catch (Exception e) {
|
||||
SftpService.SftpSession existing = sessions.remove(key);
|
||||
if (existing != null) {
|
||||
existing.disconnect();
|
||||
}
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
});
|
||||
if (initialConflict != null) {
|
||||
return buildUploadConflictResponse(initialConflict);
|
||||
}
|
||||
|
||||
String taskId = UUID.randomUUID().toString();
|
||||
String taskKey = uploadTaskKey(userId, taskId);
|
||||
|
||||
@@ -305,25 +327,25 @@ public class SftpController {
|
||||
if (!uploadTempDir.exists() && !uploadTempDir.mkdirs()) {
|
||||
throw new IOException("Failed to create upload temp directory: " + uploadTempDir.getAbsolutePath());
|
||||
}
|
||||
tempFile = new java.io.File(uploadTempDir, taskId + "_" + file.getOriginalFilename());
|
||||
tempFile = new java.io.File(uploadTempDir, taskId + "_" + filename);
|
||||
file.transferTo(tempFile);
|
||||
final java.io.File savedFile = tempFile;
|
||||
|
||||
UploadTaskStatus status = new UploadTaskStatus(taskId, userId, connectionId,
|
||||
path, file.getOriginalFilename(), file.getSize());
|
||||
path, filename, file.getSize());
|
||||
status.setController(this);
|
||||
uploadTasks.put(taskKey, status);
|
||||
|
||||
Future<?> future = transferTaskExecutor.submit(() -> {
|
||||
status.setStatus("running");
|
||||
String key = sessionKey(userId, connectionId);
|
||||
try {
|
||||
withSessionLock(key, () -> {
|
||||
try {
|
||||
SftpService.SftpSession session = getOrCreateSession(connectionId, userId);
|
||||
String remotePath = (path == null || path.isEmpty() || path.equals("/"))
|
||||
? "/" + savedFile.getName().substring(savedFile.getName().indexOf("_") + 1)
|
||||
: (path.endsWith("/") ? path + savedFile.getName().substring(savedFile.getName().indexOf("_") + 1) : path + "/" + savedFile.getName().substring(savedFile.getName().indexOf("_") + 1));
|
||||
UploadConflictInfo conflict = detectUploadConflict(session, remotePath, filename, overwrite);
|
||||
if (conflict != null) {
|
||||
throw new IllegalStateException(conflict.message);
|
||||
}
|
||||
|
||||
AtomicLong transferred = new AtomicLong(0);
|
||||
try (java.io.InputStream in = new java.io.FileInputStream(savedFile)) {
|
||||
@@ -383,6 +405,42 @@ public class SftpController {
|
||||
}
|
||||
}
|
||||
|
||||
private String resolveUploadPath(String path, String filename) {
|
||||
if (path == null || path.isEmpty() || "/".equals(path)) {
|
||||
return "/" + filename;
|
||||
}
|
||||
return path.endsWith("/") ? path + filename : path + "/" + filename;
|
||||
}
|
||||
|
||||
private UploadConflictInfo detectUploadConflict(SftpService.SftpSession session,
|
||||
String remotePath,
|
||||
String filename,
|
||||
boolean overwrite) throws Exception {
|
||||
SftpService.PathInfo existing = sftpService.statIfExists(session, remotePath);
|
||||
if (existing == null) {
|
||||
return null;
|
||||
}
|
||||
if (overwrite && !existing.directory) {
|
||||
return null;
|
||||
}
|
||||
|
||||
boolean canOverwrite = !existing.directory;
|
||||
String message = existing.directory
|
||||
? "目标目录中已存在同名文件夹,无法覆盖。"
|
||||
: "目标目录中已存在同名文件。";
|
||||
return new UploadConflictInfo(filename, existing.directory ? "dir" : "file", canOverwrite, message);
|
||||
}
|
||||
|
||||
private ResponseEntity<Map<String, Object>> buildUploadConflictResponse(UploadConflictInfo conflict) {
|
||||
Map<String, Object> response = new HashMap<>();
|
||||
response.put("code", UPLOAD_CONFLICT_CODE);
|
||||
response.put("fileName", conflict.fileName);
|
||||
response.put("conflictType", conflict.conflictType);
|
||||
response.put("canOverwrite", conflict.canOverwrite);
|
||||
response.put("message", conflict.message);
|
||||
return ResponseEntity.status(409).body(response);
|
||||
}
|
||||
|
||||
@DeleteMapping("/delete")
|
||||
public ResponseEntity<Map<String, String>> delete(
|
||||
@RequestParam Long connectionId,
|
||||
@@ -739,6 +797,20 @@ public class SftpController {
|
||||
|
||||
private final SftpSessionExpiryCleanup cleanupTask = new SftpSessionExpiryCleanup();
|
||||
|
||||
private static class UploadConflictInfo {
|
||||
private final String fileName;
|
||||
private final String conflictType;
|
||||
private final boolean canOverwrite;
|
||||
private final String message;
|
||||
|
||||
private UploadConflictInfo(String fileName, String conflictType, boolean canOverwrite, String message) {
|
||||
this.fileName = fileName;
|
||||
this.conflictType = conflictType;
|
||||
this.canOverwrite = canOverwrite;
|
||||
this.message = message;
|
||||
}
|
||||
}
|
||||
|
||||
public static class SftpSessionExpiryCleanup {
|
||||
private final Map<String, Long> lastAccessTime = new ConcurrentHashMap<>();
|
||||
|
||||
|
||||
@@ -91,10 +91,10 @@ public class SftpService {
|
||||
}
|
||||
|
||||
public static class FileInfo {
|
||||
public String name;
|
||||
public boolean directory;
|
||||
public long size;
|
||||
public long mtime;
|
||||
public String name;
|
||||
public boolean directory;
|
||||
public long size;
|
||||
public long mtime;
|
||||
|
||||
public FileInfo(String name, boolean directory, long size, long mtime) {
|
||||
this.name = name;
|
||||
@@ -104,16 +104,24 @@ public class SftpService {
|
||||
}
|
||||
}
|
||||
|
||||
public static class PathInfo {
|
||||
public final boolean directory;
|
||||
|
||||
public PathInfo(boolean directory) {
|
||||
this.directory = directory;
|
||||
}
|
||||
}
|
||||
|
||||
public interface TransferProgressListener {
|
||||
void onStart(long totalBytes);
|
||||
|
||||
void onProgress(long transferredBytes, long totalBytes);
|
||||
}
|
||||
|
||||
public List<FileInfo> listFiles(SftpSession sftpSession, String path) throws Exception {
|
||||
String listPath = (path == null || path.trim().isEmpty()) ? "." : path.trim();
|
||||
try {
|
||||
Vector<?> entries = sftpSession.getChannel().ls(listPath);
|
||||
public List<FileInfo> listFiles(SftpSession sftpSession, String path) throws Exception {
|
||||
String listPath = (path == null || path.trim().isEmpty()) ? "." : path.trim();
|
||||
try {
|
||||
Vector<?> entries = sftpSession.getChannel().ls(listPath);
|
||||
List<FileInfo> result = new ArrayList<>();
|
||||
for (Object obj : entries) {
|
||||
ChannelSftp.LsEntry entry = (ChannelSftp.LsEntry) obj;
|
||||
@@ -297,4 +305,16 @@ public class SftpService {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public PathInfo statIfExists(SftpSession sftpSession, String path) throws Exception {
|
||||
try {
|
||||
SftpATTRS attrs = sftpSession.getChannel().stat(path);
|
||||
return new PathInfo(attrs.isDir());
|
||||
} catch (SftpException e) {
|
||||
if (e.id == ChannelSftp.SSH_FX_NO_SUCH_FILE) {
|
||||
return null;
|
||||
}
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,15 +10,24 @@ import com.sshmanager.service.SftpService;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.mockito.InjectMocks;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
import org.springframework.http.HttpStatus;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.security.core.Authentication;
|
||||
import org.springframework.test.util.ReflectionTestUtils;
|
||||
import org.springframework.mock.web.MockMultipartFile;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
|
||||
import java.io.InputStream;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
@@ -28,7 +37,9 @@ import static org.mockito.ArgumentMatchers.anyString;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.doThrow;
|
||||
import static org.mockito.Mockito.lenient;
|
||||
import static org.mockito.Mockito.never;
|
||||
import static org.mockito.Mockito.timeout;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.verifyNoInteractions;
|
||||
import static org.mockito.Mockito.when;
|
||||
@@ -45,15 +56,26 @@ class SftpControllerTest {
|
||||
@Mock
|
||||
private SftpService sftpService;
|
||||
|
||||
@InjectMocks
|
||||
private SftpController sftpController;
|
||||
|
||||
@TempDir
|
||||
Path tempDir;
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
User user = new User();
|
||||
user.setId(1L);
|
||||
user.setUsername("alice");
|
||||
when(userRepository.findByUsername("alice")).thenReturn(Optional.of(user));
|
||||
sftpController = new SftpController(connectionService, userRepository, sftpService, tempDir.toString());
|
||||
}
|
||||
|
||||
@AfterEach
|
||||
void tearDown() {
|
||||
ExecutorService executor = (ExecutorService) ReflectionTestUtils.getField(sftpController, "transferTaskExecutor");
|
||||
if (executor != null) {
|
||||
executor.shutdownNow();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -133,17 +155,93 @@ class SftpControllerTest {
|
||||
assertTrue(response.getBody().get("error").contains("boom"));
|
||||
}
|
||||
|
||||
@Test
|
||||
void uploadReturnsConflictWhenTargetFileExistsAndOverwriteDisabled() throws Exception {
|
||||
when(connectionService.getConnectionForSsh(anyLong(), eq(1L))).thenReturn(new Connection());
|
||||
SftpService.SftpSession session = connectedSession(true);
|
||||
when(sftpService.connect(any(Connection.class), any(), any(), any())).thenReturn(session);
|
||||
when(sftpService.statIfExists(session, "/uploads/demo.txt")).thenReturn(new SftpService.PathInfo(false));
|
||||
|
||||
ResponseEntity<Map<String, Object>> response = sftpController.upload(
|
||||
7L,
|
||||
"/uploads",
|
||||
false,
|
||||
uploadFile("demo.txt"),
|
||||
authentication()
|
||||
);
|
||||
|
||||
assertEquals(HttpStatus.CONFLICT, response.getStatusCode());
|
||||
assertEquals("SFTP_UPLOAD_CONFLICT", response.getBody().get("code"));
|
||||
assertEquals("demo.txt", response.getBody().get("fileName"));
|
||||
assertEquals("file", response.getBody().get("conflictType"));
|
||||
assertEquals(Boolean.TRUE, response.getBody().get("canOverwrite"));
|
||||
verify(sftpService, never()).upload(any(), anyString(), any(InputStream.class), any(SftpService.TransferProgressListener.class));
|
||||
try (Stream<Path> files = Files.list(tempDir)) {
|
||||
assertEquals(0L, files.count());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void uploadAllowsOverwriteWhenTargetFileExistsAndOverwriteEnabled() throws Exception {
|
||||
when(connectionService.getConnectionForSsh(anyLong(), eq(1L))).thenReturn(new Connection());
|
||||
SftpService.SftpSession session = connectedSession(true);
|
||||
when(sftpService.connect(any(Connection.class), any(), any(), any())).thenReturn(session);
|
||||
when(sftpService.statIfExists(session, "/uploads/demo.txt")).thenReturn(new SftpService.PathInfo(false));
|
||||
|
||||
ResponseEntity<Map<String, Object>> response = sftpController.upload(
|
||||
7L,
|
||||
"/uploads",
|
||||
true,
|
||||
uploadFile("demo.txt"),
|
||||
authentication()
|
||||
);
|
||||
|
||||
assertEquals(HttpStatus.OK, response.getStatusCode());
|
||||
assertTrue(response.getBody().containsKey("taskId"));
|
||||
verify(sftpService, timeout(1000)).upload(
|
||||
eq(session),
|
||||
eq("/uploads/demo.txt"),
|
||||
any(InputStream.class),
|
||||
any(SftpService.TransferProgressListener.class)
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
void uploadReturnsConflictWhenTargetDirectoryExistsEvenWithOverwriteEnabled() throws Exception {
|
||||
when(connectionService.getConnectionForSsh(anyLong(), eq(1L))).thenReturn(new Connection());
|
||||
SftpService.SftpSession session = connectedSession(true);
|
||||
when(sftpService.connect(any(Connection.class), any(), any(), any())).thenReturn(session);
|
||||
when(sftpService.statIfExists(session, "/uploads/demo.txt")).thenReturn(new SftpService.PathInfo(true));
|
||||
|
||||
ResponseEntity<Map<String, Object>> response = sftpController.upload(
|
||||
7L,
|
||||
"/uploads",
|
||||
true,
|
||||
uploadFile("demo.txt"),
|
||||
authentication()
|
||||
);
|
||||
|
||||
assertEquals(HttpStatus.CONFLICT, response.getStatusCode());
|
||||
assertEquals("dir", response.getBody().get("conflictType"));
|
||||
assertEquals(Boolean.FALSE, response.getBody().get("canOverwrite"));
|
||||
verify(sftpService, never()).upload(any(), anyString(), any(InputStream.class), any(SftpService.TransferProgressListener.class));
|
||||
}
|
||||
|
||||
private Authentication authentication() {
|
||||
Authentication authentication = mock(Authentication.class);
|
||||
when(authentication.getName()).thenReturn("alice");
|
||||
return authentication;
|
||||
}
|
||||
|
||||
private MockMultipartFile uploadFile(String filename) {
|
||||
return new MockMultipartFile("file", filename, "text/plain", "hello".getBytes(StandardCharsets.UTF_8));
|
||||
}
|
||||
|
||||
private SftpService.SftpSession connectedSession(boolean connected) {
|
||||
Session session = mock(Session.class);
|
||||
ChannelSftp channel = mock(ChannelSftp.class);
|
||||
if (connected) {
|
||||
when(channel.isConnected()).thenReturn(true);
|
||||
lenient().when(channel.isConnected()).thenReturn(true);
|
||||
}
|
||||
return new SftpService.SftpSession(session, channel);
|
||||
}
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
package com.sshmanager.service;
|
||||
|
||||
import com.jcraft.jsch.ChannelSftp;
|
||||
import com.jcraft.jsch.Session;
|
||||
import com.jcraft.jsch.SftpATTRS;
|
||||
import com.jcraft.jsch.SftpException;
|
||||
import com.sshmanager.entity.Connection;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
@@ -8,6 +12,8 @@ import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
class SftpServiceTest {
|
||||
|
||||
@@ -48,4 +54,29 @@ class SftpServiceTest {
|
||||
executorService.shutdown();
|
||||
assertTrue(executorService.isTerminated() || executorService.isShutdown());
|
||||
}
|
||||
|
||||
@Test
|
||||
void statIfExistsReturnsNullWhenRemotePathIsMissing() throws Exception {
|
||||
Session session = mock(Session.class);
|
||||
ChannelSftp channel = mock(ChannelSftp.class);
|
||||
when(channel.stat("/missing.txt")).thenThrow(new SftpException(ChannelSftp.SSH_FX_NO_SUCH_FILE, "missing"));
|
||||
|
||||
SftpService.PathInfo result = sftpService.statIfExists(new SftpService.SftpSession(session, channel), "/missing.txt");
|
||||
|
||||
assertNull(result);
|
||||
}
|
||||
|
||||
@Test
|
||||
void statIfExistsReturnsDirectoryFlagForExistingPath() throws Exception {
|
||||
Session session = mock(Session.class);
|
||||
ChannelSftp channel = mock(ChannelSftp.class);
|
||||
SftpATTRS attrs = mock(SftpATTRS.class);
|
||||
when(channel.stat("/existing")).thenReturn(attrs);
|
||||
when(attrs.isDir()).thenReturn(true);
|
||||
|
||||
SftpService.PathInfo result = sftpService.statIfExists(new SftpService.SftpSession(session, channel), "/existing");
|
||||
|
||||
assertNotNull(result);
|
||||
assertTrue(result.directory);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user