diff --git a/backend/src/main/java/com/sshmanager/dto/ConnectionCreateRequest.java b/backend/src/main/java/com/sshmanager/dto/ConnectionCreateRequest.java index 65a4229..0878e93 100644 --- a/backend/src/main/java/com/sshmanager/dto/ConnectionCreateRequest.java +++ b/backend/src/main/java/com/sshmanager/dto/ConnectionCreateRequest.java @@ -4,13 +4,20 @@ import com.sshmanager.entity.Connection; import lombok.Data; @Data -public class ConnectionCreateRequest { - private String name; - private String host; - private Integer port = 22; - private String username; - private Connection.AuthType authType = Connection.AuthType.PASSWORD; - private String password; - private String privateKey; - private String passphrase; -} +public class ConnectionCreateRequest { + public enum SetupMode { + NONE, + PASSWORD_BOOTSTRAP + } + + private String name; + private String host; + private Integer port = 22; + private String username; + private Connection.AuthType authType = Connection.AuthType.PASSWORD; + private String password; + private String privateKey; + private String passphrase; + private SetupMode setupMode = SetupMode.NONE; + private String bootstrapPassword; +} diff --git a/backend/src/main/java/com/sshmanager/service/ConnectionService.java b/backend/src/main/java/com/sshmanager/service/ConnectionService.java index a7b8df4..e3bfd6b 100644 --- a/backend/src/main/java/com/sshmanager/service/ConnectionService.java +++ b/backend/src/main/java/com/sshmanager/service/ConnectionService.java @@ -1,11 +1,14 @@ package com.sshmanager.service; -import com.sshmanager.dto.ConnectionCreateRequest; -import com.sshmanager.dto.ConnectionDto; -import com.sshmanager.entity.Connection; -import com.sshmanager.repository.ConnectionRepository; -import org.springframework.stereotype.Service; -import org.springframework.transaction.annotation.Transactional; +import com.sshmanager.dto.ConnectionCreateRequest; +import com.sshmanager.dto.ConnectionDto; +import com.sshmanager.entity.Connection; +import com.sshmanager.exception.AccessDeniedException; +import com.sshmanager.exception.InvalidOperationException; +import com.sshmanager.exception.NotFoundException; +import com.sshmanager.repository.ConnectionRepository; +import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; import java.time.Instant; import java.util.List; @@ -14,17 +17,20 @@ import java.util.stream.Collectors; @Service public class ConnectionService { - private final ConnectionRepository connectionRepository; - private final EncryptionService encryptionService; - private final SshService sshService; - - public ConnectionService(ConnectionRepository connectionRepository, - EncryptionService encryptionService, - SshService sshService) { - this.connectionRepository = connectionRepository; - this.encryptionService = encryptionService; - this.sshService = sshService; - } + private final ConnectionRepository connectionRepository; + private final EncryptionService encryptionService; + private final SshService sshService; + private final SshBootstrapService sshBootstrapService; + + public ConnectionService(ConnectionRepository connectionRepository, + EncryptionService encryptionService, + SshService sshService, + SshBootstrapService sshBootstrapService) { + this.connectionRepository = connectionRepository; + this.encryptionService = encryptionService; + this.sshService = sshService; + this.sshBootstrapService = sshBootstrapService; + } public List listByUserId(Long userId) { return connectionRepository.findByUserIdOrderByUpdatedAtDesc(userId).stream() @@ -32,73 +38,69 @@ public class ConnectionService { .collect(Collectors.toList()); } - public ConnectionDto getById(Long id, Long userId) { - Connection conn = connectionRepository.findById(id).orElseThrow( - () -> new RuntimeException("Connection not found: " + id)); - if (!conn.getUserId().equals(userId)) { - throw new RuntimeException("Access denied"); - } - return ConnectionDto.fromEntity(conn); - } - - @Transactional - public ConnectionDto create(ConnectionCreateRequest request, Long userId) { - Connection conn = new Connection(); - conn.setUserId(userId); - conn.setName(request.getName()); - conn.setHost(request.getHost()); - conn.setPort(request.getPort() != null ? request.getPort() : 22); - conn.setUsername(request.getUsername()); - conn.setAuthType(request.getAuthType() != null ? request.getAuthType() : Connection.AuthType.PASSWORD); + public ConnectionDto getById(Long id, Long userId) { + Connection conn = connectionRepository.findById(id).orElseThrow( + () -> new NotFoundException("Connection not found: " + id)); + if (!conn.getUserId().equals(userId)) { + throw new AccessDeniedException("Access denied"); + } + return ConnectionDto.fromEntity(conn); + } - if (conn.getAuthType() == Connection.AuthType.PASSWORD) { - conn.setEncryptedPassword(encryptionService.encrypt(request.getPassword())); - conn.setEncryptedPrivateKey(null); + @Transactional + public ConnectionDto create(ConnectionCreateRequest request, Long userId) { + validateCreateRequest(request); + + Connection conn = new Connection(); + conn.setUserId(userId); + conn.setName(trimToNull(request.getName())); + conn.setHost(trimToNull(request.getHost())); + conn.setPort(request.getPort() != null ? request.getPort() : 22); + conn.setUsername(trimToNull(request.getUsername())); + + if (getSetupMode(request) == ConnectionCreateRequest.SetupMode.PASSWORD_BOOTSTRAP) { + SshBootstrapService.BootstrapResult bootstrapResult = sshBootstrapService.bootstrapWithPassword(request, userId); + conn.setAuthType(Connection.AuthType.PRIVATE_KEY); + conn.setEncryptedPassword(null); + conn.setEncryptedPrivateKey(encryptionService.encrypt(bootstrapResult.getPrivateKey())); conn.setPassphrase(null); } else { - conn.setEncryptedPassword(null); - conn.setEncryptedPrivateKey(encryptionService.encrypt(request.getPrivateKey())); - conn.setPassphrase(encryptionService.encrypt(request.getPassphrase())); + conn.setAuthType(resolveAuthType(request)); + applyCredentialUpdate(conn, request); } - - conn = connectionRepository.save(conn); - return ConnectionDto.fromEntity(conn); - } - - @Transactional - public ConnectionDto update(Long id, ConnectionCreateRequest request, Long userId) { - Connection conn = connectionRepository.findById(id).orElseThrow( - () -> new RuntimeException("Connection not found: " + id)); - if (!conn.getUserId().equals(userId)) { - throw new RuntimeException("Access denied"); - } - - if (request.getName() != null) conn.setName(request.getName()); - if (request.getHost() != null) conn.setHost(request.getHost()); - if (request.getPort() != null) conn.setPort(request.getPort()); - if (request.getUsername() != null) conn.setUsername(request.getUsername()); + + conn = connectionRepository.save(conn); + return ConnectionDto.fromEntity(conn); + } + + @Transactional + public ConnectionDto update(Long id, ConnectionCreateRequest request, Long userId) { + Connection conn = connectionRepository.findById(id).orElseThrow( + () -> new NotFoundException("Connection not found: " + id)); + if (!conn.getUserId().equals(userId)) { + throw new AccessDeniedException("Access denied"); + } + if (getSetupMode(request) == ConnectionCreateRequest.SetupMode.PASSWORD_BOOTSTRAP) { + throw new InvalidOperationException("编辑连接时不支持一键免密配置"); + } + + if (request.getName() != null) conn.setName(trimToNull(request.getName())); + if (request.getHost() != null) conn.setHost(trimToNull(request.getHost())); + if (request.getPort() != null) { + validatePort(request.getPort()); + conn.setPort(request.getPort()); + } + if (request.getUsername() != null) conn.setUsername(trimToNull(request.getUsername())); if (request.getAuthType() != null) conn.setAuthType(request.getAuthType()); - if (conn.getAuthType() == Connection.AuthType.PASSWORD) { - if (request.getPassword() != null) { - conn.setEncryptedPassword(encryptionService.encrypt(request.getPassword())); - } - conn.setEncryptedPrivateKey(null); - conn.setPassphrase(null); - } else { - if (request.getPrivateKey() != null) { - conn.setEncryptedPrivateKey(encryptionService.encrypt(request.getPrivateKey())); - } - if (request.getPassphrase() != null) { - conn.setPassphrase(encryptionService.encrypt(request.getPassphrase())); - } - conn.setEncryptedPassword(null); - } - - conn.setUpdatedAt(Instant.now()); - conn = connectionRepository.save(conn); - return ConnectionDto.fromEntity(conn); - } + validatePersistedFields(conn); + applyCredentialUpdate(conn, request); + validateStoredCredentials(conn); + + conn.setUpdatedAt(Instant.now()); + conn = connectionRepository.save(conn); + return ConnectionDto.fromEntity(conn); + } @Transactional public void delete(Long id, Long userId) { @@ -107,19 +109,19 @@ public class ConnectionService { return; } if (!conn.getUserId().equals(userId)) { - throw new RuntimeException("Access denied"); + throw new AccessDeniedException("Access denied"); } connectionRepository.delete(conn); } - - public Connection getConnectionForSsh(Long id, Long userId) { - Connection conn = connectionRepository.findById(id).orElseThrow( - () -> new RuntimeException("Connection not found: " + id)); - if (!conn.getUserId().equals(userId)) { - throw new RuntimeException("Access denied"); - } - return conn; - } + + public Connection getConnectionForSsh(Long id, Long userId) { + Connection conn = connectionRepository.findById(id).orElseThrow( + () -> new NotFoundException("Connection not found: " + id)); + if (!conn.getUserId().equals(userId)) { + throw new AccessDeniedException("Access denied"); + } + return conn; + } public String getDecryptedPassword(Connection conn) { return conn.getEncryptedPassword() != null ? @@ -144,9 +146,115 @@ public class ConnectionService { } catch (Exception e) { throw new RuntimeException("Connection test failed: " + e.getMessage(), e); } finally { - if (session != null) { - session.disconnect(); - } - } - } -} + if (session != null) { + session.disconnect(); + } + } + } + + private void validateCreateRequest(ConnectionCreateRequest request) { + if (request == null) { + throw new InvalidOperationException("连接信息不能为空"); + } + validatePersistedFields(request); + if (getSetupMode(request) == ConnectionCreateRequest.SetupMode.PASSWORD_BOOTSTRAP) { + requireText(request.getBootstrapPassword(), "启用一键免密配置时必须填写初始登录密码"); + return; + } + + Connection.AuthType authType = resolveAuthType(request); + if (authType == Connection.AuthType.PASSWORD) { + if (!hasText(request.getPassword())) { + throw new InvalidOperationException("请填写密码"); + } + return; + } + + if (!hasText(request.getPrivateKey())) { + throw new InvalidOperationException("请填写私钥"); + } + } + + private void validatePersistedFields(ConnectionCreateRequest request) { + requireText(request.getName(), "请填写名称"); + requireText(request.getHost(), "请填写主机"); + requireText(request.getUsername(), "请填写用户名"); + validatePort(request.getPort() != null ? request.getPort() : 22); + } + + private void validatePersistedFields(Connection conn) { + requireText(conn.getName(), "请填写名称"); + requireText(conn.getHost(), "请填写主机"); + requireText(conn.getUsername(), "请填写用户名"); + validatePort(conn.getPort() != null ? conn.getPort() : 22); + } + + private void applyCredentialUpdate(Connection conn, ConnectionCreateRequest request) { + if (conn.getAuthType() == Connection.AuthType.PASSWORD) { + String password = trimToNull(request.getPassword()); + if (password != null) { + conn.setEncryptedPassword(encryptionService.encrypt(password)); + } + conn.setEncryptedPrivateKey(null); + conn.setPassphrase(null); + return; + } + + String privateKey = trimToNull(request.getPrivateKey()); + if (privateKey != null) { + conn.setEncryptedPrivateKey(encryptionService.encrypt(privateKey)); + } + if (request.getPassphrase() != null) { + conn.setPassphrase(encryptionService.encrypt(trimToNull(request.getPassphrase()))); + } + conn.setEncryptedPassword(null); + } + + private void validateStoredCredentials(Connection conn) { + if (conn.getAuthType() == Connection.AuthType.PASSWORD) { + if (!hasText(conn.getEncryptedPassword())) { + throw new InvalidOperationException("请填写密码"); + } + return; + } + + if (!hasText(conn.getEncryptedPrivateKey())) { + throw new InvalidOperationException("请填写私钥"); + } + } + + private Connection.AuthType resolveAuthType(ConnectionCreateRequest request) { + return request.getAuthType() != null ? request.getAuthType() : Connection.AuthType.PASSWORD; + } + + private ConnectionCreateRequest.SetupMode getSetupMode(ConnectionCreateRequest request) { + if (request == null || request.getSetupMode() == null) { + return ConnectionCreateRequest.SetupMode.NONE; + } + return request.getSetupMode(); + } + + private void requireText(String value, String message) { + if (!hasText(value)) { + throw new InvalidOperationException(message); + } + } + + private void validatePort(Integer port) { + if (port == null || port < 1 || port > 65535) { + throw new InvalidOperationException("端口号必须在1-65535之间"); + } + } + + private boolean hasText(String value) { + return value != null && !value.trim().isEmpty(); + } + + private String trimToNull(String value) { + if (value == null) { + return null; + } + String trimmed = value.trim(); + return trimmed.isEmpty() ? null : trimmed; + } +} diff --git a/backend/src/main/java/com/sshmanager/service/SshBootstrapService.java b/backend/src/main/java/com/sshmanager/service/SshBootstrapService.java new file mode 100644 index 0000000..6ce01aa --- /dev/null +++ b/backend/src/main/java/com/sshmanager/service/SshBootstrapService.java @@ -0,0 +1,221 @@ +package com.sshmanager.service; + +import com.jcraft.jsch.JSch; +import com.jcraft.jsch.KeyPair; +import com.sshmanager.dto.ConnectionCreateRequest; +import com.sshmanager.entity.Connection; +import com.sshmanager.exception.InvalidOperationException; +import org.springframework.stereotype.Service; + +import java.io.ByteArrayOutputStream; +import java.nio.charset.StandardCharsets; +import java.time.Instant; + +@Service +public class SshBootstrapService { + + private final SshService sshService; + + public SshBootstrapService(SshService sshService) { + this.sshService = sshService; + } + + public BootstrapResult bootstrapWithPassword(ConnectionCreateRequest request, Long userId) { + String bootstrapPassword = trimToNull(request.getBootstrapPassword()); + if (bootstrapPassword == null) { + throw new InvalidOperationException("启用一键免密配置时必须填写初始登录密码"); + } + + GeneratedKeyPair keyPair = generateKeyPair(buildKeyComment(userId, request.getName())); + Connection passwordConnection = buildConnection(request, Connection.AuthType.PASSWORD); + Connection privateKeyConnection = buildConnection(request, Connection.AuthType.PRIVATE_KEY); + + authorizePublicKey(passwordConnection, bootstrapPassword, keyPair.getPublicKey()); + verifyPrivateKeyLogin(privateKeyConnection, keyPair.getPrivateKey()); + + return new BootstrapResult(keyPair.getPrivateKey()); + } + + private GeneratedKeyPair generateKeyPair(String comment) { + KeyPair keyPair = null; + try { + JSch jsch = new JSch(); + keyPair = KeyPair.genKeyPair(jsch, KeyPair.RSA, 2048); + + ByteArrayOutputStream privateKeyOutput = new ByteArrayOutputStream(); + ByteArrayOutputStream publicKeyOutput = new ByteArrayOutputStream(); + keyPair.writePrivateKey(privateKeyOutput); + keyPair.writePublicKey(publicKeyOutput, comment); + + return new GeneratedKeyPair( + privateKeyOutput.toString(StandardCharsets.UTF_8.name()), + publicKeyOutput.toString(StandardCharsets.UTF_8.name()).trim() + ); + } catch (Exception e) { + throw new InvalidOperationException("免密初始化失败:无法生成 SSH 密钥"); + } finally { + if (keyPair != null) { + keyPair.dispose(); + } + } + } + + private void authorizePublicKey(Connection connection, String bootstrapPassword, String publicKey) { + String command = buildAuthorizeCommand(publicKey); + try { + SshService.CommandResult result = sshService.executeCommandWithResult( + connection, + bootstrapPassword, + null, + null, + command + ); + if (result.getExitStatus() != 0) { + throw new InvalidOperationException(buildRemoteFailureMessage( + "免密初始化失败:无法写入远端 authorized_keys", + result + )); + } + } catch (InvalidOperationException e) { + throw e; + } catch (Exception e) { + throw new InvalidOperationException("免密初始化失败:密码登录或公钥下发失败" + formatCauseMessage(e)); + } + } + + private void verifyPrivateKeyLogin(Connection connection, String privateKey) { + try { + SshService.CommandResult result = sshService.executeCommandWithResult( + connection, + null, + privateKey, + null, + "printf 'ssh-manager bootstrap ok'" + ); + if (result.getExitStatus() != 0) { + throw new InvalidOperationException(buildRemoteFailureMessage( + "免密初始化失败:公钥已下发,但私钥验证失败", + result + )); + } + } catch (InvalidOperationException e) { + throw e; + } catch (Exception e) { + throw new InvalidOperationException("免密初始化失败:公钥已下发,但私钥验证失败" + formatCauseMessage(e)); + } + } + + private String buildAuthorizeCommand(String publicKey) { + String escapedPublicKey = shellQuote(publicKey); + String innerCommand = + "umask 077 && " + + "mkdir -p ~/.ssh && chmod 700 ~/.ssh && " + + "touch ~/.ssh/authorized_keys && chmod 600 ~/.ssh/authorized_keys && " + + "{ grep -Fqx " + escapedPublicKey + " ~/.ssh/authorized_keys || " + + "printf '%s\\n' " + escapedPublicKey + " >> ~/.ssh/authorized_keys; }"; + return "sh -lc " + shellQuote(innerCommand); + } + + private String buildRemoteFailureMessage(String prefix, SshService.CommandResult result) { + String stderr = trimToNull(result.getStderr()); + String stdout = trimToNull(result.getStdout()); + String detail = stderr != null ? stderr : stdout; + if (detail == null) { + return prefix; + } + return prefix + ":" + detail; + } + + private String buildKeyComment(Long userId, String connectionName) { + String sanitizedName = sanitizeForComment(connectionName); + long timestamp = Instant.now().getEpochSecond(); + if (sanitizedName == null) { + return "ssh-manager-" + userId + "-" + timestamp; + } + return "ssh-manager-" + userId + "-" + timestamp + "-" + sanitizedName; + } + + private String sanitizeForComment(String value) { + String trimmed = trimToNull(value); + if (trimmed == null) { + return null; + } + + StringBuilder builder = new StringBuilder(); + for (int i = 0; i < trimmed.length(); i++) { + char ch = trimmed.charAt(i); + if ((ch >= 'a' && ch <= 'z') + || (ch >= 'A' && ch <= 'Z') + || (ch >= '0' && ch <= '9') + || ch == '.' + || ch == '_' + || ch == '-') { + builder.append(ch); + } else if (builder.length() == 0 || builder.charAt(builder.length() - 1) != '-') { + builder.append('-'); + } + } + + String sanitized = builder.toString().replaceAll("^-+|-+$", ""); + return sanitized.isEmpty() ? null : sanitized; + } + + private Connection buildConnection(ConnectionCreateRequest request, Connection.AuthType authType) { + Connection connection = new Connection(); + connection.setHost(trimToNull(request.getHost())); + connection.setPort(request.getPort() != null ? request.getPort() : 22); + connection.setUsername(trimToNull(request.getUsername())); + connection.setAuthType(authType); + return connection; + } + + private String shellQuote(String value) { + return "'" + value.replace("'", "'\"'\"'") + "'"; + } + + private String formatCauseMessage(Exception e) { + String message = trimToNull(e.getMessage()); + if (message == null) { + return ""; + } + return ":" + message; + } + + private String trimToNull(String value) { + if (value == null) { + return null; + } + String trimmed = value.trim(); + return trimmed.isEmpty() ? null : trimmed; + } + + public static class BootstrapResult { + private final String privateKey; + + public BootstrapResult(String privateKey) { + this.privateKey = privateKey; + } + + public String getPrivateKey() { + return privateKey; + } + } + + private static class GeneratedKeyPair { + private final String privateKey; + private final String publicKey; + + private GeneratedKeyPair(String privateKey, String publicKey) { + this.privateKey = privateKey; + this.publicKey = publicKey; + } + + public String getPrivateKey() { + return privateKey; + } + + public String getPublicKey() { + return publicKey; + } + } +} diff --git a/backend/src/main/java/com/sshmanager/service/SshService.java b/backend/src/main/java/com/sshmanager/service/SshService.java index f4519a7..97e59ab 100644 --- a/backend/src/main/java/com/sshmanager/service/SshService.java +++ b/backend/src/main/java/com/sshmanager/service/SshService.java @@ -4,14 +4,15 @@ import com.jcraft.jsch.ChannelExec; import com.jcraft.jsch.ChannelShell; import com.jcraft.jsch.JSch; import com.jcraft.jsch.Session; -import com.sshmanager.entity.Connection; -import org.springframework.stereotype.Service; - -import java.io.BufferedReader; -import java.io.InputStreamReader; -import java.nio.charset.StandardCharsets; -import java.io.InputStream; -import java.io.OutputStream; +import com.sshmanager.entity.Connection; +import org.springframework.stereotype.Service; + +import java.io.ByteArrayOutputStream; +import java.io.BufferedReader; +import java.io.InputStreamReader; +import java.nio.charset.StandardCharsets; +import java.io.InputStream; +import java.io.OutputStream; import java.io.PipedInputStream; import java.io.PipedOutputStream; @@ -69,15 +70,24 @@ public class SshService { }).start(); return new SshSession(session, channel, channelOut, pipeToChannel); - } - - // 执行单次命令并返回输出 - public String executeCommand(Connection conn, String password, String privateKey, String passphrase, String command) throws Exception { - JSch jsch = new JSch(); - - if (conn.getAuthType() == Connection.AuthType.PRIVATE_KEY && privateKey != null && !privateKey.isEmpty()) { - byte[] keyBytes = privateKey.getBytes(StandardCharsets.UTF_8); - byte[] passphraseBytes = (passphrase != null && !passphrase.isEmpty()) + } + + // 执行单次命令并返回输出 + public String executeCommand(Connection conn, String password, String privateKey, String passphrase, String command) throws Exception { + CommandResult result = executeCommandWithResult(conn, password, privateKey, passphrase, command); + return result.getStdout(); + } + + public CommandResult executeCommandWithResult(Connection conn, + String password, + String privateKey, + String passphrase, + String command) throws Exception { + JSch jsch = new JSch(); + + if (conn.getAuthType() == Connection.AuthType.PRIVATE_KEY && privateKey != null && !privateKey.isEmpty()) { + byte[] keyBytes = privateKey.getBytes(StandardCharsets.UTF_8); + byte[] passphraseBytes = (passphrase != null && !passphrase.isEmpty()) ? passphrase.getBytes(StandardCharsets.UTF_8) : null; jsch.addIdentity("key", keyBytes, null, passphraseBytes); } @@ -92,27 +102,40 @@ public class SshService { } session.connect(8000); - - ChannelExec channel = (ChannelExec) session.openChannel("exec"); - channel.setCommand(command); - channel.setErrStream(System.err); - - InputStream in = channel.getInputStream(); - channel.connect(3000); - - BufferedReader reader = new BufferedReader(new InputStreamReader(in, StandardCharsets.UTF_8)); - StringBuilder result = new StringBuilder(); - String line; - while ((line = reader.readLine()) != null) { - result.append(line).append("\n"); - } - - channel.disconnect(); - session.disconnect(); - - return result.toString().trim(); + + ChannelExec channel = (ChannelExec) session.openChannel("exec"); + channel.setCommand(command); + InputStream in = channel.getInputStream(); + ByteArrayOutputStream stderr = new ByteArrayOutputStream(); + channel.setErrStream(stderr, true); + channel.connect(3000); + + String stdout = readStream(in); + while (!channel.isClosed()) { + Thread.sleep(50L); + } + String stderrText = stderr.toString(StandardCharsets.UTF_8.name()).trim(); + int exitStatus = channel.getExitStatus(); + + channel.disconnect(); + session.disconnect(); + + return new CommandResult(stdout, stderrText, exitStatus); } - + + private String readStream(InputStream inputStream) throws Exception { + BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8)); + StringBuilder result = new StringBuilder(); + String line; + while ((line = reader.readLine()) != null) { + if (result.length() > 0) { + result.append('\n'); + } + result.append(line); + } + return result.toString().trim(); + } + public static class SshSession { private final Session session; private final ChannelShell channel; @@ -156,8 +179,32 @@ public class SshService { } } - public boolean isConnected() { - return channel != null && channel.isConnected(); - } + public boolean isConnected() { + return channel != null && channel.isConnected(); + } + } + + public static class CommandResult { + private final String stdout; + private final String stderr; + private final int exitStatus; + + public CommandResult(String stdout, String stderr, int exitStatus) { + this.stdout = stdout; + this.stderr = stderr; + this.exitStatus = exitStatus; + } + + public String getStdout() { + return stdout; + } + + public String getStderr() { + return stderr; + } + + public int getExitStatus() { + return exitStatus; + } } } diff --git a/backend/src/test/java/com/sshmanager/service/ConnectionServiceTest.java b/backend/src/test/java/com/sshmanager/service/ConnectionServiceTest.java index d7c7273..cf92037 100644 --- a/backend/src/test/java/com/sshmanager/service/ConnectionServiceTest.java +++ b/backend/src/test/java/com/sshmanager/service/ConnectionServiceTest.java @@ -3,6 +3,7 @@ package com.sshmanager.service; import com.sshmanager.dto.ConnectionCreateRequest; import com.sshmanager.dto.ConnectionDto; import com.sshmanager.entity.Connection; +import com.sshmanager.exception.InvalidOperationException; import com.sshmanager.repository.ConnectionRepository; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -17,9 +18,13 @@ import java.util.Optional; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.when; +import static org.mockito.Mockito.lenient; @ExtendWith(MockitoExtension.class) class ConnectionServiceTest { @@ -30,12 +35,18 @@ class ConnectionServiceTest { @Mock private EncryptionService encryptionService; + @Mock + private SshService sshService; + + @Mock + private SshBootstrapService sshBootstrapService; + @InjectMocks private ConnectionService connectionService; @BeforeEach void setUp() { - when(connectionRepository.save(any(Connection.class))).thenAnswer(invocation -> invocation.getArgument(0)); + lenient().when(connectionRepository.save(any(Connection.class))).thenAnswer(invocation -> invocation.getArgument(0)); } @Test @@ -91,11 +102,82 @@ class ConnectionServiceTest { assertNull(saved.getEncryptedPassword()); } + @Test + void createPasswordBootstrapConnectionSavesGeneratedPrivateKey() { + ConnectionCreateRequest request = new ConnectionCreateRequest(); + request.setName("prod"); + request.setHost("127.0.0.1"); + request.setPort(22); + request.setUsername("root"); + request.setSetupMode(ConnectionCreateRequest.SetupMode.PASSWORD_BOOTSTRAP); + request.setBootstrapPassword("bootstrap-secret"); + + when(sshBootstrapService.bootstrapWithPassword(request, 1L)) + .thenReturn(new SshBootstrapService.BootstrapResult("generated-private-key")); + when(encryptionService.encrypt("generated-private-key")).thenReturn("enc-generated-private-key"); + + ConnectionDto result = connectionService.create(request, 1L); + + assertNotNull(result); + ArgumentCaptor captor = ArgumentCaptor.forClass(Connection.class); + verify(connectionRepository).save(captor.capture()); + Connection saved = captor.getValue(); + + assertEquals(Connection.AuthType.PRIVATE_KEY, saved.getAuthType()); + assertNull(saved.getEncryptedPassword()); + assertEquals("enc-generated-private-key", saved.getEncryptedPrivateKey()); + assertNull(saved.getPassphrase()); + verify(sshBootstrapService).bootstrapWithPassword(request, 1L); + verify(encryptionService, never()).encrypt("bootstrap-secret"); + } + + @Test + void createPasswordBootstrapConnectionRequiresBootstrapPassword() { + ConnectionCreateRequest request = new ConnectionCreateRequest(); + request.setName("prod"); + request.setHost("127.0.0.1"); + request.setUsername("root"); + request.setSetupMode(ConnectionCreateRequest.SetupMode.PASSWORD_BOOTSTRAP); + + InvalidOperationException exception = assertThrows( + InvalidOperationException.class, + () -> connectionService.create(request, 1L) + ); + + assertEquals("启用一键免密配置时必须填写初始登录密码", exception.getMessage()); + verifyNoInteractions(connectionRepository, sshBootstrapService, encryptionService); + } + + @Test + void createPasswordBootstrapConnectionDoesNotSaveWhenBootstrapFails() { + ConnectionCreateRequest request = new ConnectionCreateRequest(); + request.setName("prod"); + request.setHost("127.0.0.1"); + request.setUsername("root"); + request.setSetupMode(ConnectionCreateRequest.SetupMode.PASSWORD_BOOTSTRAP); + request.setBootstrapPassword("bootstrap-secret"); + + when(sshBootstrapService.bootstrapWithPassword(request, 1L)) + .thenThrow(new InvalidOperationException("免密初始化失败")); + + InvalidOperationException exception = assertThrows( + InvalidOperationException.class, + () -> connectionService.create(request, 1L) + ); + + assertEquals("免密初始化失败", exception.getMessage()); + verify(connectionRepository, never()).save(any(Connection.class)); + } + @Test void updateSwitchToPrivateKeyClearsPasswordCredential() { Connection existing = new Connection(); existing.setId(10L); existing.setUserId(1L); + existing.setName("prod"); + existing.setHost("127.0.0.1"); + existing.setPort(22); + existing.setUsername("root"); existing.setAuthType(Connection.AuthType.PASSWORD); existing.setEncryptedPassword("old-password"); @@ -126,6 +208,10 @@ class ConnectionServiceTest { Connection existing = new Connection(); existing.setId(20L); existing.setUserId(1L); + existing.setName("prod"); + existing.setHost("127.0.0.1"); + existing.setPort(22); + existing.setUsername("root"); existing.setAuthType(Connection.AuthType.PRIVATE_KEY); existing.setEncryptedPrivateKey("old-key"); existing.setPassphrase("old-passphrase"); @@ -149,4 +235,30 @@ class ConnectionServiceTest { assertNull(saved.getEncryptedPrivateKey()); assertNull(saved.getPassphrase()); } + + @Test + void updateRejectsPasswordBootstrapMode() { + Connection existing = new Connection(); + existing.setId(20L); + existing.setUserId(1L); + existing.setName("prod"); + existing.setHost("127.0.0.1"); + existing.setPort(22); + existing.setUsername("root"); + existing.setAuthType(Connection.AuthType.PASSWORD); + existing.setEncryptedPassword("old-password"); + + ConnectionCreateRequest request = new ConnectionCreateRequest(); + request.setSetupMode(ConnectionCreateRequest.SetupMode.PASSWORD_BOOTSTRAP); + + when(connectionRepository.findById(20L)).thenReturn(Optional.of(existing)); + + InvalidOperationException exception = assertThrows( + InvalidOperationException.class, + () -> connectionService.update(20L, request, 1L) + ); + + assertEquals("编辑连接时不支持一键免密配置", exception.getMessage()); + verify(connectionRepository, never()).save(any(Connection.class)); + } } diff --git a/backend/src/test/java/com/sshmanager/service/SshBootstrapServiceTest.java b/backend/src/test/java/com/sshmanager/service/SshBootstrapServiceTest.java new file mode 100644 index 0000000..ce193bf --- /dev/null +++ b/backend/src/test/java/com/sshmanager/service/SshBootstrapServiceTest.java @@ -0,0 +1,104 @@ +package com.sshmanager.service; + +import com.sshmanager.dto.ConnectionCreateRequest; +import com.sshmanager.entity.Connection; +import com.sshmanager.exception.InvalidOperationException; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.isNull; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +class SshBootstrapServiceTest { + + @Mock + private SshService sshService; + + @InjectMocks + private SshBootstrapService sshBootstrapService; + + @Test + void bootstrapWithPasswordGeneratesKeyAndVerifiesPrivateKeyLogin() throws Exception { + ConnectionCreateRequest request = new ConnectionCreateRequest(); + request.setName("prod"); + request.setHost("127.0.0.1"); + request.setPort(22); + request.setUsername("root"); + request.setBootstrapPassword("bootstrap-secret"); + + when(sshService.executeCommandWithResult(any(Connection.class), anyString(), isNull(), isNull(), anyString())) + .thenReturn(new SshService.CommandResult("", "", 0)); + when(sshService.executeCommandWithResult(any(Connection.class), isNull(), anyString(), isNull(), anyString())) + .thenReturn(new SshService.CommandResult("ssh-manager bootstrap ok", "", 0)); + + SshBootstrapService.BootstrapResult result = sshBootstrapService.bootstrapWithPassword(request, 1L); + + assertNotNull(result); + assertTrue(result.getPrivateKey().contains("PRIVATE KEY")); + + ArgumentCaptor connectionCaptor = ArgumentCaptor.forClass(Connection.class); + ArgumentCaptor passwordCaptor = ArgumentCaptor.forClass(String.class); + ArgumentCaptor privateKeyCaptor = ArgumentCaptor.forClass(String.class); + ArgumentCaptor commandCaptor = ArgumentCaptor.forClass(String.class); + verify(sshService, times(2)).executeCommandWithResult( + connectionCaptor.capture(), + passwordCaptor.capture(), + privateKeyCaptor.capture(), + isNull(), + commandCaptor.capture() + ); + + List connections = connectionCaptor.getAllValues(); + List passwords = passwordCaptor.getAllValues(); + List privateKeys = privateKeyCaptor.getAllValues(); + List commands = commandCaptor.getAllValues(); + + assertEquals(Connection.AuthType.PASSWORD, connections.get(0).getAuthType()); + assertEquals("bootstrap-secret", passwords.get(0)); + assertNull(privateKeys.get(0)); + assertTrue(commands.get(0).contains("authorized_keys")); + + assertEquals(Connection.AuthType.PRIVATE_KEY, connections.get(1).getAuthType()); + assertNull(passwords.get(1)); + assertNotNull(privateKeys.get(1)); + assertTrue(privateKeys.get(1).contains("PRIVATE KEY")); + assertTrue(commands.get(1).contains("ssh-manager bootstrap ok")); + } + + @Test + void bootstrapWithPasswordFailsWhenRemoteAuthorizationCommandFails() throws Exception { + ConnectionCreateRequest request = new ConnectionCreateRequest(); + request.setName("prod"); + request.setHost("127.0.0.1"); + request.setPort(22); + request.setUsername("root"); + request.setBootstrapPassword("bootstrap-secret"); + + when(sshService.executeCommandWithResult(any(Connection.class), anyString(), isNull(), isNull(), anyString())) + .thenReturn(new SshService.CommandResult("", "permission denied", 1)); + + InvalidOperationException exception = assertThrows( + InvalidOperationException.class, + () -> sshBootstrapService.bootstrapWithPassword(request, 1L) + ); + + assertTrue(exception.getMessage().contains("authorized_keys")); + assertTrue(exception.getMessage().contains("permission denied")); + } +} diff --git a/frontend/src/api/connections.ts b/frontend/src/api/connections.ts index 35a2820..691955c 100644 --- a/frontend/src/api/connections.ts +++ b/frontend/src/api/connections.ts @@ -1,6 +1,7 @@ import client from './client' -export type AuthType = 'PASSWORD' | 'PRIVATE_KEY' +export type AuthType = 'PASSWORD' | 'PRIVATE_KEY' +export type ConnectionSetupMode = 'NONE' | 'PASSWORD_BOOTSTRAP' export interface Connection { id: number @@ -18,11 +19,13 @@ export interface ConnectionCreateRequest { host: string port?: number username: string - authType?: AuthType - password?: string - privateKey?: string - passphrase?: string -} + authType?: AuthType + password?: string + privateKey?: string + passphrase?: string + setupMode?: ConnectionSetupMode + bootstrapPassword?: string +} export function listConnections() { return client.get('/connections') diff --git a/frontend/src/components/ConnectionForm.vue b/frontend/src/components/ConnectionForm.vue index b88a6a4..dd49a01 100644 --- a/frontend/src/components/ConnectionForm.vue +++ b/frontend/src/components/ConnectionForm.vue @@ -16,15 +16,18 @@ const emit = defineEmits<{ const name = ref('') const host = ref('') const port = ref(22) -const username = ref('') -const authType = ref('PASSWORD') -const password = ref('') -const privateKey = ref('') -const privateKeyFileName = ref('') -const privateKeyInputRef = ref(null) -const passphrase = ref('') - -const isEdit = computed(() => !!props.connection) +const username = ref('') +const authType = ref('PASSWORD') +const password = ref('') +const privateKey = ref('') +const privateKeyFileName = ref('') +const privateKeyInputRef = ref(null) +const passphrase = ref('') +const passwordBootstrapEnabled = ref(false) +const bootstrapPassword = ref('') + +const isEdit = computed(() => !!props.connection) +const isPasswordBootstrapMode = computed(() => !isEdit.value && passwordBootstrapEnabled.value) const hostError = computed(() => { const h = host.value.trim() @@ -54,21 +57,25 @@ watch( host.value = c.host port.value = c.port username.value = c.username - authType.value = c.authType + authType.value = c.authType password.value = '' privateKey.value = '' privateKeyFileName.value = '' passphrase.value = '' + passwordBootstrapEnabled.value = false + bootstrapPassword.value = '' } else { - name.value = '' - host.value = '' - port.value = 22 - username.value = '' - authType.value = 'PASSWORD' + name.value = '' + host.value = '' + port.value = 22 + username.value = '' + authType.value = 'PASSWORD' password.value = '' privateKey.value = '' privateKeyFileName.value = '' passphrase.value = '' + passwordBootstrapEnabled.value = false + bootstrapPassword.value = '' } }, { immediate: true } @@ -147,36 +154,45 @@ async function handleSubmit() { error.value = '请填写主机' return } - if (!username.value.trim()) { - error.value = '请填写用户名' - return - } - if (authType.value === 'PASSWORD' && !isEdit.value && !password.value) { - error.value = '请填写密码' - return - } - if (authType.value === 'PRIVATE_KEY' && !isEdit.value && !privateKey.value.trim()) { - error.value = '请填写私钥' - return - } + if (!username.value.trim()) { + error.value = '请填写用户名' + return + } + if (isPasswordBootstrapMode.value && !bootstrapPassword.value) { + error.value = '请填写初始登录密码' + return + } + if (!isPasswordBootstrapMode.value && authType.value === 'PASSWORD' && !isEdit.value && !password.value) { + error.value = '请填写密码' + return + } + if (!isPasswordBootstrapMode.value && authType.value === 'PRIVATE_KEY' && !isEdit.value && !privateKey.value.trim()) { + error.value = '请填写私钥' + return + } loading.value = true error.value = '' - const data: ConnectionCreateRequest = { - name: name.value.trim(), - host: host.value.trim(), - port: port.value, - username: username.value.trim(), - authType: authType.value, - } - if (authType.value === 'PASSWORD' && password.value) { - data.password = password.value - } - if (authType.value === 'PRIVATE_KEY') { - if (privateKey.value.trim()) data.privateKey = privateKey.value.trim() - if (passphrase.value) data.passphrase = passphrase.value - } - try { + const data: ConnectionCreateRequest = { + name: name.value.trim(), + host: host.value.trim(), + port: port.value, + username: username.value.trim(), + } + if (isPasswordBootstrapMode.value) { + data.setupMode = 'PASSWORD_BOOTSTRAP' + data.bootstrapPassword = bootstrapPassword.value + } else { + data.authType = authType.value + if (authType.value === 'PASSWORD' && password.value) { + data.password = password.value + } + if (authType.value === 'PRIVATE_KEY') { + if (privateKey.value.trim()) data.privateKey = privateKey.value.trim() + if (passphrase.value) data.passphrase = passphrase.value + } + } + try { if (props.onSave) { await props.onSave(data) } else { @@ -252,42 +268,71 @@ async function handleSubmit() {

{{ portError }}

-
- - + + -
-
- -
-
+
+ +
+
+ +
+ -
-
-
- + +
+
+
+
+

创建后将自动切换为私钥认证

+

+ 系统会生成一把新的 SSH 密钥,先用密码登录并写入远端 authorized_keys,验证成功后再保存连接。 +

+
+
+ + +
+
+
+ -
-
+ :placeholder="isEdit ? '••••••••' : ''" + /> +
+
diff --git a/frontend/src/layouts/MobaLayout.vue b/frontend/src/layouts/MobaLayout.vue index 7061dab..1a34849 100644 --- a/frontend/src/layouts/MobaLayout.vue +++ b/frontend/src/layouts/MobaLayout.vue @@ -252,7 +252,7 @@ async function handleSessionSubmit(data: ConnectionCreateRequest) { await connectionsStore.createConnection(data) // Tree node insertion is handled by useConnectionSync -> syncNewConnections. // Avoid manual insertion here to prevent duplicate nodes. - toast.success('连接已创建') + toast.success(data.setupMode === 'PASSWORD_BOOTSTRAP' ? '连接已创建并完成免密配置' : '连接已创建') showFirstRunGuide.value = false } closeSessionModal()