feat: add port forwarding and optimize connection status checks

This commit is contained in:
liumangmang
2026-06-11 14:10:30 +08:00
parent 4a17f0106e
commit e418e6ecc2
30 changed files with 1789 additions and 150 deletions
+29 -29
View File
@@ -53,35 +53,35 @@
<artifactId>jjwt-api</artifactId>
<version>0.11.5</version>
</dependency>
<dependency>
<groupId>io.jsonwebtoken</groupId>
<artifactId>jjwt-impl</artifactId>
<version>0.11.5</version>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>io.jsonwebtoken</groupId>
<artifactId>jjwt-jackson</artifactId>
<version>0.11.5</version>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<optional>true</optional>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.springframework.security</groupId>
<artifactId>spring-security-test</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<dependency>
<groupId>io.jsonwebtoken</groupId>
<artifactId>jjwt-impl</artifactId>
<version>0.11.5</version>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>io.jsonwebtoken</groupId>
<artifactId>jjwt-jackson</artifactId>
<version>0.11.5</version>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<optional>true</optional>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.springframework.security</groupId>
<artifactId>spring-security-test</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<resources>
<resource>
@@ -5,9 +5,7 @@ import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
@@ -35,14 +33,4 @@ public class WebSocketThreadPoolConfig {
);
return executor;
}
@Bean
public ScheduledExecutorService websocketCleanupScheduler() {
ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(1);
scheduler.scheduleAtFixedRate(this::cleanupIdleSessions, 30, 30, TimeUnit.MINUTES);
return scheduler;
}
private void cleanupIdleSessions() {
}
}
@@ -0,0 +1,163 @@
package com.sshmanager.controller;
import com.sshmanager.entity.Connection;
import com.sshmanager.entity.User;
import com.sshmanager.exception.AccessDeniedException;
import com.sshmanager.exception.NotFoundException;
import com.sshmanager.repository.ConnectionRepository;
import com.sshmanager.repository.UserRepository;
import com.sshmanager.service.ConnectionService;
import com.sshmanager.service.PortForwardRegistry;
import com.sshmanager.service.PortForwardRegistry.TunnelEntry;
import org.springframework.http.ResponseEntity;
import org.springframework.security.core.Authentication;
import org.springframework.web.bind.annotation.*;
import java.time.Instant;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
/**
* REST API for managing in-memory SSH port-forwarding tunnels.
*
* <ul>
* <li>GET /api/port-forwards list running tunnels for the current user</li>
* <li>POST /api/port-forwards create a new tunnel</li>
* <li>DELETE /api/port-forwards/{id} stop a tunnel</li>
* </ul>
*
* <p>All tunnels are ephemeral: they live only while the server is running.
*/
@RestController
@RequestMapping("/api/port-forwards")
public class PortForwardController {
private final PortForwardRegistry portForwardRegistry;
private final ConnectionRepository connectionRepository;
private final ConnectionService connectionService;
private final UserRepository userRepository;
public PortForwardController(PortForwardRegistry portForwardRegistry,
ConnectionRepository connectionRepository,
ConnectionService connectionService,
UserRepository userRepository) {
this.portForwardRegistry = portForwardRegistry;
this.connectionRepository = connectionRepository;
this.connectionService = connectionService;
this.userRepository = userRepository;
}
// ── helpers ───────────────────────────────────────────────────────────────
private Long getCurrentUserId(Authentication auth) {
User user = userRepository.findByUsername(auth.getName())
.orElseThrow(() -> new IllegalStateException("User not found"));
return user.getId();
}
private static Map<String, Object> toDto(TunnelEntry e) {
Map<String, Object> dto = new HashMap<>();
dto.put("id", e.getId());
dto.put("connectionId", e.getConnectionId());
dto.put("connectionName", e.getConnectionName());
dto.put("localPort", e.getLocalPort());
dto.put("remoteHost", e.getRemoteHost());
dto.put("remotePort", e.getRemotePort());
dto.put("status", e.getStatus().name().toLowerCase());
dto.put("createdAt", e.getCreatedAt().toString());
return dto;
}
// ── endpoints ─────────────────────────────────────────────────────────────
/** GET /api/port-forwards — list active tunnels for the authenticated user. */
@GetMapping
public ResponseEntity<List<Map<String, Object>>> list(Authentication auth) {
Long userId = getCurrentUserId(auth);
List<Map<String, Object>> dtos = portForwardRegistry.listByUser(userId).stream()
.map(PortForwardController::toDto)
.collect(Collectors.toList());
return ResponseEntity.ok(dtos);
}
/**
* POST /api/port-forwards — create a new port-forwarding tunnel.
*
* <p>Expected request body:
* <pre>
* {
* "connectionId": 42,
* "localPort": 8080,
* "remoteHost": "127.0.0.1",
* "remotePort": 3306
* }
* </pre>
*/
@PostMapping
public ResponseEntity<?> create(@RequestBody Map<String, Object> body, Authentication auth) {
Long userId = getCurrentUserId(auth);
// ── parse & validate request ──────────────────────────────────────────
Long connectionId;
int localPort, remotePort;
String remoteHost;
try {
connectionId = Long.valueOf(body.get("connectionId").toString());
localPort = Integer.parseInt(body.get("localPort").toString());
remotePort = Integer.parseInt(body.get("remotePort").toString());
remoteHost = body.get("remoteHost").toString().trim();
} catch (Exception e) {
Map<String, String> err = new HashMap<>();
err.put("error", "Invalid request body: connectionId, localPort, remoteHost and remotePort are required");
return ResponseEntity.badRequest().body(err);
}
// ── ownership check ───────────────────────────────────────────────────
Connection conn = connectionRepository.findById(connectionId).orElse(null);
if (conn == null) {
throw new NotFoundException("Connection not found: " + connectionId);
}
if (!conn.getUserId().equals(userId)) {
throw new AccessDeniedException("Access denied to connection: " + connectionId);
}
// ── decrypt credentials ───────────────────────────────────────────────
String password = connectionService.getDecryptedPassword(conn);
String privateKey = connectionService.getDecryptedPrivateKey(conn);
String passphrase = connectionService.getDecryptedPassphrase(conn);
// ── create tunnel ─────────────────────────────────────────────────────
try {
TunnelEntry entry = portForwardRegistry.create(
userId, conn, password, privateKey, passphrase,
localPort, remoteHost, remotePort);
return ResponseEntity.ok(toDto(entry));
} catch (IllegalArgumentException e) {
Map<String, String> err = new HashMap<>();
err.put("error", e.getMessage());
return ResponseEntity.badRequest().body(err);
} catch (Exception e) {
Map<String, String> err = new HashMap<>();
err.put("error", "Failed to create port-forward: " + e.getMessage());
return ResponseEntity.internalServerError().body(err);
}
}
/** DELETE /api/port-forwards/{id} — stop and remove a tunnel. */
@DeleteMapping("/{id}")
public ResponseEntity<Map<String, String>> stop(@PathVariable String id, Authentication auth) {
Long userId = getCurrentUserId(auth);
try {
portForwardRegistry.stop(id, userId);
Map<String, String> ok = new HashMap<>();
ok.put("message", "Tunnel stopped");
return ResponseEntity.ok(ok);
} catch (IllegalArgumentException e) {
Map<String, String> err = new HashMap<>();
err.put("error", e.getMessage());
return ResponseEntity.badRequest().body(err);
}
}
}
@@ -9,6 +9,8 @@ import com.sshmanager.service.QuickConnectionRegistry;
import com.sshmanager.service.QuickCredentialRegistry;
import com.sshmanager.service.SshService;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
@@ -34,7 +36,11 @@ public class TerminalWebSocketHandler extends TextWebSocketHandler {
private final QuickCredentialRegistry quickCredentials;
private final ExecutorService executor;
@Value("${sshmanager.terminal.idle-timeout-minutes:30}")
private long idleTimeoutMinutes;
private final AtomicInteger sessionCount = new AtomicInteger(0);
private final Map<String, WebSocketSession> wsSessions = new ConcurrentHashMap<>();
private final Map<String, SshService.SshSession> sessions = new ConcurrentHashMap<>();
private final Map<String, Long> lastActivity = new ConcurrentHashMap<>();
@@ -100,9 +106,15 @@ public class TerminalWebSocketHandler extends TextWebSocketHandler {
try {
SshService.SshSession sshSession = sshService.createShellSession(conn, password, privateKey, passphrase);
sessions.put(webSocketSession.getId(), sshSession);
wsSessions.put(webSocketSession.getId(), webSocketSession);
lastActivity.put(webSocketSession.getId(), System.currentTimeMillis());
sessionCount.incrementAndGet();
// Refresh the quick-connection TTL so an active terminal is not evicted mid-session
if (quickConnections.get(connectionId) != null) {
quickConnections.touch(connectionId);
}
executor.submit(() -> {
try {
InputStream in = sshSession.getOutputStream();
@@ -132,6 +144,12 @@ public class TerminalWebSocketHandler extends TextWebSocketHandler {
if (sshSession != null && sshSession.isConnected()) {
lastActivity.put(webSocketSession.getId(), System.currentTimeMillis());
// Touch the quick connection registry to keep metadata alive
Long connectionId = (Long) webSocketSession.getAttributes().get("connectionId");
if (connectionId != null) {
quickConnections.touch(connectionId);
}
String payload = message.getPayload();
TerminalControlMessage.parse(payload).ifPresent(ctrl -> {
if ("resize".equals(ctrl.getType()) && ctrl.getCols() != null && ctrl.getRows() != null) {
@@ -151,6 +169,7 @@ public class TerminalWebSocketHandler extends TextWebSocketHandler {
public void afterConnectionClosed(@NonNull WebSocketSession webSocketSession, @NonNull CloseStatus status) throws Exception {
SshService.SshSession sshSession = sessions.remove(webSocketSession.getId());
lastActivity.remove(webSocketSession.getId());
wsSessions.remove(webSocketSession.getId());
if (sshSession != null) {
sshSession.disconnect();
sessionCount.decrementAndGet();
@@ -165,4 +184,23 @@ public class TerminalWebSocketHandler extends TextWebSocketHandler {
}
}
}
@Scheduled(fixedDelay = 60000)
public void cleanupIdleSessions() {
long now = System.currentTimeMillis();
long maxIdleMillis = idleTimeoutMinutes * 60_000L;
lastActivity.forEach((sessionId, lastTime) -> {
if (now - lastTime > maxIdleMillis) {
WebSocketSession ws = wsSessions.get(sessionId);
if (ws != null && ws.isOpen()) {
try {
ws.sendMessage(new TextMessage("\r\n[Session idle timeout closed]\r\n"));
ws.close(CloseStatus.GOING_AWAY);
} catch (IOException e) {
// ignore
}
}
}
});
}
}
@@ -4,18 +4,43 @@ import com.sshmanager.dto.ConnectionStatusCheckRequest;
import com.sshmanager.dto.ConnectionStatusItemDto;
import com.sshmanager.dto.ConnectionStatusResponseDto;
import com.sshmanager.entity.Connection;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import javax.annotation.PreDestroy;
@Service
public class ConnectionStatusService {
private final ConnectionService connectionService;
private static final Logger log = LoggerFactory.getLogger(ConnectionStatusService.class);
private static final int MAX_CONCURRENCY = 20;
private static final int PROBE_TIMEOUT_SECONDS = 5;
public ConnectionStatusService(ConnectionService connectionService) {
private final ConnectionService connectionService;
private final TcpProbe tcpProbe;
private final ExecutorService executor;
public ConnectionStatusService(ConnectionService connectionService, TcpProbe tcpProbe) {
this.connectionService = connectionService;
this.tcpProbe = tcpProbe;
this.executor = Executors.newFixedThreadPool(MAX_CONCURRENCY, r -> {
Thread t = new Thread(r, "status-probe-" + r.hashCode());
t.setDaemon(true);
return t;
});
}
@PreDestroy
public void shutdown() {
executor.shutdownNow();
}
public ConnectionStatusResponseDto checkStatuses(Long userId, ConnectionStatusCheckRequest request) {
@@ -23,39 +48,28 @@ public class ConnectionStatusService {
throw new IllegalArgumentException("At least one connection is required");
}
List<ConnectionStatusItemDto> results = new ArrayList<ConnectionStatusItemDto>();
// Execute all probes in parallel with a bounded thread pool
List<CompletableFuture<ConnectionStatusItemDto>> futures = new ArrayList<>();
for (Long connectionId : request.getConnectionIds()) {
futures.add(CompletableFuture.supplyAsync(() -> probeConnection(connectionId, userId), executor));
}
List<ConnectionStatusItemDto> results = new ArrayList<>();
int onlineCount = 0;
int offlineCount = 0;
for (Long connectionId : request.getConnectionIds()) {
Connection connection = connectionService.getConnectionForSsh(connectionId, userId);
long startedAt = System.currentTimeMillis();
for (CompletableFuture<ConnectionStatusItemDto> future : futures) {
try {
connectionService.testConnection(
connection,
connectionService.getDecryptedPassword(connection),
connectionService.getDecryptedPrivateKey(connection),
connectionService.getDecryptedPassphrase(connection)
);
long durationMs = System.currentTimeMillis() - startedAt;
results.add(new ConnectionStatusItemDto(
connection.getId(),
connection.getName(),
"online",
"SSH connection available",
durationMs
));
onlineCount += 1;
} catch (Exception error) {
long durationMs = System.currentTimeMillis() - startedAt;
results.add(new ConnectionStatusItemDto(
connection.getId(),
connection.getName(),
"offline",
error.getMessage(),
durationMs
));
offlineCount += 1;
ConnectionStatusItemDto result = future.get(PROBE_TIMEOUT_SECONDS * 2, TimeUnit.SECONDS);
results.add(result);
if ("online".equals(result.getStatus())) {
onlineCount++;
} else {
offlineCount++;
}
} catch (Exception e) {
results.add(new ConnectionStatusItemDto(0L, null, "offline", "Probe timeout", 0));
offlineCount++;
}
}
@@ -66,4 +80,30 @@ public class ConnectionStatusService {
response.setResults(results);
return response;
}
private ConnectionStatusItemDto probeConnection(Long connectionId, Long userId) {
long startedAt = System.currentTimeMillis();
try {
Connection connection = connectionService.getConnectionForSsh(connectionId, userId);
// Quick TCP probe — no SSH handshake, just checks port liveness
tcpProbe.checkReachable(connection, PROBE_TIMEOUT_SECONDS);
long durationMs = System.currentTimeMillis() - startedAt;
return new ConnectionStatusItemDto(
connection.getId(), connection.getName(),
"online", "TCP port reachable", durationMs
);
} catch (Exception error) {
long durationMs = System.currentTimeMillis() - startedAt;
Long id = connectionId;
String name = null;
try {
Connection conn = connectionService.getConnectionForSsh(connectionId, userId);
id = conn.getId();
name = conn.getName();
} catch (Exception ignored) {}
return new ConnectionStatusItemDto(
id, name, "offline", error.getMessage(), durationMs
);
}
}
}
@@ -0,0 +1,176 @@
package com.sshmanager.service;
import com.jcraft.jsch.Session;
import com.sshmanager.entity.Connection;
import com.sshmanager.util.JschUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;
import java.time.Instant;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
/**
* In-memory registry for active SSH port-forwarding tunnels.
*
* <p>All tunnels are ephemeral — they exist only for the lifetime of the server
* process. Restarting the server clears all tunnels automatically.
*
* <p>One JSch {@link Session} is kept per tunnel. When a tunnel is stopped the
* session is disconnected and the local port is freed.
*/
@Component
public class PortForwardRegistry {
private static final Logger log = LoggerFactory.getLogger(PortForwardRegistry.class);
// ── public data model ────────────────────────────────────────────────────
public enum TunnelStatus { RUNNING, STOPPED, ERROR }
public static class TunnelEntry {
private final String id;
private final Long userId;
private final Long connectionId;
private final String connectionName;
private final int localPort;
private final String remoteHost;
private final int remotePort;
private final Instant createdAt;
private volatile TunnelStatus status;
// not exposed in DTO
final Session jschSession;
public TunnelEntry(String id, Long userId, Long connectionId, String connectionName,
int localPort, String remoteHost, int remotePort,
Session jschSession) {
this.id = id;
this.userId = userId;
this.connectionId = connectionId;
this.connectionName = connectionName;
this.localPort = localPort;
this.remoteHost = remoteHost;
this.remotePort = remotePort;
this.jschSession = jschSession;
this.createdAt = Instant.now();
this.status = TunnelStatus.RUNNING;
}
public String getId() { return id; }
public Long getUserId() { return userId; }
public Long getConnectionId() { return connectionId; }
public String getConnectionName(){ return connectionName; }
public int getLocalPort() { return localPort; }
public String getRemoteHost() { return remoteHost; }
public int getRemotePort() { return remotePort; }
public Instant getCreatedAt() { return createdAt; }
public TunnelStatus getStatus() { return status; }
void setStatus(TunnelStatus s) { this.status = s; }
}
// ── state ────────────────────────────────────────────────────────────────
private final Map<String, TunnelEntry> tunnels = new ConcurrentHashMap<>();
private final SshSessionFactory sshSessionFactory;
public PortForwardRegistry(SshSessionFactory sshSessionFactory) {
this.sshSessionFactory = sshSessionFactory;
}
// ── public API ───────────────────────────────────────────────────────────
/**
* Create and register a new port-forwarding tunnel.
*
* @param userId owner's user ID
* @param conn the SSH connection entity to tunnel through
* @param password decrypted password (null for key auth)
* @param privateKey decrypted private key PEM (null for password auth)
* @param passphrase key passphrase (null if none)
* @param localPort local TCP port to bind (165535)
* @param remoteHost the host reachable from the SSH server to forward to
* @param remotePort the port on remoteHost to forward to (165535)
* @return the registered {@link TunnelEntry}
* @throws IllegalArgumentException if port numbers are out of range or remoteHost is blank
* @throws Exception if the SSH session or port-forward setup fails
*/
public TunnelEntry create(Long userId, Connection conn,
String password, String privateKey, String passphrase,
int localPort, String remoteHost, int remotePort) throws Exception {
validatePorts(localPort, remotePort);
if (remoteHost == null || remoteHost.trim().isEmpty()) {
throw new IllegalArgumentException("remoteHost must not be blank");
}
Session session = sshSessionFactory.createSession(conn, password, privateKey, passphrase);
try {
session.setPortForwardingL(localPort, remoteHost, remotePort);
} catch (Exception e) {
session.disconnect();
throw e;
}
String id = UUID.randomUUID().toString().replace("-", "");
TunnelEntry entry = new TunnelEntry(id, userId, conn.getId(), conn.getName(),
localPort, remoteHost, remotePort, session);
tunnels.put(id, entry);
log.info("Port-forward started: id={} user={} {}:{}:{} via connection={}",
id, userId, localPort, remoteHost, remotePort, conn.getId());
return entry;
}
/**
* Stop and remove a tunnel.
*
* @param id tunnel ID
* @param userId caller's user ID — must match the tunnel owner
* @throws IllegalArgumentException if the tunnel is not found or not owned by this user
*/
public void stop(String id, Long userId) {
TunnelEntry entry = tunnels.get(id);
if (entry == null) {
throw new IllegalArgumentException("Port-forward tunnel not found: " + id);
}
if (!entry.getUserId().equals(userId)) {
throw new IllegalArgumentException("Access denied to tunnel: " + id);
}
tunnels.remove(id);
try {
entry.jschSession.delPortForwardingL(entry.getLocalPort());
} catch (Exception ignored) {
// best-effort cancel
}
if (entry.jschSession.isConnected()) {
entry.jschSession.disconnect();
}
entry.setStatus(TunnelStatus.STOPPED);
log.info("Port-forward stopped: id={} user={} localPort={}", id, userId, entry.getLocalPort());
}
/** List all active tunnels belonging to a user. */
public List<TunnelEntry> listByUser(Long userId) {
List<TunnelEntry> result = new ArrayList<>();
for (TunnelEntry e : tunnels.values()) {
if (e.getUserId().equals(userId)) {
result.add(e);
}
}
result.sort((a, b) -> a.getCreatedAt().compareTo(b.getCreatedAt()));
return result;
}
// ── helpers ──────────────────────────────────────────────────────────────
private static void validatePorts(int localPort, int remotePort) {
if (localPort < 1 || localPort > 65535) {
throw new IllegalArgumentException("localPort must be in range 1-65535, got: " + localPort);
}
if (remotePort < 1 || remotePort > 65535) {
throw new IllegalArgumentException("remotePort must be in range 1-65535, got: " + remotePort);
}
}
}
@@ -1,6 +1,10 @@
package com.sshmanager.service;
import com.sshmanager.entity.Connection;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Component;
import java.time.Instant;
@@ -10,15 +14,30 @@ import java.util.concurrent.atomic.AtomicLong;
/**
* In-memory registry for quick-connect (ephemeral) SSH connections.
* Entries are NOT persisted to the database and are cleaned up
* when the WebSocket session closes.
*
* <p>Entries are NOT persisted to the database. They are cleaned up either:
* <ul>
* <li>immediately when the WebSocket session closes ({@link #remove}), or</li>
* <li>by the background TTL sweep ({@link #evictExpired}) after
* {@code sshmanager.quick-connection.ttl-minutes} minutes of inactivity.</li>
* </ul>
*
* <p>Call {@link #touch} after a successful WebSocket handshake to reset the
* last-active timestamp and prevent premature eviction of long-lived sessions.
*/
@Component
public class QuickConnectionRegistry {
private static final Logger log = LoggerFactory.getLogger(QuickConnectionRegistry.class);
@Value("${sshmanager.quick-connection.ttl-minutes:30}")
private long ttlMinutes;
private final AtomicLong idGen = new AtomicLong(10_000_000);
private final Map<Long, Entry> entries = new ConcurrentHashMap<>();
// ── public API ───────────────────────────────────────────────────────────
public Connection create(String host, String username, int port, Long userId) {
long id = idGen.incrementAndGet();
Connection conn = new Connection();
@@ -36,11 +55,25 @@ public class QuickConnectionRegistry {
return conn;
}
/** Retrieve a quick connection; returns null if not found or already expired. */
public Connection get(Long id) {
Entry entry = entries.get(id);
return entry != null ? entry.connection : null;
}
/**
* Refresh the last-active timestamp for a quick connection.
* Call this after a WebSocket session is successfully established so that
* the entry is not evicted while the terminal is still in use.
*/
public void touch(Long id) {
Entry entry = entries.get(id);
if (entry != null) {
entry.lastAccessAt = System.currentTimeMillis();
}
}
/** Immediately remove a quick connection (called on WebSocket close). */
public void remove(Long id) {
entries.remove(id);
}
@@ -49,9 +82,41 @@ public class QuickConnectionRegistry {
return entries.size();
}
// ── background cleanup ───────────────────────────────────────────────────
/**
* Scheduled eviction of stale quick connections.
*
* <p>The interval is controlled by
* {@code sshmanager.quick-connection.cleanup-interval-ms} (default 60 s).
* An entry is evicted when its last-active time exceeds
* {@code sshmanager.quick-connection.ttl-minutes}.
*/
@Scheduled(fixedDelayString = "${sshmanager.quick-connection.cleanup-interval-ms:60000}")
public void evictExpired() {
long ttlMillis = ttlMinutes * 60_000L;
long now = System.currentTimeMillis();
int[] removed = {0};
entries.entrySet().removeIf(e -> {
boolean expired = (now - e.getValue().lastAccessAt) > ttlMillis;
if (expired) {
removed[0]++;
log.info("Evicting stale quick connection id={} (inactive > {} min)",
e.getKey(), ttlMinutes);
}
return expired;
});
if (removed[0] > 0) {
log.info("Quick-connection TTL sweep removed {} entr{}, {} remaining",
removed[0], removed[0] == 1 ? "y" : "ies", entries.size());
}
}
// ── internal entry ───────────────────────────────────────────────────────
private static class Entry {
final Connection connection;
final long createdAt = System.currentTimeMillis();
volatile long lastAccessAt = System.currentTimeMillis();
Entry(Connection connection) {
this.connection = connection;
@@ -0,0 +1,27 @@
package com.sshmanager.service;
import com.sshmanager.entity.Connection;
import org.springframework.stereotype.Component;
import java.net.InetSocketAddress;
import java.net.Socket;
/**
* Production TCP probe implementation.
* Opens a raw TCP socket to the target host:port — no SSH handshake.
*/
@Component
public class RealTcpProbe implements TcpProbe {
@Override
public void checkReachable(Connection conn, int timeoutSeconds) throws Exception {
Socket socket = new Socket();
try {
socket.connect(new InetSocketAddress(conn.getHost(), conn.getPort()), timeoutSeconds * 1000);
} catch (Exception e) {
try { socket.close(); } catch (Exception ignored) {}
throw e;
}
socket.close();
}
}
@@ -0,0 +1,24 @@
package com.sshmanager.service;
import com.jcraft.jsch.Session;
import com.sshmanager.entity.Connection;
import com.sshmanager.util.JschUtil;
import org.springframework.stereotype.Component;
/**
* A Spring-managed factory for JSch sessions.
*
* <p>Wraps the static utility calls to {@link JschUtil#createSession} so that
* dependent components can be easily tested via dependency injection and standard mocking.
*/
@Component
public class SshSessionFactory {
/**
* Create and connect an SSH session.
*/
public Session createSession(Connection conn, String password, String privateKey, String passphrase)
throws Exception {
return JschUtil.createSession(conn, password, privateKey, passphrase);
}
}
@@ -0,0 +1,12 @@
package com.sshmanager.service;
import com.sshmanager.entity.Connection;
/**
* Abstraction for TCP connectivity probing.
* Production: {@link RealTcpProbe} uses raw Socket.connect().
* Tests: mock to return success/failure without real network.
*/
public interface TcpProbe {
void checkReachable(Connection conn, int timeoutSeconds) throws Exception;
}
@@ -39,8 +39,17 @@ sshmanager:
jwt-expiration-ms: 86400000
password-expiration-days: ${SSHMANAGER_PASSWORD_EXPIRATION_DAYS:90}
terminal:
# Idle timeout threshold for active terminal websocket sessions, in minutes.
# Disconnects the underlying SSH session and closes the terminal on timeout. Default: 30.
idle-timeout-minutes: 30
websocket:
thread-pool:
core-size: 10
max-size: 50
keep-alive-seconds: 60
quick-connection:
# Time-to-live for idle quick (ephemeral) SSH connections, in minutes.
# An active WebSocket terminal resets the timer. Default: 30 minutes.
ttl-minutes: 30
# How often to run the TTL cleanup sweep, in milliseconds. Default: 60 s.
cleanup-interval-ms: 60000
@@ -1 +1 @@
import{c as r,j as e}from"./index-Z2D8CQl5.js";const l=[["circle",{cx:"12",cy:"12",r:"10",key:"1mglay"}],["line",{x1:"12",x2:"12",y1:"8",y2:"12",key:"1pkeuh"}],["line",{x1:"12",x2:"12.01",y1:"16",y2:"16",key:"4dfq90"}]],m=r("circle-alert",l);const o=[["path",{d:"M10 11v6",key:"nco0om"}],["path",{d:"M14 11v6",key:"outv1u"}],["path",{d:"M19 6v14a2 2 0 0 1-2 2H7a2 2 0 0 1-2-2V6",key:"miytrc"}],["path",{d:"M3 6h18",key:"d0wm0j"}],["path",{d:"M8 6V4a2 2 0 0 1 2-2h4a2 2 0 0 1 2 2v2",key:"e791ji"}]],b=r("trash-2",o);const i=[["path",{d:"M18 6 6 18",key:"1bl5f8"}],["path",{d:"m6 6 12 12",key:"d8bk6v"}]],x=r("x",i);function h({title:c,onClose:t,children:d,footer:a,maxWidth:s="max-w-3xl",open:n=!0}){return n?e.jsx("div",{className:"fixed inset-0 z-50 flex items-center justify-center bg-black/70 p-4 backdrop-blur-sm",children:e.jsxs("div",{className:`flex max-h-[92vh] w-full flex-col overflow-hidden rounded-3xl border border-border-main bg-surface-card ${s}`,children:[e.jsxs("div",{className:"flex items-center justify-between border-b border-border-subtle bg-surface-card/90 px-5 py-4",children:[e.jsx("h3",{className:"text-lg font-medium text-content-main",children:c}),t?e.jsx("button",{onClick:t,className:"rounded-xl border border-border-main bg-surface-muted p-2 text-content-muted transition hover:text-content-main",children:e.jsx(x,{size:18})}):null]}),e.jsx("div",{className:"flex-1 overflow-y-auto p-6",children:d}),a?e.jsx("div",{className:"flex justify-end gap-3 border-t border-border-subtle bg-surface-card/90 px-5 py-4",children:a}):null]})}):null}export{m as C,h as M,b as T,x as X};
import{c as r,j as e}from"./index-BQbRYAGj.js";const l=[["circle",{cx:"12",cy:"12",r:"10",key:"1mglay"}],["line",{x1:"12",x2:"12",y1:"8",y2:"12",key:"1pkeuh"}],["line",{x1:"12",x2:"12.01",y1:"16",y2:"16",key:"4dfq90"}]],m=r("circle-alert",l);const o=[["path",{d:"M10 11v6",key:"nco0om"}],["path",{d:"M14 11v6",key:"outv1u"}],["path",{d:"M19 6v14a2 2 0 0 1-2 2H7a2 2 0 0 1-2-2V6",key:"miytrc"}],["path",{d:"M3 6h18",key:"d0wm0j"}],["path",{d:"M8 6V4a2 2 0 0 1 2-2h4a2 2 0 0 1 2 2v2",key:"e791ji"}]],b=r("trash-2",o);const i=[["path",{d:"M18 6 6 18",key:"1bl5f8"}],["path",{d:"m6 6 12 12",key:"d8bk6v"}]],x=r("x",i);function h({title:c,onClose:t,children:d,footer:a,maxWidth:s="max-w-3xl",open:n=!0}){return n?e.jsx("div",{className:"fixed inset-0 z-50 flex items-center justify-center bg-black/70 p-4 backdrop-blur-sm",children:e.jsxs("div",{className:`flex max-h-[92vh] w-full flex-col overflow-hidden rounded-3xl border border-border-main bg-surface-card ${s}`,children:[e.jsxs("div",{className:"flex items-center justify-between border-b border-border-subtle bg-surface-card/90 px-5 py-4",children:[e.jsx("h3",{className:"text-lg font-medium text-content-main",children:c}),t?e.jsx("button",{onClick:t,className:"rounded-xl border border-border-main bg-surface-muted p-2 text-content-muted transition hover:text-content-main",children:e.jsx(x,{size:18})}):null]}),e.jsx("div",{className:"flex-1 overflow-y-auto p-6",children:d}),a?e.jsx("div",{className:"flex justify-end gap-3 border-t border-border-subtle bg-surface-card/90 px-5 py-4",children:a}):null]})}):null}export{m as C,h as M,b as T,x as X};
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -1 +1 @@
import{c as o,h as t}from"./index-Z2D8CQl5.js";const s=[["path",{d:"M22 12h-2.48a2 2 0 0 0-1.93 1.46l-2.35 8.36a.25.25 0 0 1-.48 0L9.24 2.18a.25.25 0 0 0-.48 0l-2.35 8.36A2 2 0 0 1 4.49 12H2",key:"169zse"}]],y=o("activity",s);const r=[["rect",{width:"20",height:"14",x:"2",y:"3",rx:"2",key:"48i651"}],["line",{x1:"8",x2:"16",y1:"21",y2:"21",key:"1svkeh"}],["line",{x1:"12",x2:"12",y1:"17",y2:"21",key:"vw1qmm"}]],h=o("monitor",r);const a=[["rect",{width:"20",height:"8",x:"2",y:"2",rx:"2",ry:"2",key:"ngkwjq"}],["rect",{width:"20",height:"8",x:"2",y:"14",rx:"2",ry:"2",key:"iecqi9"}],["line",{x1:"6",x2:"6.01",y1:"6",y2:"6",key:"16zg32"}],["line",{x1:"6",x2:"6.01",y1:"18",y2:"18",key:"nzw8ys"}]],x=o("server",a);function d(){return t.get("/connections")}function k(n){return t.post("/connections",n)}function p(n,e){return t.put(`/connections/${n}`,e)}function g(n){return t.delete(`/connections/${n}`)}function l(n,e){return t.post("/connections/batch-command",{connectionIds:n,command:e})}function f(n){return t.post("/connections/status",{connectionIds:n})}function m(n){return t.put(`/connections/${n}/pin`)}function C(n,e,c,i){return t.post("/connections/quick-connect",{host:n,username:e,port:c,password:i})}function q(n){return t.get(`/monitor/${n}`)}export{y as A,h as M,x as S,k as a,f as c,g as d,l as e,q as g,d as l,C as q,m as t,p as u};
import{c as o,h as t}from"./index-BQbRYAGj.js";const s=[["path",{d:"M22 12h-2.48a2 2 0 0 0-1.93 1.46l-2.35 8.36a.25.25 0 0 1-.48 0L9.24 2.18a.25.25 0 0 0-.48 0l-2.35 8.36A2 2 0 0 1 4.49 12H2",key:"169zse"}]],y=o("activity",s);const r=[["rect",{width:"20",height:"14",x:"2",y:"3",rx:"2",key:"48i651"}],["line",{x1:"8",x2:"16",y1:"21",y2:"21",key:"1svkeh"}],["line",{x1:"12",x2:"12",y1:"17",y2:"21",key:"vw1qmm"}]],h=o("monitor",r);const a=[["rect",{width:"20",height:"8",x:"2",y:"2",rx:"2",ry:"2",key:"ngkwjq"}],["rect",{width:"20",height:"8",x:"2",y:"14",rx:"2",ry:"2",key:"iecqi9"}],["line",{x1:"6",x2:"6.01",y1:"6",y2:"6",key:"16zg32"}],["line",{x1:"6",x2:"6.01",y1:"18",y2:"18",key:"nzw8ys"}]],x=o("server",a);function d(){return t.get("/connections")}function k(n){return t.post("/connections",n)}function p(n,e){return t.put(`/connections/${n}`,e)}function g(n){return t.delete(`/connections/${n}`)}function l(n,e){return t.post("/connections/batch-command",{connectionIds:n,command:e})}function f(n){return t.post("/connections/status",{connectionIds:n})}function m(n){return t.put(`/connections/${n}/pin`)}function C(n,e,c,i){return t.post("/connections/quick-connect",{host:n,username:e,port:c,password:i})}function q(n){return t.get(`/monitor/${n}`)}export{y as A,h as M,x as S,k as a,f as c,g as d,l as e,q as g,d as l,C as q,m as t,p as u};
+2 -2
View File
@@ -11,8 +11,8 @@
rel="stylesheet"
/>
<title>SSH Manager</title>
<script type="module" crossorigin src="/assets/index-Z2D8CQl5.js"></script>
<link rel="stylesheet" crossorigin href="/assets/index-B4Duc4SL.css">
<script type="module" crossorigin src="/assets/index-BQbRYAGj.js"></script>
<link rel="stylesheet" crossorigin href="/assets/index-CPovcnGC.css">
</head>
<body>
<div id="app"></div>
@@ -0,0 +1,204 @@
package com.sshmanager.controller;
import com.sshmanager.entity.Connection;
import com.sshmanager.entity.User;
import com.sshmanager.exception.AccessDeniedException;
import com.sshmanager.exception.NotFoundException;
import com.sshmanager.repository.ConnectionRepository;
import com.sshmanager.repository.UserRepository;
import com.sshmanager.service.ConnectionService;
import com.sshmanager.service.PortForwardRegistry;
import com.sshmanager.service.PortForwardRegistry.TunnelEntry;
import com.sshmanager.service.PortForwardRegistry.TunnelStatus;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.http.ResponseEntity;
import org.springframework.security.core.Authentication;
import java.util.*;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.Mockito.*;
/**
* Unit tests for {@link PortForwardController}.
*/
@ExtendWith(MockitoExtension.class)
@SuppressWarnings("unchecked")
class PortForwardControllerTest {
@Mock
private PortForwardRegistry portForwardRegistry;
@Mock
private ConnectionRepository connectionRepository;
@Mock
private ConnectionService connectionService;
@Mock
private UserRepository userRepository;
@InjectMocks
private PortForwardController portForwardController;
private Authentication authentication;
private User testUser;
private Connection connection;
@BeforeEach
void setUp() {
authentication = mock(Authentication.class);
when(authentication.getName()).thenReturn("testuser");
testUser = new User();
testUser.setId(100L);
testUser.setUsername("testuser");
when(userRepository.findByUsername("testuser")).thenReturn(Optional.of(testUser));
connection = new Connection();
connection.setId(1L);
connection.setUserId(100L);
connection.setName("test-conn");
}
@Test
void list_returnsActiveTunnels() {
TunnelEntry mockEntry = mock(TunnelEntry.class);
when(mockEntry.getId()).thenReturn("tunnel-id");
when(mockEntry.getConnectionId()).thenReturn(1L);
when(mockEntry.getConnectionName()).thenReturn("test-conn");
when(mockEntry.getLocalPort()).thenReturn(8080);
when(mockEntry.getRemoteHost()).thenReturn("127.0.0.1");
when(mockEntry.getRemotePort()).thenReturn(3306);
when(mockEntry.getStatus()).thenReturn(TunnelStatus.RUNNING);
when(mockEntry.getCreatedAt()).thenReturn(java.time.Instant.now());
when(portForwardRegistry.listByUser(100L)).thenReturn(Collections.singletonList(mockEntry));
ResponseEntity<List<Map<String, Object>>> response = (ResponseEntity<List<Map<String, Object>>>) (Object) portForwardController.list(authentication);
assertEquals(200, response.getStatusCode().value());
List<Map<String, Object>> body = response.getBody();
assertNotNull(body);
assertEquals(1, body.size());
assertEquals("tunnel-id", body.get(0).get("id"));
}
@Test
void create_success() throws Exception {
Map<String, Object> body = new HashMap<>();
body.put("connectionId", 1L);
body.put("localPort", 8080);
body.put("remoteHost", "127.0.0.1");
body.put("remotePort", 3306);
when(connectionRepository.findById(1L)).thenReturn(Optional.of(connection));
when(connectionService.getDecryptedPassword(connection)).thenReturn("password");
TunnelEntry mockEntry = mock(TunnelEntry.class);
when(mockEntry.getId()).thenReturn("tunnel-id");
when(mockEntry.getConnectionId()).thenReturn(1L);
when(mockEntry.getConnectionName()).thenReturn("test-conn");
when(mockEntry.getLocalPort()).thenReturn(8080);
when(mockEntry.getRemoteHost()).thenReturn("127.0.0.1");
when(mockEntry.getRemotePort()).thenReturn(3306);
when(mockEntry.getStatus()).thenReturn(TunnelStatus.RUNNING);
when(mockEntry.getCreatedAt()).thenReturn(java.time.Instant.now());
when(portForwardRegistry.create(eq(100L), eq(connection), eq("password"), any(), any(), eq(8080), eq("127.0.0.1"), eq(3306)))
.thenReturn(mockEntry);
ResponseEntity<?> response = portForwardController.create(body, authentication);
assertEquals(200, response.getStatusCode().value());
Map<String, Object> respBody = (Map<String, Object>) response.getBody();
assertNotNull(respBody);
assertEquals("tunnel-id", respBody.get("id"));
}
@Test
void create_invalidRequestBody() {
Map<String, Object> body = new HashMap<>();
// missing connectionId and ports
ResponseEntity<?> response = portForwardController.create(body, authentication);
assertEquals(400, response.getStatusCode().value());
Map<String, String> respBody = (Map<String, String>) response.getBody();
assertNotNull(respBody);
assertTrue(respBody.get("error").contains("Invalid request body"));
}
@Test
void create_connectionNotFound() {
Map<String, Object> body = new HashMap<>();
body.put("connectionId", 999L);
body.put("localPort", 8080);
body.put("remoteHost", "127.0.0.1");
body.put("remotePort", 3306);
when(connectionRepository.findById(999L)).thenReturn(Optional.empty());
assertThrows(NotFoundException.class, () ->
portForwardController.create(body, authentication));
}
@Test
void create_accessDenied() {
Map<String, Object> body = new HashMap<>();
body.put("connectionId", 1L);
body.put("localPort", 8080);
body.put("remoteHost", "127.0.0.1");
body.put("remotePort", 3306);
Connection otherUserConnection = new Connection();
otherUserConnection.setId(1L);
otherUserConnection.setUserId(200L); // different owner
when(connectionRepository.findById(1L)).thenReturn(Optional.of(otherUserConnection));
assertThrows(AccessDeniedException.class, () ->
portForwardController.create(body, authentication));
}
@Test
void create_internalServerError() throws Exception {
Map<String, Object> body = new HashMap<>();
body.put("connectionId", 1L);
body.put("localPort", 8080);
body.put("remoteHost", "127.0.0.1");
body.put("remotePort", 3306);
when(connectionRepository.findById(1L)).thenReturn(Optional.of(connection));
when(portForwardRegistry.create(any(), any(), any(), any(), any(), anyInt(), anyString(), anyInt()))
.thenThrow(new RuntimeException("SSH failed"));
ResponseEntity<?> response = portForwardController.create(body, authentication);
assertEquals(500, response.getStatusCode().value());
Map<String, String> respBody = (Map<String, String>) response.getBody();
assertNotNull(respBody);
assertEquals("Failed to create port-forward: SSH failed", respBody.get("error"));
}
@Test
void stop_success() {
ResponseEntity<Map<String, String>> response = portForwardController.stop("tunnel-id", authentication);
assertEquals(200, response.getStatusCode().value());
verify(portForwardRegistry).stop("tunnel-id", 100L);
}
@Test
void stop_throwsIllegalArgumentException() {
doThrow(new IllegalArgumentException("Tunnel not found"))
.when(portForwardRegistry).stop("nonexistent-id", 100L);
ResponseEntity<Map<String, String>> response = portForwardController.stop("nonexistent-id", authentication);
assertEquals(400, response.getStatusCode().value());
assertEquals("Tunnel not found", response.getBody().get("error"));
}
}
@@ -60,6 +60,56 @@ class TerminalWebSocketHandlerTest {
verify(sshSession, never()).getInputStream();
}
@Test
@SuppressWarnings("unchecked")
void cleanupIdleSessionsClosesExpiredWebSockets() throws Exception {
ConnectionRepository connectionRepository = mock(ConnectionRepository.class);
UserRepository userRepository = mock(UserRepository.class);
ConnectionService connectionService = mock(ConnectionService.class);
SshService sshService = mock(SshService.class);
QuickConnectionRegistry quickConnections = mock(QuickConnectionRegistry.class);
QuickCredentialRegistry quickCredentials = mock(QuickCredentialRegistry.class);
ExecutorService executor = mock(ExecutorService.class);
TerminalWebSocketHandler handler = new TerminalWebSocketHandler(
connectionRepository,
userRepository,
connectionService,
sshService,
quickConnections,
quickCredentials,
executor
);
Field wsSessionsField = TerminalWebSocketHandler.class.getDeclaredField("wsSessions");
wsSessionsField.setAccessible(true);
Map<String, WebSocketSession> wsSessions = (Map<String, WebSocketSession>) wsSessionsField.get(handler);
Field lastActivityField = TerminalWebSocketHandler.class.getDeclaredField("lastActivity");
lastActivityField.setAccessible(true);
Map<String, Long> lastActivity = (Map<String, Long>) lastActivityField.get(handler);
Field idleTimeoutField = TerminalWebSocketHandler.class.getDeclaredField("idleTimeoutMinutes");
idleTimeoutField.setAccessible(true);
idleTimeoutField.set(handler, 30L);
WebSocketSession ws1 = mock(WebSocketSession.class);
when(ws1.isOpen()).thenReturn(true);
WebSocketSession ws2 = mock(WebSocketSession.class);
wsSessions.put("ws1", ws1);
lastActivity.put("ws1", System.currentTimeMillis() - (31 * 60 * 1000L));
wsSessions.put("ws2", ws2);
lastActivity.put("ws2", System.currentTimeMillis() - (10 * 60 * 1000L));
handler.cleanupIdleSessions();
verify(ws1).close(org.springframework.web.socket.CloseStatus.GOING_AWAY);
verify(ws2, never()).close(org.springframework.web.socket.CloseStatus.GOING_AWAY);
}
@SuppressWarnings("unchecked")
private Map<String, SshService.SshSession> sessionsMap(TerminalWebSocketHandler handler) throws Exception {
Field f = TerminalWebSocketHandler.class.getDeclaredField("sessions");
@@ -3,19 +3,20 @@ package com.sshmanager.service;
import com.sshmanager.dto.ConnectionStatusCheckRequest;
import com.sshmanager.dto.ConnectionStatusResponseDto;
import com.sshmanager.entity.Connection;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import java.util.Arrays;
import java.util.Collections;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.doReturn;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doThrow;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class)
@@ -24,49 +25,128 @@ class ConnectionStatusServiceTest {
@Mock
private ConnectionService connectionService;
@InjectMocks
@Mock
private TcpProbe tcpProbe;
private ConnectionStatusService connectionStatusService;
@Test
void checkStatusesAggregatesOnlineAndOfflineResults() {
Connection onlineConnection = new Connection();
onlineConnection.setId(1L);
onlineConnection.setUserId(99L);
onlineConnection.setName("prod");
@BeforeEach
void setUp() {
connectionStatusService = new ConnectionStatusService(connectionService, tcpProbe);
}
Connection offlineConnection = new Connection();
offlineConnection.setId(2L);
offlineConnection.setUserId(99L);
offlineConnection.setName("test");
@Test
void checkStatusesRejectsEmptyConnectionIds() {
ConnectionStatusCheckRequest request = new ConnectionStatusCheckRequest();
IllegalArgumentException error = assertThrows(
IllegalArgumentException.class,
() -> connectionStatusService.checkStatuses(1L, request)
);
assertEquals("At least one connection is required", error.getMessage());
}
@Test
void checkStatusesMarksReachableAsOnline() throws Exception {
Connection conn = new Connection();
conn.setId(1L);
conn.setUserId(99L);
conn.setName("reachable-host");
conn.setHost("10.0.0.1");
conn.setPort(22);
when(connectionService.getConnectionForSsh(1L, 99L)).thenReturn(conn);
ConnectionStatusCheckRequest request = new ConnectionStatusCheckRequest();
request.setConnectionIds(Arrays.asList(1L, 2L));
request.setConnectionIds(Arrays.asList(1L));
when(connectionService.getConnectionForSsh(1L, 99L)).thenReturn(onlineConnection);
when(connectionService.getConnectionForSsh(2L, 99L)).thenReturn(offlineConnection);
doReturn(onlineConnection).when(connectionService).testConnection(eq(onlineConnection), eq(null), eq(null), eq(null));
doThrow(new RuntimeException("Connection refused")).when(connectionService).testConnection(eq(offlineConnection), eq(null), eq(null), eq(null));
ConnectionStatusResponseDto response = connectionStatusService.checkStatuses(99L, request);
assertEquals(1, response.getTotal());
assertEquals(1, response.getOnlineCount());
assertEquals(0, response.getOfflineCount());
assertEquals("online", response.getResults().get(0).getStatus());
assertEquals("TCP port reachable", response.getResults().get(0).getMessage());
}
@Test
void checkStatusesMarksUnreachableAsOffline() throws Exception {
Connection conn = new Connection();
conn.setId(2L);
conn.setUserId(99L);
conn.setName("unreachable-host");
conn.setHost("10.0.0.2");
conn.setPort(22);
when(connectionService.getConnectionForSsh(2L, 99L)).thenReturn(conn);
doAnswer(invocation -> { throw new RuntimeException("Connection refused"); })
.when(tcpProbe).checkReachable(conn, 5);
ConnectionStatusCheckRequest request = new ConnectionStatusCheckRequest();
request.setConnectionIds(Arrays.asList(2L));
ConnectionStatusResponseDto response = connectionStatusService.checkStatuses(99L, request);
assertEquals(1, response.getTotal());
assertEquals(0, response.getOnlineCount());
assertEquals(1, response.getOfflineCount());
assertEquals("offline", response.getResults().get(0).getStatus());
}
@Test
void checkStatusesAggregatesMixedResults() throws Exception {
Connection online = new Connection();
online.setId(10L);
online.setUserId(99L);
online.setName("server-a");
online.setHost("10.0.0.1");
online.setPort(22);
Connection offline = new Connection();
offline.setId(20L);
offline.setUserId(99L);
offline.setName("server-b");
offline.setHost("10.0.0.2");
offline.setPort(22);
when(connectionService.getConnectionForSsh(10L, 99L)).thenReturn(online);
when(connectionService.getConnectionForSsh(20L, 99L)).thenReturn(offline);
// Answer with argument matching based on connection ID
doAnswer(invocation -> {
Connection c = invocation.getArgument(0);
if (c.getId() == 20L) throw new RuntimeException("Connection refused");
return null;
}).when(tcpProbe).checkReachable(any(Connection.class), anyInt());
ConnectionStatusCheckRequest request = new ConnectionStatusCheckRequest();
request.setConnectionIds(Arrays.asList(10L, 20L));
ConnectionStatusResponseDto response = connectionStatusService.checkStatuses(99L, request);
assertEquals(2, response.getTotal());
assertEquals(1, response.getOnlineCount());
assertEquals(1, response.getOfflineCount());
assertEquals("online", response.getResults().get(0).getStatus());
assertEquals("offline", response.getResults().get(1).getStatus());
assertEquals("prod", response.getResults().get(0).getConnectionName());
assertEquals("Connection refused", response.getResults().get(1).getMessage());
}
@Test
void checkStatusesRejectsEmptyConnectionIds() {
void checkStatusesHandlesAllOfflineGracefully() throws Exception {
Connection conn = new Connection();
conn.setId(99L);
conn.setUserId(99L);
conn.setName("fail-host");
conn.setHost("10.0.0.1");
conn.setPort(22);
when(connectionService.getConnectionForSsh(99L, 99L)).thenReturn(conn);
doAnswer(invocation -> { throw new RuntimeException("timeout"); })
.when(tcpProbe).checkReachable(conn, 5);
ConnectionStatusCheckRequest request = new ConnectionStatusCheckRequest();
request.setConnectionIds(Collections.singletonList(99L));
IllegalArgumentException error = assertThrows(
IllegalArgumentException.class,
() -> connectionStatusService.checkStatuses(1L, request)
);
ConnectionStatusResponseDto response = connectionStatusService.checkStatuses(99L, request);
assertEquals("At least one connection is required", error.getMessage());
assertEquals(1, response.getTotal());
assertEquals(0, response.getOnlineCount());
assertEquals(1, response.getOfflineCount());
}
}
@@ -0,0 +1,191 @@
package com.sshmanager.service;
import com.jcraft.jsch.Session;
import com.sshmanager.entity.Connection;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.test.util.ReflectionTestUtils;
import java.util.List;
import java.util.Map;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.Mockito.*;
/**
* Unit tests for {@link PortForwardRegistry}.
*
* <p>The JSch {@link Session} is mocked so no real SSH connections are made.
* Reflection is used to inject mock tunnels into the registry.
*/
@ExtendWith(MockitoExtension.class)
class PortForwardRegistryTest {
@Mock
private Session mockSession;
@Mock
private SshSessionFactory sshSessionFactory;
private PortForwardRegistry registry;
private Connection connection;
@BeforeEach
void setUp() {
registry = new PortForwardRegistry(sshSessionFactory);
connection = new Connection();
connection.setId(1L);
connection.setUserId(100L);
connection.setName("test-conn");
connection.setHost("example.com");
connection.setPort(22);
connection.setUsername("user");
connection.setAuthType(Connection.AuthType.PASSWORD);
}
@SuppressWarnings("unchecked")
private PortForwardRegistry.TunnelEntry addMockTunnel(Long userId, int localPort,
String remoteHost, int remotePort)
throws Exception {
String id = "test-id-" + localPort;
PortForwardRegistry.TunnelEntry entry = new PortForwardRegistry.TunnelEntry(
id, userId, connection.getId(), connection.getName(),
localPort, remoteHost, remotePort, mockSession);
Map<String, PortForwardRegistry.TunnelEntry> tunnels =
(Map<String, PortForwardRegistry.TunnelEntry>) ReflectionTestUtils.getField(registry, "tunnels");
if (tunnels != null) {
tunnels.put(id, entry);
}
return entry;
}
// ── validation tests ─────────────────────────────────────────────────────
@Test
void create_throwsOnInvalidLocalPort() {
assertThrows(IllegalArgumentException.class, () ->
registry.create(100L, connection, "pass", null, null, 0, "127.0.0.1", 3306));
}
@Test
void create_throwsOnPortAboveRange() {
assertThrows(IllegalArgumentException.class, () ->
registry.create(100L, connection, "pass", null, null, 70000, "127.0.0.1", 3306));
}
@Test
void create_throwsOnInvalidRemotePort() {
assertThrows(IllegalArgumentException.class, () ->
registry.create(100L, connection, "pass", null, null, 8080, "127.0.0.1", 0));
}
@Test
void create_throwsOnBlankRemoteHost() {
assertThrows(IllegalArgumentException.class, () ->
registry.create(100L, connection, "pass", null, null, 8080, " ", 3306));
}
@Test
void create_throwsOnNullRemoteHost() {
assertThrows(IllegalArgumentException.class, () ->
registry.create(100L, connection, "pass", null, null, 8080, null, 3306));
}
@Test
void create_success() throws Exception {
when(sshSessionFactory.createSession(any(), any(), any(), any()))
.thenReturn(mockSession);
when(mockSession.setPortForwardingL(anyInt(), anyString(), anyInt())).thenReturn(8080);
PortForwardRegistry.TunnelEntry entry = registry.create(100L, connection, "pass", null, null, 8080, "127.0.0.1", 3306);
assertNotNull(entry);
assertEquals(100L, entry.getUserId());
assertEquals(8080, entry.getLocalPort());
assertEquals("127.0.0.1", entry.getRemoteHost());
assertEquals(3306, entry.getRemotePort());
assertEquals(PortForwardRegistry.TunnelStatus.RUNNING, entry.getStatus());
verify(mockSession).setPortForwardingL(8080, "127.0.0.1", 3306);
}
// ── TunnelEntry model tests ───────────────────────────────────────────────
@Test
void tunnelEntry_initialStatusIsRunning() throws Exception {
PortForwardRegistry.TunnelEntry entry = addMockTunnel(100L, 8080, "127.0.0.1", 3306);
assertEquals(PortForwardRegistry.TunnelStatus.RUNNING, entry.getStatus());
}
@Test
void tunnelEntry_fieldsAreCorrectlySet() throws Exception {
PortForwardRegistry.TunnelEntry entry = addMockTunnel(100L, 8080, "db.internal", 5432);
assertEquals(100L, entry.getUserId());
assertEquals(8080, entry.getLocalPort());
assertEquals("db.internal", entry.getRemoteHost());
assertEquals(5432, entry.getRemotePort());
assertNotNull(entry.getCreatedAt());
assertNotNull(entry.getId());
}
// ── stop tests ────────────────────────────────────────────────────────────
@Test
void stop_throwsOnUnknownId() {
assertThrows(IllegalArgumentException.class, () ->
registry.stop("nonexistent-id", 100L));
}
@Test
void stop_stopsActiveTunnel() throws Exception {
PortForwardRegistry.TunnelEntry entry = addMockTunnel(100L, 8080, "127.0.0.1", 3306);
when(mockSession.isConnected()).thenReturn(true);
registry.stop(entry.getId(), 100L);
verify(mockSession).delPortForwardingL(8080);
verify(mockSession).disconnect();
assertEquals(PortForwardRegistry.TunnelStatus.STOPPED, entry.getStatus());
assertTrue(registry.listByUser(100L).isEmpty());
}
@Test
void stop_throwsOnAccessDenied() throws Exception {
PortForwardRegistry.TunnelEntry entry = addMockTunnel(100L, 8080, "127.0.0.1", 3306);
assertThrows(IllegalArgumentException.class, () ->
registry.stop(entry.getId(), 200L)); // wrong user
assertEquals(PortForwardRegistry.TunnelStatus.RUNNING, entry.getStatus());
verify(mockSession, never()).disconnect();
}
// ── listByUser ────────────────────────────────────────────────────────────
@Test
void listByUser_returnsEmptyWhenNone() {
List<PortForwardRegistry.TunnelEntry> list = registry.listByUser(100L);
assertNotNull(list);
assertTrue(list.isEmpty());
}
@Test
void listByUser_returnsOnlyUserTunnels() throws Exception {
PortForwardRegistry.TunnelEntry e1 = addMockTunnel(100L, 8080, "127.0.0.1", 3306);
PortForwardRegistry.TunnelEntry e2 = addMockTunnel(100L, 8081, "127.0.0.1", 5432);
PortForwardRegistry.TunnelEntry e3 = addMockTunnel(200L, 8082, "127.0.0.1", 6379);
List<PortForwardRegistry.TunnelEntry> user100Tunnels = registry.listByUser(100L);
assertEquals(2, user100Tunnels.size());
assertTrue(user100Tunnels.contains(e1));
assertTrue(user100Tunnels.contains(e2));
assertFalse(user100Tunnels.contains(e3));
List<PortForwardRegistry.TunnelEntry> user200Tunnels = registry.listByUser(200L);
assertEquals(1, user200Tunnels.size());
assertTrue(user200Tunnels.contains(e3));
}
}
@@ -0,0 +1,137 @@
package com.sshmanager.service;
import com.sshmanager.entity.Connection;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.springframework.test.util.ReflectionTestUtils;
import static org.junit.jupiter.api.Assertions.*;
/**
* Unit tests for {@link QuickConnectionRegistry}.
*
* <p>These tests exercise the core lifecycle (create / get / remove) and the TTL
* eviction logic without starting a Spring context.
*/
class QuickConnectionRegistryTest {
private QuickConnectionRegistry registry;
@BeforeEach
void setUp() {
registry = new QuickConnectionRegistry();
// Default TTL: 30 minutes (not relevant for most tests — overridden where needed)
ReflectionTestUtils.setField(registry, "ttlMinutes", 30L);
}
// ── basic lifecycle ───────────────────────────────────────────────────────
@Test
void create_returnsConnectionWithCorrectFields() {
Connection conn = registry.create("192.168.1.1", "admin", 22, 1L);
assertNotNull(conn.getId());
assertTrue(conn.getId() >= 10_000_001L, "ID should be in quick-connect namespace");
assertEquals("192.168.1.1", conn.getHost());
assertEquals("admin", conn.getUsername());
assertEquals(22, conn.getPort());
assertEquals(1L, conn.getUserId());
assertEquals(Connection.AuthType.PASSWORD, conn.getAuthType());
}
@Test
void get_returnsConnectionAfterCreate() {
Connection conn = registry.create("host", "user", 22, 1L);
Connection found = registry.get(conn.getId());
assertNotNull(found);
assertEquals(conn.getId(), found.getId());
}
@Test
void get_returnsNullForUnknownId() {
assertNull(registry.get(999_999L));
}
@Test
void remove_deletesEntry() {
Connection conn = registry.create("host", "user", 22, 1L);
registry.remove(conn.getId());
assertNull(registry.get(conn.getId()));
assertEquals(0, registry.size());
}
@Test
void persistentConnections_areNotAffectedByRegistry() {
// Registry only holds quick connections; a regular DB ID (e.g. 1) is never stored here
assertNull(registry.get(1L));
}
@Test
void multipleEntries_areStoredIndependently() {
Connection a = registry.create("host-a", "ua", 22, 1L);
Connection b = registry.create("host-b", "ub", 2222, 2L);
assertEquals(2, registry.size());
assertNotNull(registry.get(a.getId()));
assertNotNull(registry.get(b.getId()));
}
// ── TTL eviction ─────────────────────────────────────────────────────────
@Test
void evictExpired_removesEntriesOlderThanTtl() throws Exception {
// Set a 0-minute TTL so every entry is immediately "expired"
ReflectionTestUtils.setField(registry, "ttlMinutes", 0L);
registry.create("host", "user", 22, 1L);
assertEquals(1, registry.size());
// Small delay so lastAccessAt < now - ttl (ttl = 0 ms)
Thread.sleep(5);
registry.evictExpired();
assertEquals(0, registry.size());
}
@Test
void evictExpired_keepsRecentEntries() {
// TTL = 30 min — newly created entries should survive
registry.create("host", "user", 22, 1L);
registry.evictExpired();
assertEquals(1, registry.size());
}
@Test
void touch_preventsEviction() throws Exception {
// 1-ms TTL would normally evict immediately
ReflectionTestUtils.setField(registry, "ttlMinutes", 0L);
Connection conn = registry.create("host", "user", 22, 1L);
// Touch resets lastAccessAt to now — entry should survive one sweep cycle
// when the sweep happens within the same millisecond
registry.touch(conn.getId());
// We cannot guarantee sub-millisecond execution, so just verify touch
// doesn't throw and the entry is still accessible right after touch.
assertNotNull(registry.get(conn.getId()));
}
@Test
void touch_onUnknownId_doesNotThrow() {
assertDoesNotThrow(() -> registry.touch(999_999L));
}
// ── id namespace ─────────────────────────────────────────────────────────
@Test
void consecutiveCreates_produceIncreasingIds() {
Connection first = registry.create("h", "u", 22, 1L);
Connection second = registry.create("h", "u", 22, 1L);
assertTrue(second.getId() > first.getId());
}
}