package com.rax.vital.config; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Configuration; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; import org.springframework.messaging.simp.config.ChannelRegistration; import org.springframework.messaging.simp.config.MessageBrokerRegistry; import org.springframework.messaging.simp.stomp.StompCommand; import org.springframework.messaging.simp.stomp.StompHeaderAccessor; import org.springframework.messaging.support.ChannelInterceptor; import org.springframework.messaging.support.MessageHeaderAccessor; import org.springframework.security.access.AccessDeniedException; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.OAuth2TokenType; import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker; import org.springframework.web.socket.config.annotation.StompEndpointRegistry; import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer; import org.springframework.web.socket.messaging.StompSubProtocolErrorHandler; @Configuration @EnableWebSocketMessageBroker public class WebSocketConfig implements WebSocketMessageBrokerConfigurer { @Autowired private OAuth2AuthorizationService authorizationService; @Override public void registerStompEndpoints(StompEndpointRegistry registry) { registry.addEndpoint("/rax/chat", "/rax/ai-medicine", "/rax/doctor-medicine", "/rax/vital-signs", "/rax/SurgeryData") .setAllowedOrigins("*"); // 错误处理 registry.setErrorHandler(new StompSubProtocolErrorHandler() { @Override public Message handleClientMessageProcessingError(Message clientMessage, Throwable ex) { return super.handleClientMessageProcessingError(clientMessage, ex); } @Override public Message handleErrorMessageToClient(Message errorMessage) { return super.handleErrorMessageToClient(errorMessage); } @Override protected Message handleInternal(StompHeaderAccessor errorHeaderAccessor, byte[] errorPayload, Throwable cause, StompHeaderAccessor clientHeaderAccessor) { return super.handleInternal(errorHeaderAccessor, errorPayload, cause, clientHeaderAccessor); } }); } @Override public void configureMessageBroker(MessageBrokerRegistry registry) { registry.enableSimpleBroker("/topic"); registry.setApplicationDestinationPrefixes("/front"); registry.setUserDestinationPrefix("/topic/user"); } /** * stomp未登录验证 * @param registration */ @Override public void configureClientInboundChannel(ChannelRegistration registration) { registration.interceptors(new ChannelInterceptor() { @Override public Message preSend(Message message, MessageChannel channel) { StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class); if (StompCommand.CONNECT.equals(accessor.getCommand())) { String token = accessor.getNativeHeader("access_token").get(0); OAuth2Authorization authorization = authorizationService.findByToken(token, OAuth2TokenType.ACCESS_TOKEN); if (authorization == null) { throw new AccessDeniedException("Access is denied"); } else { accessor.setUser(authorization.getAttribute("java.security.Principal")); } } return message; } }); } }