rax-remote-2/vital-signs/src/main/java/com/rax/vital/config/WebSocketConfig.java
2024-03-15 16:55:59 +08:00

85 lines
3.5 KiB
Java

package com.rax.vital.config;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Configuration;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.simp.config.ChannelRegistration;
import org.springframework.messaging.simp.config.MessageBrokerRegistry;
import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.support.ChannelInterceptor;
import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.security.access.AccessDeniedException;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.OAuth2TokenType;
import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker;
import org.springframework.web.socket.config.annotation.StompEndpointRegistry;
import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer;
import org.springframework.web.socket.messaging.StompSubProtocolErrorHandler;
@Configuration
@EnableWebSocketMessageBroker
public class WebSocketConfig implements WebSocketMessageBrokerConfigurer {
@Autowired
private OAuth2AuthorizationService authorizationService;
@Override
public void registerStompEndpoints(StompEndpointRegistry registry) {
registry.addEndpoint("/rax/chat", "/rax/ai-medicine", "/rax/doctor-medicine", "/rax/vital-signs", "/rax/SurgeryData")
.setAllowedOrigins("*");
// 错误处理
registry.setErrorHandler(new StompSubProtocolErrorHandler() {
@Override
public Message<byte[]> handleClientMessageProcessingError(Message<byte[]> clientMessage, Throwable ex) {
return super.handleClientMessageProcessingError(clientMessage, ex);
}
@Override
public Message<byte[]> handleErrorMessageToClient(Message<byte[]> errorMessage) {
return super.handleErrorMessageToClient(errorMessage);
}
@Override
protected Message<byte[]> handleInternal(StompHeaderAccessor errorHeaderAccessor, byte[] errorPayload, Throwable cause, StompHeaderAccessor clientHeaderAccessor) {
return super.handleInternal(errorHeaderAccessor, errorPayload, cause, clientHeaderAccessor);
}
});
}
@Override
public void configureMessageBroker(MessageBrokerRegistry registry) {
registry.enableSimpleBroker("/topic");
registry.setApplicationDestinationPrefixes("/front");
registry.setUserDestinationPrefix("/topic/user");
}
/**
* stomp未登录验证
* @param registration
*/
@Override
public void configureClientInboundChannel(ChannelRegistration registration) {
registration.interceptors(new ChannelInterceptor() {
@Override
public Message<?> preSend(Message<?> message, MessageChannel channel) {
StompHeaderAccessor accessor =
MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
if (StompCommand.CONNECT.equals(accessor.getCommand())) {
String token = accessor.getNativeHeader("access_token").get(0);
OAuth2Authorization authorization = authorizationService.findByToken(token, OAuth2TokenType.ACCESS_TOKEN);
if (authorization == null) {
throw new AccessDeniedException("Access is denied");
} else {
accessor.setUser(authorization.getAttribute("java.security.Principal"));
}
}
return message;
}
});
}
}