package com.sshmanager.controller; import com.sshmanager.entity.Connection; import com.sshmanager.entity.User; import com.sshmanager.repository.ConnectionRepository; import com.sshmanager.repository.UserRepository; import com.sshmanager.service.ConnectionService; import com.sshmanager.service.SshService; import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.stereotype.Component; import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.handler.TextWebSocketHandler; import java.io.IOException; import java.io.InputStream; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; import java.util.concurrent.atomic.AtomicInteger; @Component public class TerminalWebSocketHandler extends TextWebSocketHandler { private final ConnectionRepository connectionRepository; private final UserRepository userRepository; private final ConnectionService connectionService; private final SshService sshService; private final ExecutorService executor; private final AtomicInteger sessionCount = new AtomicInteger(0); private final Map sessions = new ConcurrentHashMap<>(); private final Map lastActivity = new ConcurrentHashMap<>(); public TerminalWebSocketHandler(ConnectionRepository connectionRepository, UserRepository userRepository, ConnectionService connectionService, SshService sshService, @Qualifier("terminalWebSocketExecutor") ExecutorService executor) { this.connectionRepository = connectionRepository; this.userRepository = userRepository; this.connectionService = connectionService; this.sshService = sshService; this.executor = executor; } @Override public void afterConnectionEstablished(WebSocketSession webSocketSession) throws Exception { Long connectionId = (Long) webSocketSession.getAttributes().get("connectionId"); String username = (String) webSocketSession.getAttributes().get("username"); if (connectionId == null || username == null) { webSocketSession.close(CloseStatus.BAD_DATA); return; } User user = userRepository.findByUsername(username).orElse(null); if (user == null) { webSocketSession.close(CloseStatus.BAD_DATA); return; } Connection conn = connectionRepository.findById(connectionId).orElse(null); if (conn == null || !conn.getUserId().equals(user.getId())) { webSocketSession.close(CloseStatus.BAD_DATA); return; } String password = connectionService.getDecryptedPassword(conn); String privateKey = connectionService.getDecryptedPrivateKey(conn); String passphrase = connectionService.getDecryptedPassphrase(conn); try { SshService.SshSession sshSession = sshService.createShellSession(conn, password, privateKey, passphrase); sessions.put(webSocketSession.getId(), sshSession); lastActivity.put(webSocketSession.getId(), System.currentTimeMillis()); sessionCount.incrementAndGet(); executor.submit(() -> { try { InputStream in = sshSession.getOutputStream(); byte[] buf = new byte[4096]; int n; while (webSocketSession.isOpen() && sshSession.isConnected() && (n = in.read(buf)) >= 0) { String text = new String(buf, 0, n, "UTF-8"); webSocketSession.sendMessage(new TextMessage(text)); } } catch (Exception e) { if (webSocketSession.isOpen()) { try { webSocketSession.sendMessage(new TextMessage("\r\n[Connection closed]\r\n")); } catch (IOException ignored) { } } } }); } catch (Exception e) { webSocketSession.sendMessage(new TextMessage("\r\n[SSH Error: " + e.getMessage() + "]\r\n")); } } @Override protected void handleTextMessage(WebSocketSession webSocketSession, TextMessage message) throws Exception { SshService.SshSession sshSession = sessions.get(webSocketSession.getId()); if (sshSession != null && sshSession.isConnected()) { lastActivity.put(webSocketSession.getId(), System.currentTimeMillis()); sshSession.getInputStream().write(message.asBytes()); sshSession.getInputStream().flush(); } } @Override public void afterConnectionClosed(WebSocketSession webSocketSession, CloseStatus status) throws Exception { SshService.SshSession sshSession = sessions.remove(webSocketSession.getId()); lastActivity.remove(webSocketSession.getId()); if (sshSession != null) { sshSession.disconnect(); sessionCount.decrementAndGet(); } } }