반응형

기본 예제

build.gradle.kts

dependencies {
    // ...
    implementation("org.springframework.boot:spring-boot-starter-websocket")
}

DemoHandshakeInterceptor

@Component
class DemoHandshakeInterceptor : HttpSessionHandshakeInterceptor() {
    private val log = LoggerFactory.getLogger(this::class.java)

    override fun beforeHandshake(request: ServerHttpRequest, response: ServerHttpResponse, wsHandler: WebSocketHandler, attributes: MutableMap<String, Any>): Boolean {
        val params = UriComponentsBuilder.fromUri(request.uri).build().queryParams
        val roomId = params.getFirst("roomId")
        val clientId = params.getFirst("clientId")

        log.info("[Handshake] roomId: {}, clientId: {}", roomId, clientId)

        if (roomId.isNullOrBlank() || clientId.isNullOrBlank()) {
            return false
        }

        attributes["roomId"] = roomId
        attributes["clientId"] = clientId
        return true
    }
}

DemoWebSocketHandler

@Component
class DemoWebSocketHandler : TextWebSocketHandler() {
    private val log = LoggerFactory.getLogger(this::class.java)
    private val clients = ConcurrentHashMap<String, CopyOnWriteArrayList<WebSocketSession>>()

    override fun handleTextMessage(session: WebSocketSession, message: TextMessage) {
        val roomId = session.roomId()
        val clientId = session.clientId()

        clients[roomId]?.forEach {
            it.sendMessage(TextMessage(message.payload))
        }

        log.info("[send] roomId: {}, clientId: {}, message: {}", roomId, clientId, message.payload)
    }

    override fun afterConnectionEstablished(session: WebSocketSession) {
        val roomId = session.roomId()
        val clientId = session.clientId()
        clients.computeIfAbsent(roomId) { CopyOnWriteArrayList<WebSocketSession>() }.add(session)
        log.info("[connected] roomId: {}. clientId: {}, sessions: {}", roomId, clientId, clients[roomId]?.size)
    }

    override fun afterConnectionClosed(session: WebSocketSession, status: CloseStatus) {
        val roomId = session.roomId()
        val clientId = session.clientId()
        clients[roomId]?.remove(session)
        log.info("[disconnected] roomId: {}. clinetId: {}, sessions: {}", roomId, clientId, clients[roomId]?.size)
    }

    private fun WebSocketSession.roomId(): String {
        return this.attributes["roomId"] as String
    }

    private fun WebSocketSession.clientId(): String {
        return this.attributes["clientId"] as String
    }
}

WebSocketConfig

@EnableWebSocket
@Configuration
class WebSocketConfig(
    private val demoWebSocketHandler: DemoWebSocketHandler,
    private val demoHandshakeInterceptor: DemoHandshakeInterceptor,
) : WebSocketConfigurer {
    override fun registerWebSocketHandlers(registry: WebSocketHandlerRegistry) {
        registry.addHandler(demoWebSocketHandler, "/ws/demo")
            .addInterceptors(demoHandshakeInterceptor)
            .setAllowedOrigins("*")
    }
}

index.html

<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <title>Title</title>
</head>
<body>
<ul id="messages"></ul>
<button id="send">send</button>
<script>
    const roomId = "1";
    const clientId = Math.random().toString(36).substring(7);
    const demoWebSocket = new WebSocket(`ws://localhost:8080/ws/demo?roomId=${roomId}&clientId=${clientId}`);

    demoWebSocket.onopen = () => {
        console.log('open');
    };

    demoWebSocket.onclose = () => {
        console.log('close');
    };

    demoWebSocket.onerror = (error) => {
        console.error('error', error);
    };

    demoWebSocket.onmessage = (event) => {
        const li = document.createElement("li");
        li.textContent = event.data;
        document.querySelector("#messages").appendChild(li);
    };

    document.querySelector('#send').addEventListener('click', () => {
        demoWebSocket.send(`Hello World - ${Date.now()}`);
    });
</script>
</body>
</html>

시큐리티 예제

build.gradle.kts

dependencies {
    // ...
    implementation("org.springframework.boot:spring-boot-starter-security")
}

UserAuthenticationFilter

@Component
class UserAuthenticationFilter : OncePerRequestFilter() {
    private val log = LoggerFactory.getLogger(this::class.java)

    override fun doFilterInternal(request: HttpServletRequest, response: HttpServletResponse, filterChain: FilterChain) {
        val accessToken = request.getParameter("accessToken")

        if (accessToken != null) {
            try {
                val (roomId, clientId) = String(Base64.getDecoder().decode(accessToken), Charsets.UTF_8).split(":")
                
                // TODO 유효한 roomId, clientId인지 검증
                
                val user = User(roomId, clientId)

                val authentication = object : AbstractAuthenticationToken(listOf()) {
                    override fun getPrincipal() = user
                    override fun getCredentials() = accessToken
                    override fun isAuthenticated() = true
                }

                SecurityContextHolder.getContext().authentication = authentication
            } catch (e: Exception) {
                log.warn("token error", e)
            }
        }

        filterChain.doFilter(request, response)
    }

    data class User(
        val roomId: String,
        val clientId: String,
    )
}

DemoHandshakeInterceptor

@Component
class DemoHandshakeInterceptor : HttpSessionHandshakeInterceptor() {
    private val log = LoggerFactory.getLogger(this::class.java)

    override fun beforeHandshake(request: ServerHttpRequest, response: ServerHttpResponse, wsHandler: WebSocketHandler, attributes: MutableMap<String, Any>): Boolean {
        val authentication = SecurityContextHolder.getContext().authentication

        if (!authentication.isAuthenticated) {
            log.warn("authentication failed")
            return false
        }

        val user = authentication.principal as UserAuthenticationFilter.User

        attributes["roomId"] = user.roomId
        attributes["clientId"] = user.clientId
        return true
    }
}

SecurityConfig

@EnableWebSecurity
@Configuration
class SecurityConfig(
    private val userAuthenticationFilter: UserAuthenticationFilter,
) {
    @Bean
    fun filterChain(http: HttpSecurity): SecurityFilterChain {
        return http
            .authorizeHttpRequests { authorize ->
                authorize.requestMatchers("/ws/demo").authenticated()
                authorize.anyRequest().permitAll()
            }
            .addFilterAfter(userAuthenticationFilter, LogoutFilter::class.java)
            .csrf { it.disable() }
            .build()
    }
}

index.html

<script>
    const roomId = "1";
    const clientId = Math.random().toString(36).substring(7);
    const accessToken = btoa(`${roomId}:${clientId}`)
    const demoWebSocket = new WebSocket(`ws://localhost:8080/ws/demo?accessToken=${accessToken}`);
    // ...
</script>

 

반응형

+ Recent posts