Improve SFTP reliability and credential hygiene with regression tests
This commit is contained in:
@@ -9,12 +9,13 @@ import com.sshmanager.service.ConnectionService;
|
||||
import com.sshmanager.service.SftpService;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.security.core.Authentication;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
import org.springframework.web.multipart.MultipartFile;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.security.core.Authentication;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
import org.springframework.web.multipart.MultipartFile;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.StreamingResponseBody;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
@@ -57,12 +58,27 @@ public class SftpController {
|
||||
return userId + ":" + connectionId;
|
||||
}
|
||||
|
||||
private <T> T withSessionLock(String key, Supplier<T> action) {
|
||||
Object lock = sessionLocks.computeIfAbsent(key, k -> new Object());
|
||||
synchronized (lock) {
|
||||
return action.get();
|
||||
}
|
||||
}
|
||||
private <T> T withSessionLock(String key, Supplier<T> action) {
|
||||
Object lock = sessionLocks.computeIfAbsent(key, k -> new Object());
|
||||
synchronized (lock) {
|
||||
return action.get();
|
||||
}
|
||||
}
|
||||
|
||||
private <T> T withTwoSessionLocks(String keyA, String keyB, Supplier<T> action) {
|
||||
if (keyA.equals(keyB)) {
|
||||
return withSessionLock(keyA, action);
|
||||
}
|
||||
String first = keyA.compareTo(keyB) < 0 ? keyA : keyB;
|
||||
String second = keyA.compareTo(keyB) < 0 ? keyB : keyA;
|
||||
Object firstLock = sessionLocks.computeIfAbsent(first, k -> new Object());
|
||||
Object secondLock = sessionLocks.computeIfAbsent(second, k -> new Object());
|
||||
synchronized (firstLock) {
|
||||
synchronized (secondLock) {
|
||||
return action.get();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private SftpService.SftpSession getOrCreateSession(Long connectionId, Long userId) throws Exception {
|
||||
String key = sessionKey(userId, connectionId);
|
||||
@@ -160,35 +176,38 @@ public class SftpController {
|
||||
}
|
||||
}
|
||||
|
||||
@GetMapping("/download")
|
||||
public ResponseEntity<byte[]> download(
|
||||
@RequestParam Long connectionId,
|
||||
@RequestParam String path,
|
||||
Authentication authentication) {
|
||||
try {
|
||||
Long userId = getCurrentUserId(authentication);
|
||||
String key = sessionKey(userId, connectionId);
|
||||
return withSessionLock(key, () -> {
|
||||
try {
|
||||
SftpService.SftpSession session = getOrCreateSession(connectionId, userId);
|
||||
byte[] data = sftpService.download(session, path);
|
||||
String filename = path.contains("/") ? path.substring(path.lastIndexOf('/') + 1) : path;
|
||||
return ResponseEntity.ok()
|
||||
.header(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=\"" + filename + "\"")
|
||||
.contentType(MediaType.APPLICATION_OCTET_STREAM)
|
||||
.body(data);
|
||||
} catch (Exception e) {
|
||||
SftpService.SftpSession existing = sessions.remove(key);
|
||||
if (existing != null) {
|
||||
existing.disconnect();
|
||||
}
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
});
|
||||
} catch (Exception e) {
|
||||
return ResponseEntity.status(500).build();
|
||||
}
|
||||
}
|
||||
@GetMapping("/download")
|
||||
public ResponseEntity<StreamingResponseBody> download(
|
||||
@RequestParam Long connectionId,
|
||||
@RequestParam String path,
|
||||
Authentication authentication) {
|
||||
try {
|
||||
Long userId = getCurrentUserId(authentication);
|
||||
String key = sessionKey(userId, connectionId);
|
||||
String filename = path.contains("/") ? path.substring(path.lastIndexOf('/') + 1) : path;
|
||||
StreamingResponseBody stream = outputStream -> withSessionLock(key, () -> {
|
||||
try {
|
||||
SftpService.SftpSession session = getOrCreateSession(connectionId, userId);
|
||||
sftpService.download(session, path, outputStream);
|
||||
outputStream.flush();
|
||||
return null;
|
||||
} catch (Exception e) {
|
||||
SftpService.SftpSession existing = sessions.remove(key);
|
||||
if (existing != null) {
|
||||
existing.disconnect();
|
||||
}
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
});
|
||||
|
||||
return ResponseEntity.ok()
|
||||
.header(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=\"" + filename + "\"")
|
||||
.contentType(MediaType.APPLICATION_OCTET_STREAM)
|
||||
.body(stream);
|
||||
} catch (Exception e) {
|
||||
return ResponseEntity.status(500).build();
|
||||
}
|
||||
}
|
||||
|
||||
@PostMapping("/upload")
|
||||
public ResponseEntity<Map<String, String>> upload(
|
||||
@@ -202,13 +221,15 @@ public class SftpController {
|
||||
return withSessionLock(key, () -> {
|
||||
try {
|
||||
SftpService.SftpSession session = getOrCreateSession(connectionId, userId);
|
||||
String remotePath = (path == null || path.isEmpty() || path.equals("/"))
|
||||
? "/" + file.getOriginalFilename()
|
||||
: (path.endsWith("/") ? path + file.getOriginalFilename() : path + "/" + file.getOriginalFilename());
|
||||
sftpService.upload(session, remotePath, file.getBytes());
|
||||
Map<String, String> result = new HashMap<>();
|
||||
result.put("message", "Uploaded");
|
||||
return ResponseEntity.ok(result);
|
||||
String remotePath = (path == null || path.isEmpty() || path.equals("/"))
|
||||
? "/" + file.getOriginalFilename()
|
||||
: (path.endsWith("/") ? path + file.getOriginalFilename() : path + "/" + file.getOriginalFilename());
|
||||
try (java.io.InputStream in = file.getInputStream()) {
|
||||
sftpService.upload(session, remotePath, in);
|
||||
}
|
||||
Map<String, String> result = new HashMap<>();
|
||||
result.put("message", "Uploaded");
|
||||
return ResponseEntity.ok(result);
|
||||
} catch (Exception e) {
|
||||
SftpService.SftpSession existing = sessions.remove(key);
|
||||
if (existing != null) {
|
||||
@@ -317,36 +338,55 @@ public class SftpController {
|
||||
}
|
||||
|
||||
@PostMapping("/transfer-remote")
|
||||
public ResponseEntity<Map<String, String>> transferRemote(
|
||||
@RequestParam Long sourceConnectionId,
|
||||
@RequestParam String sourcePath,
|
||||
@RequestParam Long targetConnectionId,
|
||||
@RequestParam String targetPath,
|
||||
Authentication authentication) {
|
||||
try {
|
||||
Long userId = getCurrentUserId(authentication);
|
||||
public ResponseEntity<Map<String, String>> transferRemote(
|
||||
@RequestParam Long sourceConnectionId,
|
||||
@RequestParam String sourcePath,
|
||||
@RequestParam Long targetConnectionId,
|
||||
@RequestParam String targetPath,
|
||||
Authentication authentication) {
|
||||
try {
|
||||
Long userId = getCurrentUserId(authentication);
|
||||
if (sourcePath == null || sourcePath.trim().isEmpty()) {
|
||||
Map<String, String> err = new HashMap<>();
|
||||
err.put("error", "sourcePath is required");
|
||||
return ResponseEntity.badRequest().body(err);
|
||||
}
|
||||
if (targetPath == null || targetPath.trim().isEmpty()) {
|
||||
Map<String, String> err = new HashMap<>();
|
||||
err.put("error", "targetPath is required");
|
||||
return ResponseEntity.badRequest().body(err);
|
||||
}
|
||||
SftpService.SftpSession sourceSession = getOrCreateSession(sourceConnectionId, userId);
|
||||
SftpService.SftpSession targetSession = getOrCreateSession(targetConnectionId, userId);
|
||||
if (sourceConnectionId.equals(targetConnectionId)) {
|
||||
sftpService.rename(sourceSession, sourcePath.trim(), targetPath.trim());
|
||||
} else {
|
||||
sftpService.transferRemote(sourceSession, sourcePath.trim(), targetSession, targetPath.trim());
|
||||
}
|
||||
Map<String, String> result = new HashMap<>();
|
||||
result.put("message", "Transferred");
|
||||
return ResponseEntity.ok(result);
|
||||
} catch (Exception e) {
|
||||
Map<String, String> error = new HashMap<>();
|
||||
if (targetPath == null || targetPath.trim().isEmpty()) {
|
||||
Map<String, String> err = new HashMap<>();
|
||||
err.put("error", "targetPath is required");
|
||||
return ResponseEntity.badRequest().body(err);
|
||||
}
|
||||
String sourceKey = sessionKey(userId, sourceConnectionId);
|
||||
String targetKey = sessionKey(userId, targetConnectionId);
|
||||
withTwoSessionLocks(sourceKey, targetKey, () -> {
|
||||
try {
|
||||
SftpService.SftpSession sourceSession = getOrCreateSession(sourceConnectionId, userId);
|
||||
SftpService.SftpSession targetSession = getOrCreateSession(targetConnectionId, userId);
|
||||
if (sourceConnectionId.equals(targetConnectionId)) {
|
||||
sftpService.rename(sourceSession, sourcePath.trim(), targetPath.trim());
|
||||
} else {
|
||||
sftpService.transferRemote(sourceSession, sourcePath.trim(), targetSession, targetPath.trim());
|
||||
}
|
||||
return null;
|
||||
} catch (Exception e) {
|
||||
SftpService.SftpSession source = sessions.remove(sourceKey);
|
||||
if (source != null) {
|
||||
source.disconnect();
|
||||
}
|
||||
if (!sourceKey.equals(targetKey)) {
|
||||
SftpService.SftpSession target = sessions.remove(targetKey);
|
||||
if (target != null) {
|
||||
target.disconnect();
|
||||
}
|
||||
}
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
});
|
||||
Map<String, String> result = new HashMap<>();
|
||||
result.put("message", "Transferred");
|
||||
return ResponseEntity.ok(result);
|
||||
} catch (Exception e) {
|
||||
Map<String, String> error = new HashMap<>();
|
||||
error.put("error", e.getMessage() != null ? e.getMessage() : "Transfer failed");
|
||||
return ResponseEntity.status(500).body(error);
|
||||
}
|
||||
|
||||
@@ -45,17 +45,18 @@ public class ConnectionService {
|
||||
conn.setName(request.getName());
|
||||
conn.setHost(request.getHost());
|
||||
conn.setPort(request.getPort() != null ? request.getPort() : 22);
|
||||
conn.setUsername(request.getUsername());
|
||||
conn.setAuthType(request.getAuthType() != null ? request.getAuthType() : Connection.AuthType.PASSWORD);
|
||||
|
||||
if (conn.getAuthType() == Connection.AuthType.PASSWORD && request.getPassword() != null) {
|
||||
conn.setEncryptedPassword(encryptionService.encrypt(request.getPassword()));
|
||||
} else if (conn.getAuthType() == Connection.AuthType.PRIVATE_KEY && request.getPrivateKey() != null) {
|
||||
conn.setEncryptedPrivateKey(encryptionService.encrypt(request.getPrivateKey()));
|
||||
if (request.getPassphrase() != null && !request.getPassphrase().isEmpty()) {
|
||||
conn.setPassphrase(encryptionService.encrypt(request.getPassphrase()));
|
||||
}
|
||||
}
|
||||
conn.setUsername(request.getUsername());
|
||||
conn.setAuthType(request.getAuthType() != null ? request.getAuthType() : Connection.AuthType.PASSWORD);
|
||||
|
||||
if (conn.getAuthType() == Connection.AuthType.PASSWORD) {
|
||||
conn.setEncryptedPassword(encryptionService.encrypt(request.getPassword()));
|
||||
conn.setEncryptedPrivateKey(null);
|
||||
conn.setPassphrase(null);
|
||||
} else {
|
||||
conn.setEncryptedPassword(null);
|
||||
conn.setEncryptedPrivateKey(encryptionService.encrypt(request.getPrivateKey()));
|
||||
conn.setPassphrase(encryptionService.encrypt(request.getPassphrase()));
|
||||
}
|
||||
|
||||
conn = connectionRepository.save(conn);
|
||||
return ConnectionDto.fromEntity(conn);
|
||||
@@ -69,22 +70,27 @@ public class ConnectionService {
|
||||
throw new RuntimeException("Access denied");
|
||||
}
|
||||
|
||||
if (request.getName() != null) conn.setName(request.getName());
|
||||
if (request.getHost() != null) conn.setHost(request.getHost());
|
||||
if (request.getPort() != null) conn.setPort(request.getPort());
|
||||
if (request.getUsername() != null) conn.setUsername(request.getUsername());
|
||||
if (request.getAuthType() != null) conn.setAuthType(request.getAuthType());
|
||||
|
||||
if (request.getPassword() != null) {
|
||||
conn.setEncryptedPassword(encryptionService.encrypt(request.getPassword()));
|
||||
}
|
||||
if (request.getPrivateKey() != null) {
|
||||
conn.setEncryptedPrivateKey(encryptionService.encrypt(request.getPrivateKey()));
|
||||
}
|
||||
if (request.getPassphrase() != null) {
|
||||
conn.setPassphrase(request.getPassphrase().isEmpty() ? null :
|
||||
encryptionService.encrypt(request.getPassphrase()));
|
||||
}
|
||||
if (request.getName() != null) conn.setName(request.getName());
|
||||
if (request.getHost() != null) conn.setHost(request.getHost());
|
||||
if (request.getPort() != null) conn.setPort(request.getPort());
|
||||
if (request.getUsername() != null) conn.setUsername(request.getUsername());
|
||||
if (request.getAuthType() != null) conn.setAuthType(request.getAuthType());
|
||||
|
||||
if (conn.getAuthType() == Connection.AuthType.PASSWORD) {
|
||||
if (request.getPassword() != null) {
|
||||
conn.setEncryptedPassword(encryptionService.encrypt(request.getPassword()));
|
||||
}
|
||||
conn.setEncryptedPrivateKey(null);
|
||||
conn.setPassphrase(null);
|
||||
} else {
|
||||
if (request.getPrivateKey() != null) {
|
||||
conn.setEncryptedPrivateKey(encryptionService.encrypt(request.getPrivateKey()));
|
||||
}
|
||||
if (request.getPassphrase() != null) {
|
||||
conn.setPassphrase(encryptionService.encrypt(request.getPassphrase()));
|
||||
}
|
||||
conn.setEncryptedPassword(null);
|
||||
}
|
||||
|
||||
conn.setUpdatedAt(Instant.now());
|
||||
conn = connectionRepository.save(conn);
|
||||
|
||||
@@ -4,15 +4,14 @@ import com.jcraft.jsch.ChannelSftp;
|
||||
import com.jcraft.jsch.JSch;
|
||||
import com.jcraft.jsch.Session;
|
||||
import com.jcraft.jsch.SftpException;
|
||||
import com.sshmanager.entity.Connection;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.io.InputStream;
|
||||
import java.io.PipedInputStream;
|
||||
import java.io.PipedOutputStream;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import com.sshmanager.entity.Connection;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
import java.io.PipedInputStream;
|
||||
import java.io.PipedOutputStream;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Vector;
|
||||
@@ -149,15 +148,13 @@ public class SftpService {
|
||||
}
|
||||
}
|
||||
|
||||
public byte[] download(SftpSession sftpSession, String remotePath) throws Exception {
|
||||
ByteArrayOutputStream out = new ByteArrayOutputStream();
|
||||
sftpSession.getChannel().get(remotePath, out);
|
||||
return out.toByteArray();
|
||||
}
|
||||
|
||||
public void upload(SftpSession sftpSession, String remotePath, byte[] data) throws Exception {
|
||||
sftpSession.getChannel().put(new ByteArrayInputStream(data), remotePath);
|
||||
}
|
||||
public void download(SftpSession sftpSession, String remotePath, OutputStream out) throws Exception {
|
||||
sftpSession.getChannel().get(remotePath, out);
|
||||
}
|
||||
|
||||
public void upload(SftpSession sftpSession, String remotePath, InputStream in) throws Exception {
|
||||
sftpSession.getChannel().put(in, remotePath);
|
||||
}
|
||||
|
||||
public void delete(SftpSession sftpSession, String remotePath, boolean isDir) throws Exception {
|
||||
if (isDir) {
|
||||
|
||||
@@ -0,0 +1,149 @@
|
||||
package com.sshmanager.controller;
|
||||
|
||||
import com.jcraft.jsch.ChannelSftp;
|
||||
import com.jcraft.jsch.Session;
|
||||
import com.sshmanager.entity.Connection;
|
||||
import com.sshmanager.entity.User;
|
||||
import com.sshmanager.repository.UserRepository;
|
||||
import com.sshmanager.service.ConnectionService;
|
||||
import com.sshmanager.service.SftpService;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.mockito.InjectMocks;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
import org.springframework.http.HttpStatus;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.security.core.Authentication;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.anyLong;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.doThrow;
|
||||
import static org.mockito.Mockito.never;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.verifyNoInteractions;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
class SftpControllerTest {
|
||||
|
||||
@Mock
|
||||
private ConnectionService connectionService;
|
||||
|
||||
@Mock
|
||||
private UserRepository userRepository;
|
||||
|
||||
@Mock
|
||||
private SftpService sftpService;
|
||||
|
||||
@InjectMocks
|
||||
private SftpController sftpController;
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
User user = new User();
|
||||
user.setId(1L);
|
||||
user.setUsername("alice");
|
||||
when(userRepository.findByUsername("alice")).thenReturn(Optional.of(user));
|
||||
}
|
||||
|
||||
@Test
|
||||
void transferRemoteReturnsBadRequestWhenSourcePathMissing() {
|
||||
ResponseEntity<Map<String, String>> response = sftpController.transferRemote(
|
||||
1L,
|
||||
" ",
|
||||
2L,
|
||||
"/tmp/target.txt",
|
||||
authentication()
|
||||
);
|
||||
|
||||
assertEquals(HttpStatus.BAD_REQUEST, response.getStatusCode());
|
||||
assertEquals("sourcePath is required", response.getBody().get("error"));
|
||||
verifyNoInteractions(sftpService);
|
||||
}
|
||||
|
||||
@Test
|
||||
void transferRemoteUsesRenameWhenSourceAndTargetConnectionAreSame() throws Exception {
|
||||
when(connectionService.getConnectionForSsh(anyLong(), eq(1L))).thenReturn(new Connection());
|
||||
SftpService.SftpSession session = connectedSession(true);
|
||||
when(sftpService.connect(any(Connection.class), any(), any(), any())).thenReturn(session);
|
||||
|
||||
ResponseEntity<Map<String, String>> response = sftpController.transferRemote(
|
||||
3L,
|
||||
"/src/file.txt",
|
||||
3L,
|
||||
"/dst/file.txt",
|
||||
authentication()
|
||||
);
|
||||
|
||||
assertEquals(HttpStatus.OK, response.getStatusCode());
|
||||
assertEquals("Transferred", response.getBody().get("message"));
|
||||
verify(sftpService).rename(session, "/src/file.txt", "/dst/file.txt");
|
||||
verify(sftpService, never()).transferRemote(any(), any(), any(), any());
|
||||
}
|
||||
|
||||
@Test
|
||||
void transferRemoteUsesCrossSessionTransferWhenConnectionsDiffer() throws Exception {
|
||||
when(connectionService.getConnectionForSsh(anyLong(), eq(1L))).thenReturn(new Connection());
|
||||
SftpService.SftpSession sourceSession = connectedSession(false);
|
||||
SftpService.SftpSession targetSession = connectedSession(false);
|
||||
when(sftpService.connect(any(Connection.class), any(), any(), any()))
|
||||
.thenReturn(sourceSession)
|
||||
.thenReturn(targetSession);
|
||||
|
||||
ResponseEntity<Map<String, String>> response = sftpController.transferRemote(
|
||||
10L,
|
||||
"/src/file.txt",
|
||||
20L,
|
||||
"/dst/file.txt",
|
||||
authentication()
|
||||
);
|
||||
|
||||
assertEquals(HttpStatus.OK, response.getStatusCode());
|
||||
assertEquals("Transferred", response.getBody().get("message"));
|
||||
verify(sftpService).transferRemote(sourceSession, "/src/file.txt", targetSession, "/dst/file.txt");
|
||||
verify(sftpService, never()).rename(any(), any(), any());
|
||||
}
|
||||
|
||||
@Test
|
||||
void transferRemoteReturnsServerErrorWhenTransferFails() throws Exception {
|
||||
when(connectionService.getConnectionForSsh(anyLong(), eq(1L))).thenReturn(new Connection());
|
||||
SftpService.SftpSession session = connectedSession(true);
|
||||
when(sftpService.connect(any(Connection.class), any(), any(), any())).thenReturn(session);
|
||||
doThrow(new RuntimeException("boom")).when(sftpService).rename(any(), any(), any());
|
||||
|
||||
ResponseEntity<Map<String, String>> response = sftpController.transferRemote(
|
||||
3L,
|
||||
"/src/file.txt",
|
||||
3L,
|
||||
"/dst/file.txt",
|
||||
authentication()
|
||||
);
|
||||
|
||||
assertEquals(HttpStatus.INTERNAL_SERVER_ERROR, response.getStatusCode());
|
||||
assertTrue(response.getBody().get("error").contains("boom"));
|
||||
}
|
||||
|
||||
private Authentication authentication() {
|
||||
Authentication authentication = mock(Authentication.class);
|
||||
when(authentication.getName()).thenReturn("alice");
|
||||
return authentication;
|
||||
}
|
||||
|
||||
private SftpService.SftpSession connectedSession(boolean connected) {
|
||||
Session session = mock(Session.class);
|
||||
ChannelSftp channel = mock(ChannelSftp.class);
|
||||
if (connected) {
|
||||
when(channel.isConnected()).thenReturn(true);
|
||||
}
|
||||
return new SftpService.SftpSession(session, channel);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,152 @@
|
||||
package com.sshmanager.service;
|
||||
|
||||
import com.sshmanager.dto.ConnectionCreateRequest;
|
||||
import com.sshmanager.dto.ConnectionDto;
|
||||
import com.sshmanager.entity.Connection;
|
||||
import com.sshmanager.repository.ConnectionRepository;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.mockito.InjectMocks;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
|
||||
import java.util.Optional;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||
import static org.junit.jupiter.api.Assertions.assertNull;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
class ConnectionServiceTest {
|
||||
|
||||
@Mock
|
||||
private ConnectionRepository connectionRepository;
|
||||
|
||||
@Mock
|
||||
private EncryptionService encryptionService;
|
||||
|
||||
@InjectMocks
|
||||
private ConnectionService connectionService;
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
when(connectionRepository.save(any(Connection.class))).thenAnswer(invocation -> invocation.getArgument(0));
|
||||
}
|
||||
|
||||
@Test
|
||||
void createPasswordConnectionClearsPrivateKeyCredentials() {
|
||||
ConnectionCreateRequest request = new ConnectionCreateRequest();
|
||||
request.setName("prod");
|
||||
request.setHost("127.0.0.1");
|
||||
request.setUsername("root");
|
||||
request.setAuthType(Connection.AuthType.PASSWORD);
|
||||
request.setPassword("secret");
|
||||
request.setPrivateKey("unused-key");
|
||||
request.setPassphrase("unused-passphrase");
|
||||
|
||||
when(encryptionService.encrypt("secret")).thenReturn("enc-secret");
|
||||
|
||||
ConnectionDto result = connectionService.create(request, 1L);
|
||||
|
||||
assertNotNull(result);
|
||||
ArgumentCaptor<Connection> captor = ArgumentCaptor.forClass(Connection.class);
|
||||
verify(connectionRepository).save(captor.capture());
|
||||
Connection saved = captor.getValue();
|
||||
|
||||
assertEquals(Connection.AuthType.PASSWORD, saved.getAuthType());
|
||||
assertEquals("enc-secret", saved.getEncryptedPassword());
|
||||
assertNull(saved.getEncryptedPrivateKey());
|
||||
assertNull(saved.getPassphrase());
|
||||
}
|
||||
|
||||
@Test
|
||||
void createPrivateKeyConnectionClearsPasswordCredential() {
|
||||
ConnectionCreateRequest request = new ConnectionCreateRequest();
|
||||
request.setName("prod");
|
||||
request.setHost("127.0.0.1");
|
||||
request.setUsername("root");
|
||||
request.setAuthType(Connection.AuthType.PRIVATE_KEY);
|
||||
request.setPassword("unused-password");
|
||||
request.setPrivateKey("private-key");
|
||||
request.setPassphrase("passphrase");
|
||||
|
||||
when(encryptionService.encrypt("private-key")).thenReturn("enc-key");
|
||||
when(encryptionService.encrypt("passphrase")).thenReturn("enc-passphrase");
|
||||
|
||||
ConnectionDto result = connectionService.create(request, 1L);
|
||||
|
||||
assertNotNull(result);
|
||||
ArgumentCaptor<Connection> captor = ArgumentCaptor.forClass(Connection.class);
|
||||
verify(connectionRepository).save(captor.capture());
|
||||
Connection saved = captor.getValue();
|
||||
|
||||
assertEquals(Connection.AuthType.PRIVATE_KEY, saved.getAuthType());
|
||||
assertEquals("enc-key", saved.getEncryptedPrivateKey());
|
||||
assertEquals("enc-passphrase", saved.getPassphrase());
|
||||
assertNull(saved.getEncryptedPassword());
|
||||
}
|
||||
|
||||
@Test
|
||||
void updateSwitchToPrivateKeyClearsPasswordCredential() {
|
||||
Connection existing = new Connection();
|
||||
existing.setId(10L);
|
||||
existing.setUserId(1L);
|
||||
existing.setAuthType(Connection.AuthType.PASSWORD);
|
||||
existing.setEncryptedPassword("old-password");
|
||||
|
||||
ConnectionCreateRequest request = new ConnectionCreateRequest();
|
||||
request.setAuthType(Connection.AuthType.PRIVATE_KEY);
|
||||
request.setPrivateKey("new-key");
|
||||
request.setPassphrase("new-passphrase");
|
||||
|
||||
when(connectionRepository.findById(10L)).thenReturn(Optional.of(existing));
|
||||
when(encryptionService.encrypt("new-key")).thenReturn("enc-new-key");
|
||||
when(encryptionService.encrypt("new-passphrase")).thenReturn("enc-new-passphrase");
|
||||
|
||||
ConnectionDto result = connectionService.update(10L, request, 1L);
|
||||
|
||||
assertNotNull(result);
|
||||
ArgumentCaptor<Connection> captor = ArgumentCaptor.forClass(Connection.class);
|
||||
verify(connectionRepository).save(captor.capture());
|
||||
Connection saved = captor.getValue();
|
||||
|
||||
assertEquals(Connection.AuthType.PRIVATE_KEY, saved.getAuthType());
|
||||
assertEquals("enc-new-key", saved.getEncryptedPrivateKey());
|
||||
assertEquals("enc-new-passphrase", saved.getPassphrase());
|
||||
assertNull(saved.getEncryptedPassword());
|
||||
}
|
||||
|
||||
@Test
|
||||
void updateSwitchToPasswordClearsPrivateKeyCredentials() {
|
||||
Connection existing = new Connection();
|
||||
existing.setId(20L);
|
||||
existing.setUserId(1L);
|
||||
existing.setAuthType(Connection.AuthType.PRIVATE_KEY);
|
||||
existing.setEncryptedPrivateKey("old-key");
|
||||
existing.setPassphrase("old-passphrase");
|
||||
|
||||
ConnectionCreateRequest request = new ConnectionCreateRequest();
|
||||
request.setAuthType(Connection.AuthType.PASSWORD);
|
||||
request.setPassword("new-password");
|
||||
|
||||
when(connectionRepository.findById(20L)).thenReturn(Optional.of(existing));
|
||||
when(encryptionService.encrypt("new-password")).thenReturn("enc-new-password");
|
||||
|
||||
ConnectionDto result = connectionService.update(20L, request, 1L);
|
||||
|
||||
assertNotNull(result);
|
||||
ArgumentCaptor<Connection> captor = ArgumentCaptor.forClass(Connection.class);
|
||||
verify(connectionRepository).save(captor.capture());
|
||||
Connection saved = captor.getValue();
|
||||
|
||||
assertEquals(Connection.AuthType.PASSWORD, saved.getAuthType());
|
||||
assertEquals("enc-new-password", saved.getEncryptedPassword());
|
||||
assertNull(saved.getEncryptedPrivateKey());
|
||||
assertNull(saved.getPassphrase());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user