Improve SFTP reliability and credential hygiene with regression tests

This commit is contained in:
2026-03-09 00:02:22 +08:00
parent a61a88f36b
commit a10906d711
7 changed files with 571 additions and 227 deletions

View File

@@ -9,12 +9,13 @@ import com.sshmanager.service.ConnectionService;
import com.sshmanager.service.SftpService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.security.core.Authentication;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.security.core.Authentication;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;
import org.springframework.web.servlet.mvc.method.annotation.StreamingResponseBody;
import java.util.HashMap;
import java.util.List;
@@ -57,12 +58,27 @@ public class SftpController {
return userId + ":" + connectionId;
}
private <T> T withSessionLock(String key, Supplier<T> action) {
Object lock = sessionLocks.computeIfAbsent(key, k -> new Object());
synchronized (lock) {
return action.get();
}
}
private <T> T withSessionLock(String key, Supplier<T> action) {
Object lock = sessionLocks.computeIfAbsent(key, k -> new Object());
synchronized (lock) {
return action.get();
}
}
private <T> T withTwoSessionLocks(String keyA, String keyB, Supplier<T> action) {
if (keyA.equals(keyB)) {
return withSessionLock(keyA, action);
}
String first = keyA.compareTo(keyB) < 0 ? keyA : keyB;
String second = keyA.compareTo(keyB) < 0 ? keyB : keyA;
Object firstLock = sessionLocks.computeIfAbsent(first, k -> new Object());
Object secondLock = sessionLocks.computeIfAbsent(second, k -> new Object());
synchronized (firstLock) {
synchronized (secondLock) {
return action.get();
}
}
}
private SftpService.SftpSession getOrCreateSession(Long connectionId, Long userId) throws Exception {
String key = sessionKey(userId, connectionId);
@@ -160,35 +176,38 @@ public class SftpController {
}
}
@GetMapping("/download")
public ResponseEntity<byte[]> download(
@RequestParam Long connectionId,
@RequestParam String path,
Authentication authentication) {
try {
Long userId = getCurrentUserId(authentication);
String key = sessionKey(userId, connectionId);
return withSessionLock(key, () -> {
try {
SftpService.SftpSession session = getOrCreateSession(connectionId, userId);
byte[] data = sftpService.download(session, path);
String filename = path.contains("/") ? path.substring(path.lastIndexOf('/') + 1) : path;
return ResponseEntity.ok()
.header(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=\"" + filename + "\"")
.contentType(MediaType.APPLICATION_OCTET_STREAM)
.body(data);
} catch (Exception e) {
SftpService.SftpSession existing = sessions.remove(key);
if (existing != null) {
existing.disconnect();
}
throw new RuntimeException(e);
}
});
} catch (Exception e) {
return ResponseEntity.status(500).build();
}
}
@GetMapping("/download")
public ResponseEntity<StreamingResponseBody> download(
@RequestParam Long connectionId,
@RequestParam String path,
Authentication authentication) {
try {
Long userId = getCurrentUserId(authentication);
String key = sessionKey(userId, connectionId);
String filename = path.contains("/") ? path.substring(path.lastIndexOf('/') + 1) : path;
StreamingResponseBody stream = outputStream -> withSessionLock(key, () -> {
try {
SftpService.SftpSession session = getOrCreateSession(connectionId, userId);
sftpService.download(session, path, outputStream);
outputStream.flush();
return null;
} catch (Exception e) {
SftpService.SftpSession existing = sessions.remove(key);
if (existing != null) {
existing.disconnect();
}
throw new RuntimeException(e);
}
});
return ResponseEntity.ok()
.header(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=\"" + filename + "\"")
.contentType(MediaType.APPLICATION_OCTET_STREAM)
.body(stream);
} catch (Exception e) {
return ResponseEntity.status(500).build();
}
}
@PostMapping("/upload")
public ResponseEntity<Map<String, String>> upload(
@@ -202,13 +221,15 @@ public class SftpController {
return withSessionLock(key, () -> {
try {
SftpService.SftpSession session = getOrCreateSession(connectionId, userId);
String remotePath = (path == null || path.isEmpty() || path.equals("/"))
? "/" + file.getOriginalFilename()
: (path.endsWith("/") ? path + file.getOriginalFilename() : path + "/" + file.getOriginalFilename());
sftpService.upload(session, remotePath, file.getBytes());
Map<String, String> result = new HashMap<>();
result.put("message", "Uploaded");
return ResponseEntity.ok(result);
String remotePath = (path == null || path.isEmpty() || path.equals("/"))
? "/" + file.getOriginalFilename()
: (path.endsWith("/") ? path + file.getOriginalFilename() : path + "/" + file.getOriginalFilename());
try (java.io.InputStream in = file.getInputStream()) {
sftpService.upload(session, remotePath, in);
}
Map<String, String> result = new HashMap<>();
result.put("message", "Uploaded");
return ResponseEntity.ok(result);
} catch (Exception e) {
SftpService.SftpSession existing = sessions.remove(key);
if (existing != null) {
@@ -317,36 +338,55 @@ public class SftpController {
}
@PostMapping("/transfer-remote")
public ResponseEntity<Map<String, String>> transferRemote(
@RequestParam Long sourceConnectionId,
@RequestParam String sourcePath,
@RequestParam Long targetConnectionId,
@RequestParam String targetPath,
Authentication authentication) {
try {
Long userId = getCurrentUserId(authentication);
public ResponseEntity<Map<String, String>> transferRemote(
@RequestParam Long sourceConnectionId,
@RequestParam String sourcePath,
@RequestParam Long targetConnectionId,
@RequestParam String targetPath,
Authentication authentication) {
try {
Long userId = getCurrentUserId(authentication);
if (sourcePath == null || sourcePath.trim().isEmpty()) {
Map<String, String> err = new HashMap<>();
err.put("error", "sourcePath is required");
return ResponseEntity.badRequest().body(err);
}
if (targetPath == null || targetPath.trim().isEmpty()) {
Map<String, String> err = new HashMap<>();
err.put("error", "targetPath is required");
return ResponseEntity.badRequest().body(err);
}
SftpService.SftpSession sourceSession = getOrCreateSession(sourceConnectionId, userId);
SftpService.SftpSession targetSession = getOrCreateSession(targetConnectionId, userId);
if (sourceConnectionId.equals(targetConnectionId)) {
sftpService.rename(sourceSession, sourcePath.trim(), targetPath.trim());
} else {
sftpService.transferRemote(sourceSession, sourcePath.trim(), targetSession, targetPath.trim());
}
Map<String, String> result = new HashMap<>();
result.put("message", "Transferred");
return ResponseEntity.ok(result);
} catch (Exception e) {
Map<String, String> error = new HashMap<>();
if (targetPath == null || targetPath.trim().isEmpty()) {
Map<String, String> err = new HashMap<>();
err.put("error", "targetPath is required");
return ResponseEntity.badRequest().body(err);
}
String sourceKey = sessionKey(userId, sourceConnectionId);
String targetKey = sessionKey(userId, targetConnectionId);
withTwoSessionLocks(sourceKey, targetKey, () -> {
try {
SftpService.SftpSession sourceSession = getOrCreateSession(sourceConnectionId, userId);
SftpService.SftpSession targetSession = getOrCreateSession(targetConnectionId, userId);
if (sourceConnectionId.equals(targetConnectionId)) {
sftpService.rename(sourceSession, sourcePath.trim(), targetPath.trim());
} else {
sftpService.transferRemote(sourceSession, sourcePath.trim(), targetSession, targetPath.trim());
}
return null;
} catch (Exception e) {
SftpService.SftpSession source = sessions.remove(sourceKey);
if (source != null) {
source.disconnect();
}
if (!sourceKey.equals(targetKey)) {
SftpService.SftpSession target = sessions.remove(targetKey);
if (target != null) {
target.disconnect();
}
}
throw new RuntimeException(e);
}
});
Map<String, String> result = new HashMap<>();
result.put("message", "Transferred");
return ResponseEntity.ok(result);
} catch (Exception e) {
Map<String, String> error = new HashMap<>();
error.put("error", e.getMessage() != null ? e.getMessage() : "Transfer failed");
return ResponseEntity.status(500).body(error);
}

View File

@@ -45,17 +45,18 @@ public class ConnectionService {
conn.setName(request.getName());
conn.setHost(request.getHost());
conn.setPort(request.getPort() != null ? request.getPort() : 22);
conn.setUsername(request.getUsername());
conn.setAuthType(request.getAuthType() != null ? request.getAuthType() : Connection.AuthType.PASSWORD);
if (conn.getAuthType() == Connection.AuthType.PASSWORD && request.getPassword() != null) {
conn.setEncryptedPassword(encryptionService.encrypt(request.getPassword()));
} else if (conn.getAuthType() == Connection.AuthType.PRIVATE_KEY && request.getPrivateKey() != null) {
conn.setEncryptedPrivateKey(encryptionService.encrypt(request.getPrivateKey()));
if (request.getPassphrase() != null && !request.getPassphrase().isEmpty()) {
conn.setPassphrase(encryptionService.encrypt(request.getPassphrase()));
}
}
conn.setUsername(request.getUsername());
conn.setAuthType(request.getAuthType() != null ? request.getAuthType() : Connection.AuthType.PASSWORD);
if (conn.getAuthType() == Connection.AuthType.PASSWORD) {
conn.setEncryptedPassword(encryptionService.encrypt(request.getPassword()));
conn.setEncryptedPrivateKey(null);
conn.setPassphrase(null);
} else {
conn.setEncryptedPassword(null);
conn.setEncryptedPrivateKey(encryptionService.encrypt(request.getPrivateKey()));
conn.setPassphrase(encryptionService.encrypt(request.getPassphrase()));
}
conn = connectionRepository.save(conn);
return ConnectionDto.fromEntity(conn);
@@ -69,22 +70,27 @@ public class ConnectionService {
throw new RuntimeException("Access denied");
}
if (request.getName() != null) conn.setName(request.getName());
if (request.getHost() != null) conn.setHost(request.getHost());
if (request.getPort() != null) conn.setPort(request.getPort());
if (request.getUsername() != null) conn.setUsername(request.getUsername());
if (request.getAuthType() != null) conn.setAuthType(request.getAuthType());
if (request.getPassword() != null) {
conn.setEncryptedPassword(encryptionService.encrypt(request.getPassword()));
}
if (request.getPrivateKey() != null) {
conn.setEncryptedPrivateKey(encryptionService.encrypt(request.getPrivateKey()));
}
if (request.getPassphrase() != null) {
conn.setPassphrase(request.getPassphrase().isEmpty() ? null :
encryptionService.encrypt(request.getPassphrase()));
}
if (request.getName() != null) conn.setName(request.getName());
if (request.getHost() != null) conn.setHost(request.getHost());
if (request.getPort() != null) conn.setPort(request.getPort());
if (request.getUsername() != null) conn.setUsername(request.getUsername());
if (request.getAuthType() != null) conn.setAuthType(request.getAuthType());
if (conn.getAuthType() == Connection.AuthType.PASSWORD) {
if (request.getPassword() != null) {
conn.setEncryptedPassword(encryptionService.encrypt(request.getPassword()));
}
conn.setEncryptedPrivateKey(null);
conn.setPassphrase(null);
} else {
if (request.getPrivateKey() != null) {
conn.setEncryptedPrivateKey(encryptionService.encrypt(request.getPrivateKey()));
}
if (request.getPassphrase() != null) {
conn.setPassphrase(encryptionService.encrypt(request.getPassphrase()));
}
conn.setEncryptedPassword(null);
}
conn.setUpdatedAt(Instant.now());
conn = connectionRepository.save(conn);

View File

@@ -4,15 +4,14 @@ import com.jcraft.jsch.ChannelSftp;
import com.jcraft.jsch.JSch;
import com.jcraft.jsch.Session;
import com.jcraft.jsch.SftpException;
import com.sshmanager.entity.Connection;
import org.springframework.stereotype.Service;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.InputStream;
import java.io.PipedInputStream;
import java.io.PipedOutputStream;
import java.nio.charset.StandardCharsets;
import com.sshmanager.entity.Connection;
import org.springframework.stereotype.Service;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.PipedInputStream;
import java.io.PipedOutputStream;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.Vector;
@@ -149,15 +148,13 @@ public class SftpService {
}
}
public byte[] download(SftpSession sftpSession, String remotePath) throws Exception {
ByteArrayOutputStream out = new ByteArrayOutputStream();
sftpSession.getChannel().get(remotePath, out);
return out.toByteArray();
}
public void upload(SftpSession sftpSession, String remotePath, byte[] data) throws Exception {
sftpSession.getChannel().put(new ByteArrayInputStream(data), remotePath);
}
public void download(SftpSession sftpSession, String remotePath, OutputStream out) throws Exception {
sftpSession.getChannel().get(remotePath, out);
}
public void upload(SftpSession sftpSession, String remotePath, InputStream in) throws Exception {
sftpSession.getChannel().put(in, remotePath);
}
public void delete(SftpSession sftpSession, String remotePath, boolean isDir) throws Exception {
if (isDir) {