feat: add port forwarding and optimize connection status checks
This commit is contained in:
@@ -3,19 +3,20 @@ package com.sshmanager.service;
|
||||
import com.sshmanager.dto.ConnectionStatusCheckRequest;
|
||||
import com.sshmanager.dto.ConnectionStatusResponseDto;
|
||||
import com.sshmanager.entity.Connection;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.mockito.InjectMocks;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
import static org.mockito.Mockito.doReturn;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.Mockito.doThrow;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.anyInt;
|
||||
import static org.mockito.Mockito.doAnswer;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
@@ -24,49 +25,128 @@ class ConnectionStatusServiceTest {
|
||||
@Mock
|
||||
private ConnectionService connectionService;
|
||||
|
||||
@InjectMocks
|
||||
@Mock
|
||||
private TcpProbe tcpProbe;
|
||||
|
||||
private ConnectionStatusService connectionStatusService;
|
||||
|
||||
@Test
|
||||
void checkStatusesAggregatesOnlineAndOfflineResults() {
|
||||
Connection onlineConnection = new Connection();
|
||||
onlineConnection.setId(1L);
|
||||
onlineConnection.setUserId(99L);
|
||||
onlineConnection.setName("prod");
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
connectionStatusService = new ConnectionStatusService(connectionService, tcpProbe);
|
||||
}
|
||||
|
||||
Connection offlineConnection = new Connection();
|
||||
offlineConnection.setId(2L);
|
||||
offlineConnection.setUserId(99L);
|
||||
offlineConnection.setName("test");
|
||||
@Test
|
||||
void checkStatusesRejectsEmptyConnectionIds() {
|
||||
ConnectionStatusCheckRequest request = new ConnectionStatusCheckRequest();
|
||||
IllegalArgumentException error = assertThrows(
|
||||
IllegalArgumentException.class,
|
||||
() -> connectionStatusService.checkStatuses(1L, request)
|
||||
);
|
||||
assertEquals("At least one connection is required", error.getMessage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void checkStatusesMarksReachableAsOnline() throws Exception {
|
||||
Connection conn = new Connection();
|
||||
conn.setId(1L);
|
||||
conn.setUserId(99L);
|
||||
conn.setName("reachable-host");
|
||||
conn.setHost("10.0.0.1");
|
||||
conn.setPort(22);
|
||||
|
||||
when(connectionService.getConnectionForSsh(1L, 99L)).thenReturn(conn);
|
||||
|
||||
ConnectionStatusCheckRequest request = new ConnectionStatusCheckRequest();
|
||||
request.setConnectionIds(Arrays.asList(1L, 2L));
|
||||
request.setConnectionIds(Arrays.asList(1L));
|
||||
|
||||
when(connectionService.getConnectionForSsh(1L, 99L)).thenReturn(onlineConnection);
|
||||
when(connectionService.getConnectionForSsh(2L, 99L)).thenReturn(offlineConnection);
|
||||
doReturn(onlineConnection).when(connectionService).testConnection(eq(onlineConnection), eq(null), eq(null), eq(null));
|
||||
doThrow(new RuntimeException("Connection refused")).when(connectionService).testConnection(eq(offlineConnection), eq(null), eq(null), eq(null));
|
||||
ConnectionStatusResponseDto response = connectionStatusService.checkStatuses(99L, request);
|
||||
|
||||
assertEquals(1, response.getTotal());
|
||||
assertEquals(1, response.getOnlineCount());
|
||||
assertEquals(0, response.getOfflineCount());
|
||||
assertEquals("online", response.getResults().get(0).getStatus());
|
||||
assertEquals("TCP port reachable", response.getResults().get(0).getMessage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void checkStatusesMarksUnreachableAsOffline() throws Exception {
|
||||
Connection conn = new Connection();
|
||||
conn.setId(2L);
|
||||
conn.setUserId(99L);
|
||||
conn.setName("unreachable-host");
|
||||
conn.setHost("10.0.0.2");
|
||||
conn.setPort(22);
|
||||
|
||||
when(connectionService.getConnectionForSsh(2L, 99L)).thenReturn(conn);
|
||||
doAnswer(invocation -> { throw new RuntimeException("Connection refused"); })
|
||||
.when(tcpProbe).checkReachable(conn, 5);
|
||||
|
||||
ConnectionStatusCheckRequest request = new ConnectionStatusCheckRequest();
|
||||
request.setConnectionIds(Arrays.asList(2L));
|
||||
|
||||
ConnectionStatusResponseDto response = connectionStatusService.checkStatuses(99L, request);
|
||||
|
||||
assertEquals(1, response.getTotal());
|
||||
assertEquals(0, response.getOnlineCount());
|
||||
assertEquals(1, response.getOfflineCount());
|
||||
assertEquals("offline", response.getResults().get(0).getStatus());
|
||||
}
|
||||
|
||||
@Test
|
||||
void checkStatusesAggregatesMixedResults() throws Exception {
|
||||
Connection online = new Connection();
|
||||
online.setId(10L);
|
||||
online.setUserId(99L);
|
||||
online.setName("server-a");
|
||||
online.setHost("10.0.0.1");
|
||||
online.setPort(22);
|
||||
|
||||
Connection offline = new Connection();
|
||||
offline.setId(20L);
|
||||
offline.setUserId(99L);
|
||||
offline.setName("server-b");
|
||||
offline.setHost("10.0.0.2");
|
||||
offline.setPort(22);
|
||||
|
||||
when(connectionService.getConnectionForSsh(10L, 99L)).thenReturn(online);
|
||||
when(connectionService.getConnectionForSsh(20L, 99L)).thenReturn(offline);
|
||||
// Answer with argument matching based on connection ID
|
||||
doAnswer(invocation -> {
|
||||
Connection c = invocation.getArgument(0);
|
||||
if (c.getId() == 20L) throw new RuntimeException("Connection refused");
|
||||
return null;
|
||||
}).when(tcpProbe).checkReachable(any(Connection.class), anyInt());
|
||||
|
||||
ConnectionStatusCheckRequest request = new ConnectionStatusCheckRequest();
|
||||
request.setConnectionIds(Arrays.asList(10L, 20L));
|
||||
|
||||
ConnectionStatusResponseDto response = connectionStatusService.checkStatuses(99L, request);
|
||||
|
||||
assertEquals(2, response.getTotal());
|
||||
assertEquals(1, response.getOnlineCount());
|
||||
assertEquals(1, response.getOfflineCount());
|
||||
assertEquals("online", response.getResults().get(0).getStatus());
|
||||
assertEquals("offline", response.getResults().get(1).getStatus());
|
||||
assertEquals("prod", response.getResults().get(0).getConnectionName());
|
||||
assertEquals("Connection refused", response.getResults().get(1).getMessage());
|
||||
}
|
||||
|
||||
@Test
|
||||
void checkStatusesRejectsEmptyConnectionIds() {
|
||||
void checkStatusesHandlesAllOfflineGracefully() throws Exception {
|
||||
Connection conn = new Connection();
|
||||
conn.setId(99L);
|
||||
conn.setUserId(99L);
|
||||
conn.setName("fail-host");
|
||||
conn.setHost("10.0.0.1");
|
||||
conn.setPort(22);
|
||||
|
||||
when(connectionService.getConnectionForSsh(99L, 99L)).thenReturn(conn);
|
||||
doAnswer(invocation -> { throw new RuntimeException("timeout"); })
|
||||
.when(tcpProbe).checkReachable(conn, 5);
|
||||
|
||||
ConnectionStatusCheckRequest request = new ConnectionStatusCheckRequest();
|
||||
request.setConnectionIds(Collections.singletonList(99L));
|
||||
|
||||
IllegalArgumentException error = assertThrows(
|
||||
IllegalArgumentException.class,
|
||||
() -> connectionStatusService.checkStatuses(1L, request)
|
||||
);
|
||||
ConnectionStatusResponseDto response = connectionStatusService.checkStatuses(99L, request);
|
||||
|
||||
assertEquals("At least one connection is required", error.getMessage());
|
||||
assertEquals(1, response.getTotal());
|
||||
assertEquals(0, response.getOnlineCount());
|
||||
assertEquals(1, response.getOfflineCount());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,191 @@
|
||||
package com.sshmanager.service;
|
||||
|
||||
import com.jcraft.jsch.Session;
|
||||
import com.sshmanager.entity.Connection;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
import org.springframework.test.util.ReflectionTestUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
import static org.mockito.Mockito.*;
|
||||
|
||||
/**
|
||||
* Unit tests for {@link PortForwardRegistry}.
|
||||
*
|
||||
* <p>The JSch {@link Session} is mocked so no real SSH connections are made.
|
||||
* Reflection is used to inject mock tunnels into the registry.
|
||||
*/
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
class PortForwardRegistryTest {
|
||||
|
||||
@Mock
|
||||
private Session mockSession;
|
||||
|
||||
@Mock
|
||||
private SshSessionFactory sshSessionFactory;
|
||||
|
||||
private PortForwardRegistry registry;
|
||||
private Connection connection;
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
registry = new PortForwardRegistry(sshSessionFactory);
|
||||
connection = new Connection();
|
||||
connection.setId(1L);
|
||||
connection.setUserId(100L);
|
||||
connection.setName("test-conn");
|
||||
connection.setHost("example.com");
|
||||
connection.setPort(22);
|
||||
connection.setUsername("user");
|
||||
connection.setAuthType(Connection.AuthType.PASSWORD);
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private PortForwardRegistry.TunnelEntry addMockTunnel(Long userId, int localPort,
|
||||
String remoteHost, int remotePort)
|
||||
throws Exception {
|
||||
String id = "test-id-" + localPort;
|
||||
PortForwardRegistry.TunnelEntry entry = new PortForwardRegistry.TunnelEntry(
|
||||
id, userId, connection.getId(), connection.getName(),
|
||||
localPort, remoteHost, remotePort, mockSession);
|
||||
|
||||
Map<String, PortForwardRegistry.TunnelEntry> tunnels =
|
||||
(Map<String, PortForwardRegistry.TunnelEntry>) ReflectionTestUtils.getField(registry, "tunnels");
|
||||
if (tunnels != null) {
|
||||
tunnels.put(id, entry);
|
||||
}
|
||||
return entry;
|
||||
}
|
||||
|
||||
// ── validation tests ─────────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
void create_throwsOnInvalidLocalPort() {
|
||||
assertThrows(IllegalArgumentException.class, () ->
|
||||
registry.create(100L, connection, "pass", null, null, 0, "127.0.0.1", 3306));
|
||||
}
|
||||
|
||||
@Test
|
||||
void create_throwsOnPortAboveRange() {
|
||||
assertThrows(IllegalArgumentException.class, () ->
|
||||
registry.create(100L, connection, "pass", null, null, 70000, "127.0.0.1", 3306));
|
||||
}
|
||||
|
||||
@Test
|
||||
void create_throwsOnInvalidRemotePort() {
|
||||
assertThrows(IllegalArgumentException.class, () ->
|
||||
registry.create(100L, connection, "pass", null, null, 8080, "127.0.0.1", 0));
|
||||
}
|
||||
|
||||
@Test
|
||||
void create_throwsOnBlankRemoteHost() {
|
||||
assertThrows(IllegalArgumentException.class, () ->
|
||||
registry.create(100L, connection, "pass", null, null, 8080, " ", 3306));
|
||||
}
|
||||
|
||||
@Test
|
||||
void create_throwsOnNullRemoteHost() {
|
||||
assertThrows(IllegalArgumentException.class, () ->
|
||||
registry.create(100L, connection, "pass", null, null, 8080, null, 3306));
|
||||
}
|
||||
|
||||
@Test
|
||||
void create_success() throws Exception {
|
||||
when(sshSessionFactory.createSession(any(), any(), any(), any()))
|
||||
.thenReturn(mockSession);
|
||||
when(mockSession.setPortForwardingL(anyInt(), anyString(), anyInt())).thenReturn(8080);
|
||||
|
||||
PortForwardRegistry.TunnelEntry entry = registry.create(100L, connection, "pass", null, null, 8080, "127.0.0.1", 3306);
|
||||
|
||||
assertNotNull(entry);
|
||||
assertEquals(100L, entry.getUserId());
|
||||
assertEquals(8080, entry.getLocalPort());
|
||||
assertEquals("127.0.0.1", entry.getRemoteHost());
|
||||
assertEquals(3306, entry.getRemotePort());
|
||||
assertEquals(PortForwardRegistry.TunnelStatus.RUNNING, entry.getStatus());
|
||||
verify(mockSession).setPortForwardingL(8080, "127.0.0.1", 3306);
|
||||
}
|
||||
|
||||
// ── TunnelEntry model tests ───────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
void tunnelEntry_initialStatusIsRunning() throws Exception {
|
||||
PortForwardRegistry.TunnelEntry entry = addMockTunnel(100L, 8080, "127.0.0.1", 3306);
|
||||
assertEquals(PortForwardRegistry.TunnelStatus.RUNNING, entry.getStatus());
|
||||
}
|
||||
|
||||
@Test
|
||||
void tunnelEntry_fieldsAreCorrectlySet() throws Exception {
|
||||
PortForwardRegistry.TunnelEntry entry = addMockTunnel(100L, 8080, "db.internal", 5432);
|
||||
assertEquals(100L, entry.getUserId());
|
||||
assertEquals(8080, entry.getLocalPort());
|
||||
assertEquals("db.internal", entry.getRemoteHost());
|
||||
assertEquals(5432, entry.getRemotePort());
|
||||
assertNotNull(entry.getCreatedAt());
|
||||
assertNotNull(entry.getId());
|
||||
}
|
||||
|
||||
// ── stop tests ────────────────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
void stop_throwsOnUnknownId() {
|
||||
assertThrows(IllegalArgumentException.class, () ->
|
||||
registry.stop("nonexistent-id", 100L));
|
||||
}
|
||||
|
||||
@Test
|
||||
void stop_stopsActiveTunnel() throws Exception {
|
||||
PortForwardRegistry.TunnelEntry entry = addMockTunnel(100L, 8080, "127.0.0.1", 3306);
|
||||
when(mockSession.isConnected()).thenReturn(true);
|
||||
|
||||
registry.stop(entry.getId(), 100L);
|
||||
|
||||
verify(mockSession).delPortForwardingL(8080);
|
||||
verify(mockSession).disconnect();
|
||||
assertEquals(PortForwardRegistry.TunnelStatus.STOPPED, entry.getStatus());
|
||||
assertTrue(registry.listByUser(100L).isEmpty());
|
||||
}
|
||||
|
||||
@Test
|
||||
void stop_throwsOnAccessDenied() throws Exception {
|
||||
PortForwardRegistry.TunnelEntry entry = addMockTunnel(100L, 8080, "127.0.0.1", 3306);
|
||||
|
||||
assertThrows(IllegalArgumentException.class, () ->
|
||||
registry.stop(entry.getId(), 200L)); // wrong user
|
||||
|
||||
assertEquals(PortForwardRegistry.TunnelStatus.RUNNING, entry.getStatus());
|
||||
verify(mockSession, never()).disconnect();
|
||||
}
|
||||
|
||||
// ── listByUser ────────────────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
void listByUser_returnsEmptyWhenNone() {
|
||||
List<PortForwardRegistry.TunnelEntry> list = registry.listByUser(100L);
|
||||
assertNotNull(list);
|
||||
assertTrue(list.isEmpty());
|
||||
}
|
||||
|
||||
@Test
|
||||
void listByUser_returnsOnlyUserTunnels() throws Exception {
|
||||
PortForwardRegistry.TunnelEntry e1 = addMockTunnel(100L, 8080, "127.0.0.1", 3306);
|
||||
PortForwardRegistry.TunnelEntry e2 = addMockTunnel(100L, 8081, "127.0.0.1", 5432);
|
||||
PortForwardRegistry.TunnelEntry e3 = addMockTunnel(200L, 8082, "127.0.0.1", 6379);
|
||||
|
||||
List<PortForwardRegistry.TunnelEntry> user100Tunnels = registry.listByUser(100L);
|
||||
assertEquals(2, user100Tunnels.size());
|
||||
assertTrue(user100Tunnels.contains(e1));
|
||||
assertTrue(user100Tunnels.contains(e2));
|
||||
assertFalse(user100Tunnels.contains(e3));
|
||||
|
||||
List<PortForwardRegistry.TunnelEntry> user200Tunnels = registry.listByUser(200L);
|
||||
assertEquals(1, user200Tunnels.size());
|
||||
assertTrue(user200Tunnels.contains(e3));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,137 @@
|
||||
package com.sshmanager.service;
|
||||
|
||||
import com.sshmanager.entity.Connection;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.test.util.ReflectionTestUtils;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
/**
|
||||
* Unit tests for {@link QuickConnectionRegistry}.
|
||||
*
|
||||
* <p>These tests exercise the core lifecycle (create / get / remove) and the TTL
|
||||
* eviction logic without starting a Spring context.
|
||||
*/
|
||||
class QuickConnectionRegistryTest {
|
||||
|
||||
private QuickConnectionRegistry registry;
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
registry = new QuickConnectionRegistry();
|
||||
// Default TTL: 30 minutes (not relevant for most tests — overridden where needed)
|
||||
ReflectionTestUtils.setField(registry, "ttlMinutes", 30L);
|
||||
}
|
||||
|
||||
// ── basic lifecycle ───────────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
void create_returnsConnectionWithCorrectFields() {
|
||||
Connection conn = registry.create("192.168.1.1", "admin", 22, 1L);
|
||||
|
||||
assertNotNull(conn.getId());
|
||||
assertTrue(conn.getId() >= 10_000_001L, "ID should be in quick-connect namespace");
|
||||
assertEquals("192.168.1.1", conn.getHost());
|
||||
assertEquals("admin", conn.getUsername());
|
||||
assertEquals(22, conn.getPort());
|
||||
assertEquals(1L, conn.getUserId());
|
||||
assertEquals(Connection.AuthType.PASSWORD, conn.getAuthType());
|
||||
}
|
||||
|
||||
@Test
|
||||
void get_returnsConnectionAfterCreate() {
|
||||
Connection conn = registry.create("host", "user", 22, 1L);
|
||||
|
||||
Connection found = registry.get(conn.getId());
|
||||
assertNotNull(found);
|
||||
assertEquals(conn.getId(), found.getId());
|
||||
}
|
||||
|
||||
@Test
|
||||
void get_returnsNullForUnknownId() {
|
||||
assertNull(registry.get(999_999L));
|
||||
}
|
||||
|
||||
@Test
|
||||
void remove_deletesEntry() {
|
||||
Connection conn = registry.create("host", "user", 22, 1L);
|
||||
registry.remove(conn.getId());
|
||||
|
||||
assertNull(registry.get(conn.getId()));
|
||||
assertEquals(0, registry.size());
|
||||
}
|
||||
|
||||
@Test
|
||||
void persistentConnections_areNotAffectedByRegistry() {
|
||||
// Registry only holds quick connections; a regular DB ID (e.g. 1) is never stored here
|
||||
assertNull(registry.get(1L));
|
||||
}
|
||||
|
||||
@Test
|
||||
void multipleEntries_areStoredIndependently() {
|
||||
Connection a = registry.create("host-a", "ua", 22, 1L);
|
||||
Connection b = registry.create("host-b", "ub", 2222, 2L);
|
||||
|
||||
assertEquals(2, registry.size());
|
||||
assertNotNull(registry.get(a.getId()));
|
||||
assertNotNull(registry.get(b.getId()));
|
||||
}
|
||||
|
||||
// ── TTL eviction ─────────────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
void evictExpired_removesEntriesOlderThanTtl() throws Exception {
|
||||
// Set a 0-minute TTL so every entry is immediately "expired"
|
||||
ReflectionTestUtils.setField(registry, "ttlMinutes", 0L);
|
||||
|
||||
registry.create("host", "user", 22, 1L);
|
||||
assertEquals(1, registry.size());
|
||||
|
||||
// Small delay so lastAccessAt < now - ttl (ttl = 0 ms)
|
||||
Thread.sleep(5);
|
||||
registry.evictExpired();
|
||||
|
||||
assertEquals(0, registry.size());
|
||||
}
|
||||
|
||||
@Test
|
||||
void evictExpired_keepsRecentEntries() {
|
||||
// TTL = 30 min — newly created entries should survive
|
||||
registry.create("host", "user", 22, 1L);
|
||||
registry.evictExpired();
|
||||
|
||||
assertEquals(1, registry.size());
|
||||
}
|
||||
|
||||
@Test
|
||||
void touch_preventsEviction() throws Exception {
|
||||
// 1-ms TTL would normally evict immediately
|
||||
ReflectionTestUtils.setField(registry, "ttlMinutes", 0L);
|
||||
|
||||
Connection conn = registry.create("host", "user", 22, 1L);
|
||||
|
||||
// Touch resets lastAccessAt to now — entry should survive one sweep cycle
|
||||
// when the sweep happens within the same millisecond
|
||||
registry.touch(conn.getId());
|
||||
|
||||
// We cannot guarantee sub-millisecond execution, so just verify touch
|
||||
// doesn't throw and the entry is still accessible right after touch.
|
||||
assertNotNull(registry.get(conn.getId()));
|
||||
}
|
||||
|
||||
@Test
|
||||
void touch_onUnknownId_doesNotThrow() {
|
||||
assertDoesNotThrow(() -> registry.touch(999_999L));
|
||||
}
|
||||
|
||||
// ── id namespace ─────────────────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
void consecutiveCreates_produceIncreasingIds() {
|
||||
Connection first = registry.create("h", "u", 22, 1L);
|
||||
Connection second = registry.create("h", "u", 22, 1L);
|
||||
|
||||
assertTrue(second.getId() > first.getId());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user