feat: add port forwarding and optimize connection status checks

This commit is contained in:
liumangmang
2026-06-11 14:10:30 +08:00
parent 4a17f0106e
commit e418e6ecc2
30 changed files with 1789 additions and 150 deletions
@@ -3,19 +3,20 @@ package com.sshmanager.service;
import com.sshmanager.dto.ConnectionStatusCheckRequest;
import com.sshmanager.dto.ConnectionStatusResponseDto;
import com.sshmanager.entity.Connection;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import java.util.Arrays;
import java.util.Collections;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.doReturn;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doThrow;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class)
@@ -24,49 +25,128 @@ class ConnectionStatusServiceTest {
@Mock
private ConnectionService connectionService;
@InjectMocks
@Mock
private TcpProbe tcpProbe;
private ConnectionStatusService connectionStatusService;
@Test
void checkStatusesAggregatesOnlineAndOfflineResults() {
Connection onlineConnection = new Connection();
onlineConnection.setId(1L);
onlineConnection.setUserId(99L);
onlineConnection.setName("prod");
@BeforeEach
void setUp() {
connectionStatusService = new ConnectionStatusService(connectionService, tcpProbe);
}
Connection offlineConnection = new Connection();
offlineConnection.setId(2L);
offlineConnection.setUserId(99L);
offlineConnection.setName("test");
@Test
void checkStatusesRejectsEmptyConnectionIds() {
ConnectionStatusCheckRequest request = new ConnectionStatusCheckRequest();
IllegalArgumentException error = assertThrows(
IllegalArgumentException.class,
() -> connectionStatusService.checkStatuses(1L, request)
);
assertEquals("At least one connection is required", error.getMessage());
}
@Test
void checkStatusesMarksReachableAsOnline() throws Exception {
Connection conn = new Connection();
conn.setId(1L);
conn.setUserId(99L);
conn.setName("reachable-host");
conn.setHost("10.0.0.1");
conn.setPort(22);
when(connectionService.getConnectionForSsh(1L, 99L)).thenReturn(conn);
ConnectionStatusCheckRequest request = new ConnectionStatusCheckRequest();
request.setConnectionIds(Arrays.asList(1L, 2L));
request.setConnectionIds(Arrays.asList(1L));
when(connectionService.getConnectionForSsh(1L, 99L)).thenReturn(onlineConnection);
when(connectionService.getConnectionForSsh(2L, 99L)).thenReturn(offlineConnection);
doReturn(onlineConnection).when(connectionService).testConnection(eq(onlineConnection), eq(null), eq(null), eq(null));
doThrow(new RuntimeException("Connection refused")).when(connectionService).testConnection(eq(offlineConnection), eq(null), eq(null), eq(null));
ConnectionStatusResponseDto response = connectionStatusService.checkStatuses(99L, request);
assertEquals(1, response.getTotal());
assertEquals(1, response.getOnlineCount());
assertEquals(0, response.getOfflineCount());
assertEquals("online", response.getResults().get(0).getStatus());
assertEquals("TCP port reachable", response.getResults().get(0).getMessage());
}
@Test
void checkStatusesMarksUnreachableAsOffline() throws Exception {
Connection conn = new Connection();
conn.setId(2L);
conn.setUserId(99L);
conn.setName("unreachable-host");
conn.setHost("10.0.0.2");
conn.setPort(22);
when(connectionService.getConnectionForSsh(2L, 99L)).thenReturn(conn);
doAnswer(invocation -> { throw new RuntimeException("Connection refused"); })
.when(tcpProbe).checkReachable(conn, 5);
ConnectionStatusCheckRequest request = new ConnectionStatusCheckRequest();
request.setConnectionIds(Arrays.asList(2L));
ConnectionStatusResponseDto response = connectionStatusService.checkStatuses(99L, request);
assertEquals(1, response.getTotal());
assertEquals(0, response.getOnlineCount());
assertEquals(1, response.getOfflineCount());
assertEquals("offline", response.getResults().get(0).getStatus());
}
@Test
void checkStatusesAggregatesMixedResults() throws Exception {
Connection online = new Connection();
online.setId(10L);
online.setUserId(99L);
online.setName("server-a");
online.setHost("10.0.0.1");
online.setPort(22);
Connection offline = new Connection();
offline.setId(20L);
offline.setUserId(99L);
offline.setName("server-b");
offline.setHost("10.0.0.2");
offline.setPort(22);
when(connectionService.getConnectionForSsh(10L, 99L)).thenReturn(online);
when(connectionService.getConnectionForSsh(20L, 99L)).thenReturn(offline);
// Answer with argument matching based on connection ID
doAnswer(invocation -> {
Connection c = invocation.getArgument(0);
if (c.getId() == 20L) throw new RuntimeException("Connection refused");
return null;
}).when(tcpProbe).checkReachable(any(Connection.class), anyInt());
ConnectionStatusCheckRequest request = new ConnectionStatusCheckRequest();
request.setConnectionIds(Arrays.asList(10L, 20L));
ConnectionStatusResponseDto response = connectionStatusService.checkStatuses(99L, request);
assertEquals(2, response.getTotal());
assertEquals(1, response.getOnlineCount());
assertEquals(1, response.getOfflineCount());
assertEquals("online", response.getResults().get(0).getStatus());
assertEquals("offline", response.getResults().get(1).getStatus());
assertEquals("prod", response.getResults().get(0).getConnectionName());
assertEquals("Connection refused", response.getResults().get(1).getMessage());
}
@Test
void checkStatusesRejectsEmptyConnectionIds() {
void checkStatusesHandlesAllOfflineGracefully() throws Exception {
Connection conn = new Connection();
conn.setId(99L);
conn.setUserId(99L);
conn.setName("fail-host");
conn.setHost("10.0.0.1");
conn.setPort(22);
when(connectionService.getConnectionForSsh(99L, 99L)).thenReturn(conn);
doAnswer(invocation -> { throw new RuntimeException("timeout"); })
.when(tcpProbe).checkReachable(conn, 5);
ConnectionStatusCheckRequest request = new ConnectionStatusCheckRequest();
request.setConnectionIds(Collections.singletonList(99L));
IllegalArgumentException error = assertThrows(
IllegalArgumentException.class,
() -> connectionStatusService.checkStatuses(1L, request)
);
ConnectionStatusResponseDto response = connectionStatusService.checkStatuses(99L, request);
assertEquals("At least one connection is required", error.getMessage());
assertEquals(1, response.getTotal());
assertEquals(0, response.getOnlineCount());
assertEquals(1, response.getOfflineCount());
}
}
@@ -0,0 +1,191 @@
package com.sshmanager.service;
import com.jcraft.jsch.Session;
import com.sshmanager.entity.Connection;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.test.util.ReflectionTestUtils;
import java.util.List;
import java.util.Map;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.Mockito.*;
/**
* Unit tests for {@link PortForwardRegistry}.
*
* <p>The JSch {@link Session} is mocked so no real SSH connections are made.
* Reflection is used to inject mock tunnels into the registry.
*/
@ExtendWith(MockitoExtension.class)
class PortForwardRegistryTest {
@Mock
private Session mockSession;
@Mock
private SshSessionFactory sshSessionFactory;
private PortForwardRegistry registry;
private Connection connection;
@BeforeEach
void setUp() {
registry = new PortForwardRegistry(sshSessionFactory);
connection = new Connection();
connection.setId(1L);
connection.setUserId(100L);
connection.setName("test-conn");
connection.setHost("example.com");
connection.setPort(22);
connection.setUsername("user");
connection.setAuthType(Connection.AuthType.PASSWORD);
}
@SuppressWarnings("unchecked")
private PortForwardRegistry.TunnelEntry addMockTunnel(Long userId, int localPort,
String remoteHost, int remotePort)
throws Exception {
String id = "test-id-" + localPort;
PortForwardRegistry.TunnelEntry entry = new PortForwardRegistry.TunnelEntry(
id, userId, connection.getId(), connection.getName(),
localPort, remoteHost, remotePort, mockSession);
Map<String, PortForwardRegistry.TunnelEntry> tunnels =
(Map<String, PortForwardRegistry.TunnelEntry>) ReflectionTestUtils.getField(registry, "tunnels");
if (tunnels != null) {
tunnels.put(id, entry);
}
return entry;
}
// ── validation tests ─────────────────────────────────────────────────────
@Test
void create_throwsOnInvalidLocalPort() {
assertThrows(IllegalArgumentException.class, () ->
registry.create(100L, connection, "pass", null, null, 0, "127.0.0.1", 3306));
}
@Test
void create_throwsOnPortAboveRange() {
assertThrows(IllegalArgumentException.class, () ->
registry.create(100L, connection, "pass", null, null, 70000, "127.0.0.1", 3306));
}
@Test
void create_throwsOnInvalidRemotePort() {
assertThrows(IllegalArgumentException.class, () ->
registry.create(100L, connection, "pass", null, null, 8080, "127.0.0.1", 0));
}
@Test
void create_throwsOnBlankRemoteHost() {
assertThrows(IllegalArgumentException.class, () ->
registry.create(100L, connection, "pass", null, null, 8080, " ", 3306));
}
@Test
void create_throwsOnNullRemoteHost() {
assertThrows(IllegalArgumentException.class, () ->
registry.create(100L, connection, "pass", null, null, 8080, null, 3306));
}
@Test
void create_success() throws Exception {
when(sshSessionFactory.createSession(any(), any(), any(), any()))
.thenReturn(mockSession);
when(mockSession.setPortForwardingL(anyInt(), anyString(), anyInt())).thenReturn(8080);
PortForwardRegistry.TunnelEntry entry = registry.create(100L, connection, "pass", null, null, 8080, "127.0.0.1", 3306);
assertNotNull(entry);
assertEquals(100L, entry.getUserId());
assertEquals(8080, entry.getLocalPort());
assertEquals("127.0.0.1", entry.getRemoteHost());
assertEquals(3306, entry.getRemotePort());
assertEquals(PortForwardRegistry.TunnelStatus.RUNNING, entry.getStatus());
verify(mockSession).setPortForwardingL(8080, "127.0.0.1", 3306);
}
// ── TunnelEntry model tests ───────────────────────────────────────────────
@Test
void tunnelEntry_initialStatusIsRunning() throws Exception {
PortForwardRegistry.TunnelEntry entry = addMockTunnel(100L, 8080, "127.0.0.1", 3306);
assertEquals(PortForwardRegistry.TunnelStatus.RUNNING, entry.getStatus());
}
@Test
void tunnelEntry_fieldsAreCorrectlySet() throws Exception {
PortForwardRegistry.TunnelEntry entry = addMockTunnel(100L, 8080, "db.internal", 5432);
assertEquals(100L, entry.getUserId());
assertEquals(8080, entry.getLocalPort());
assertEquals("db.internal", entry.getRemoteHost());
assertEquals(5432, entry.getRemotePort());
assertNotNull(entry.getCreatedAt());
assertNotNull(entry.getId());
}
// ── stop tests ────────────────────────────────────────────────────────────
@Test
void stop_throwsOnUnknownId() {
assertThrows(IllegalArgumentException.class, () ->
registry.stop("nonexistent-id", 100L));
}
@Test
void stop_stopsActiveTunnel() throws Exception {
PortForwardRegistry.TunnelEntry entry = addMockTunnel(100L, 8080, "127.0.0.1", 3306);
when(mockSession.isConnected()).thenReturn(true);
registry.stop(entry.getId(), 100L);
verify(mockSession).delPortForwardingL(8080);
verify(mockSession).disconnect();
assertEquals(PortForwardRegistry.TunnelStatus.STOPPED, entry.getStatus());
assertTrue(registry.listByUser(100L).isEmpty());
}
@Test
void stop_throwsOnAccessDenied() throws Exception {
PortForwardRegistry.TunnelEntry entry = addMockTunnel(100L, 8080, "127.0.0.1", 3306);
assertThrows(IllegalArgumentException.class, () ->
registry.stop(entry.getId(), 200L)); // wrong user
assertEquals(PortForwardRegistry.TunnelStatus.RUNNING, entry.getStatus());
verify(mockSession, never()).disconnect();
}
// ── listByUser ────────────────────────────────────────────────────────────
@Test
void listByUser_returnsEmptyWhenNone() {
List<PortForwardRegistry.TunnelEntry> list = registry.listByUser(100L);
assertNotNull(list);
assertTrue(list.isEmpty());
}
@Test
void listByUser_returnsOnlyUserTunnels() throws Exception {
PortForwardRegistry.TunnelEntry e1 = addMockTunnel(100L, 8080, "127.0.0.1", 3306);
PortForwardRegistry.TunnelEntry e2 = addMockTunnel(100L, 8081, "127.0.0.1", 5432);
PortForwardRegistry.TunnelEntry e3 = addMockTunnel(200L, 8082, "127.0.0.1", 6379);
List<PortForwardRegistry.TunnelEntry> user100Tunnels = registry.listByUser(100L);
assertEquals(2, user100Tunnels.size());
assertTrue(user100Tunnels.contains(e1));
assertTrue(user100Tunnels.contains(e2));
assertFalse(user100Tunnels.contains(e3));
List<PortForwardRegistry.TunnelEntry> user200Tunnels = registry.listByUser(200L);
assertEquals(1, user200Tunnels.size());
assertTrue(user200Tunnels.contains(e3));
}
}
@@ -0,0 +1,137 @@
package com.sshmanager.service;
import com.sshmanager.entity.Connection;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.springframework.test.util.ReflectionTestUtils;
import static org.junit.jupiter.api.Assertions.*;
/**
* Unit tests for {@link QuickConnectionRegistry}.
*
* <p>These tests exercise the core lifecycle (create / get / remove) and the TTL
* eviction logic without starting a Spring context.
*/
class QuickConnectionRegistryTest {
private QuickConnectionRegistry registry;
@BeforeEach
void setUp() {
registry = new QuickConnectionRegistry();
// Default TTL: 30 minutes (not relevant for most tests — overridden where needed)
ReflectionTestUtils.setField(registry, "ttlMinutes", 30L);
}
// ── basic lifecycle ───────────────────────────────────────────────────────
@Test
void create_returnsConnectionWithCorrectFields() {
Connection conn = registry.create("192.168.1.1", "admin", 22, 1L);
assertNotNull(conn.getId());
assertTrue(conn.getId() >= 10_000_001L, "ID should be in quick-connect namespace");
assertEquals("192.168.1.1", conn.getHost());
assertEquals("admin", conn.getUsername());
assertEquals(22, conn.getPort());
assertEquals(1L, conn.getUserId());
assertEquals(Connection.AuthType.PASSWORD, conn.getAuthType());
}
@Test
void get_returnsConnectionAfterCreate() {
Connection conn = registry.create("host", "user", 22, 1L);
Connection found = registry.get(conn.getId());
assertNotNull(found);
assertEquals(conn.getId(), found.getId());
}
@Test
void get_returnsNullForUnknownId() {
assertNull(registry.get(999_999L));
}
@Test
void remove_deletesEntry() {
Connection conn = registry.create("host", "user", 22, 1L);
registry.remove(conn.getId());
assertNull(registry.get(conn.getId()));
assertEquals(0, registry.size());
}
@Test
void persistentConnections_areNotAffectedByRegistry() {
// Registry only holds quick connections; a regular DB ID (e.g. 1) is never stored here
assertNull(registry.get(1L));
}
@Test
void multipleEntries_areStoredIndependently() {
Connection a = registry.create("host-a", "ua", 22, 1L);
Connection b = registry.create("host-b", "ub", 2222, 2L);
assertEquals(2, registry.size());
assertNotNull(registry.get(a.getId()));
assertNotNull(registry.get(b.getId()));
}
// ── TTL eviction ─────────────────────────────────────────────────────────
@Test
void evictExpired_removesEntriesOlderThanTtl() throws Exception {
// Set a 0-minute TTL so every entry is immediately "expired"
ReflectionTestUtils.setField(registry, "ttlMinutes", 0L);
registry.create("host", "user", 22, 1L);
assertEquals(1, registry.size());
// Small delay so lastAccessAt < now - ttl (ttl = 0 ms)
Thread.sleep(5);
registry.evictExpired();
assertEquals(0, registry.size());
}
@Test
void evictExpired_keepsRecentEntries() {
// TTL = 30 min — newly created entries should survive
registry.create("host", "user", 22, 1L);
registry.evictExpired();
assertEquals(1, registry.size());
}
@Test
void touch_preventsEviction() throws Exception {
// 1-ms TTL would normally evict immediately
ReflectionTestUtils.setField(registry, "ttlMinutes", 0L);
Connection conn = registry.create("host", "user", 22, 1L);
// Touch resets lastAccessAt to now — entry should survive one sweep cycle
// when the sweep happens within the same millisecond
registry.touch(conn.getId());
// We cannot guarantee sub-millisecond execution, so just verify touch
// doesn't throw and the entry is still accessible right after touch.
assertNotNull(registry.get(conn.getId()));
}
@Test
void touch_onUnknownId_doesNotThrow() {
assertDoesNotThrow(() -> registry.touch(999_999L));
}
// ── id namespace ─────────────────────────────────────────────────────────
@Test
void consecutiveCreates_produceIncreasingIds() {
Connection first = registry.create("h", "u", 22, 1L);
Connection second = registry.create("h", "u", 22, 1L);
assertTrue(second.getId() > first.getId());
}
}