stomp channelInterceptor中根据token中ws的sessionId进行查询数据断开连接

This commit is contained in:
zhaoyz 2024-03-19 09:26:50 +08:00
parent 458a79d66e
commit a93864d985
4 changed files with 112 additions and 53 deletions

View File

@ -1,19 +1,12 @@
package com.rax.vital.config;
import com.rax.vital.interceptor.WSChannelInterceptor;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Configuration;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.simp.config.ChannelRegistration;
import org.springframework.messaging.simp.config.MessageBrokerRegistry;
import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.support.ChannelInterceptor;
import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.security.access.AccessDeniedException;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.OAuth2TokenType;
import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker;
import org.springframework.web.socket.config.annotation.StompEndpointRegistry;
import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer;
@ -24,7 +17,7 @@ import org.springframework.web.socket.messaging.StompSubProtocolErrorHandler;
public class WebSocketConfig implements WebSocketMessageBrokerConfigurer {
@Autowired
private OAuth2AuthorizationService authorizationService;
private WSChannelInterceptor wsChannelInterceptor;
@Override
public void registerStompEndpoints(StompEndpointRegistry registry) {
@ -32,16 +25,36 @@ public class WebSocketConfig implements WebSocketMessageBrokerConfigurer {
.setAllowedOrigins("*");
// 错误处理
registry.setErrorHandler(new StompSubProtocolErrorHandler() {
/**
*
* @param clientMessage the client message related to the error, possibly
* {@code null} if error occurred while parsing a WebSocket message
* @param ex the cause for the error, never {@code null}
* @return
*/
@Override
public Message<byte[]> handleClientMessageProcessingError(Message<byte[]> clientMessage, Throwable ex) {
return super.handleClientMessageProcessingError(clientMessage, ex);
}
/**
*
* @param errorMessage the error message, never {@code null}
* @return
*/
@Override
public Message<byte[]> handleErrorMessageToClient(Message<byte[]> errorMessage) {
return super.handleErrorMessageToClient(errorMessage);
}
/**
*
* @param errorHeaderAccessor
* @param errorPayload
* @param cause
* @param clientHeaderAccessor
* @return
*/
@Override
protected Message<byte[]> handleInternal(StompHeaderAccessor errorHeaderAccessor, byte[] errorPayload, Throwable cause, StompHeaderAccessor clientHeaderAccessor) {
return super.handleInternal(errorHeaderAccessor, errorPayload, cause, clientHeaderAccessor);
@ -58,27 +71,12 @@ public class WebSocketConfig implements WebSocketMessageBrokerConfigurer {
/**
* stomp未登录验证
*
* @param registration
*/
@Override
public void configureClientInboundChannel(ChannelRegistration registration) {
registration.interceptors(new ChannelInterceptor() {
@Override
public Message<?> preSend(Message<?> message, MessageChannel channel) {
StompHeaderAccessor accessor =
MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
if (StompCommand.CONNECT.equals(accessor.getCommand())) {
String token = accessor.getNativeHeader("access_token").get(0);
OAuth2Authorization authorization = authorizationService.findByToken(token, OAuth2TokenType.ACCESS_TOKEN);
if (authorization == null) {
throw new AccessDeniedException("Access is denied");
} else {
accessor.setUser(authorization.getAttribute("java.security.Principal"));
}
}
return message;
}
});
registration.interceptors(wsChannelInterceptor);
}
}

View File

@ -0,0 +1,60 @@
package com.rax.vital.interceptor;
import com.rax.vital.timer.VitalSignTimer;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.support.ChannelInterceptor;
import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.security.access.AccessDeniedException;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.OAuth2TokenType;
import org.springframework.stereotype.Component;
import java.util.List;
@Slf4j
@Component
public class WSChannelInterceptor implements ChannelInterceptor {
@Autowired
private OAuth2AuthorizationService authorizationService;
@Autowired
private VitalSignTimer vitalSignTimer;
@Override
public Message<?> preSend(Message<?> message, MessageChannel channel) {
StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
if (accessor != null) {
List<String> accessToken = accessor.getNativeHeader("token");
if (accessToken != null && !accessToken.isEmpty()) {
String token = accessToken.get(0);
OAuth2Authorization authorization = authorizationService.findByToken(token, OAuth2TokenType.ACCESS_TOKEN);
if (StompCommand.CONNECT.equals(accessor.getCommand())) {
if (authorization == null) {
throw new AccessDeniedException("Access is denied");
}
}
}
if (StompCommand.ABORT.equals(accessor.getCommand())) {
System.out.println("StompCommand.ABORT");
} else if (StompCommand.DISCONNECT.equals(accessor.getCommand())
|| StompCommand.UNSUBSCRIBE.equals(accessor.getCommand())) {
String simpSessionId = (String) accessor.getHeader("simpSessionId");
vitalSignTimer.stopTimerTaskMongo(simpSessionId);
}
}
return message;
}
}

View File

@ -6,17 +6,20 @@ import com.rax.vital.medicine.service.DoctorMedicineService;
import com.rax.vital.timer.VitalSignTimer;
import io.swagger.v3.oas.annotations.security.SecurityRequirement;
import io.swagger.v3.oas.annotations.tags.Tag;
import lombok.RequiredArgsConstructor;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.HttpHeaders;
import org.springframework.messaging.MessageHeaders;
import org.springframework.messaging.handler.annotation.MessageMapping;
import org.springframework.security.access.AccessDeniedException;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.OAuth2TokenType;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import java.security.Principal;
/**
* 用药
*
* @date 2024.2.19
*/
@RestController
@ -34,14 +37,19 @@ public class MedicineController {
@Autowired
private VitalSignTimer vitalSignTimer;
@Autowired
private OAuth2AuthorizationService authorizationService;
@MessageMapping("/getSurgeryData")
public void doctorMedicine(Principal principal, String body) {
public void doctorMedicine(MessageHeaders messageHeaders, String body) {
JSONObject params = JSONObject.parseObject(body);
String username = principal.getName();
if ("stop".equals(params.getString("status"))) {
vitalSignTimer.stopTimerTaskMongo(params.getString("db"), username);
OAuth2Authorization authorization = authorizationService.findByToken(params.getString("token"), OAuth2TokenType.ACCESS_TOKEN);
if (authorization != null) {
String username = authorization.getPrincipalName();
String simpSessionId = messageHeaders.get("simpSessionId", String.class);
vitalSignTimer.createAndSendMessageMongo(params.getString("db"), username, simpSessionId);
} else {
vitalSignTimer.createAndSendMessageMongo(params.getString("db"), username);
throw new AccessDeniedException("Access is denied");
}
}

View File

@ -78,18 +78,17 @@ public class VitalSignTimer {
*
* @author zhaoyz
*/
public void createAndSendMessageMongo(String database, String account) {
TimerTask task = timerMongoTaskMap.get(account + ":" + database);
public void createAndSendMessageMongo(String database, String username, String simpSessionId) {
TimerTask task = timerMongoTaskMap.get(simpSessionId);
if (task != null) {
return;
}
MongoDBSource mongoDBSource = mongoDBSourceMap.get(database);
MongoDBSource mongoDBSource = mongoDBSourceMap.get(simpSessionId);
if (mongoDBSource == null) {
mongoDBSource = new MongoDBSource(mongoDBHost, mongoPassword, mongoUsername, database);
mongoDBSourceMap.put(database, mongoDBSource);
mongoDBSourceMap.put(simpSessionId, mongoDBSource);
mongoDBSource.open();
mongoDBSource.increaseCount();
}
MongoDBSource finalMongoDBSource = mongoDBSource;
@ -110,14 +109,14 @@ public class VitalSignTimer {
List flags = flagService.getFlags(template);
result.put("flags", flags);
simpMessagingTemplate.convertAndSendToUser(account + ":" + database, "/surgeryData", result);
simpMessagingTemplate.convertAndSendToUser(username + ":" + database, "/surgeryData", result);
}
};
// 定时任务设置1秒
Timer timer = new Timer();
timer.schedule(timerTask, 0, 1000);
timerMongoTaskMap.put(account + ":" + database, timerTask);
timerMongoTaskMap.put(simpSessionId, timerTask);
}
/**
@ -166,22 +165,16 @@ public class VitalSignTimer {
/**
* 停止指定的某个用户查询的患者数据库定时器数据库类型是MongoDB
*
* @param database
* @author zhaoyz
*/
public synchronized void stopTimerTaskMongo(String database, String account) {
TimerTask timerTask = timerMongoTaskMap.get(account + ":" + database);
public synchronized void stopTimerTaskMongo(String simpSessionId) {
TimerTask timerTask = timerMongoTaskMap.get(simpSessionId);
if (timerTask != null) {
timerTask.cancel();
timerMongoTaskMap.remove(account + ":" + database);
MongoDBSource mongoDBSource = mongoDBSourceMap.get(database);
mongoDBSource.decreaseCount();
int count = mongoDBSource.getCount();
if (count == 0) {
mongoDBSource.close();
mongoDBSourceMap.remove(database);
}
MongoDBSource mongoDBSource = mongoDBSourceMap.get(simpSessionId);
mongoDBSource.close();
timerMongoTaskMap.remove(simpSessionId);
mongoDBSourceMap.remove(simpSessionId);
}
}