stomp请求通过token验证

This commit is contained in:
zhaoyz 2024-03-15 16:55:59 +08:00
parent 971d90abb0
commit 458a79d66e
4 changed files with 51 additions and 16 deletions

View File

@ -76,9 +76,9 @@ security:
- /error - /error
- /actuator/** - /actuator/**
- /code/** - /code/**
- /rax/**
# 临时白名单 # 临时白名单
- /static/** - /static/**
# - /rax/**
# - /topic/** # - /topic/**
# - /front/** # - /front/**
# - /medicine/** # - /medicine/**

View File

@ -1,5 +1,6 @@
package com.rax.vital.config; package com.rax.vital.config;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;
import org.springframework.messaging.Message; import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel; import org.springframework.messaging.MessageChannel;
@ -9,19 +10,43 @@ import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor; import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.support.ChannelInterceptor; import org.springframework.messaging.support.ChannelInterceptor;
import org.springframework.messaging.support.MessageHeaderAccessor; import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.security.core.Authentication; import org.springframework.security.access.AccessDeniedException;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.OAuth2TokenType;
import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker; import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker;
import org.springframework.web.socket.config.annotation.StompEndpointRegistry; import org.springframework.web.socket.config.annotation.StompEndpointRegistry;
import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer; import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer;
import org.springframework.web.socket.messaging.StompSubProtocolErrorHandler;
@Configuration @Configuration
@EnableWebSocketMessageBroker @EnableWebSocketMessageBroker
public class WebSocketConfig implements WebSocketMessageBrokerConfigurer { public class WebSocketConfig implements WebSocketMessageBrokerConfigurer {
@Autowired
private OAuth2AuthorizationService authorizationService;
@Override @Override
public void registerStompEndpoints(StompEndpointRegistry registry) { public void registerStompEndpoints(StompEndpointRegistry registry) {
registry.addEndpoint("/rax/chat", "/rax/ai-medicine", "/rax/doctor-medicine", "/rax/vital-signs", "/rax/SurgeryData") registry.addEndpoint("/rax/chat", "/rax/ai-medicine", "/rax/doctor-medicine", "/rax/vital-signs", "/rax/SurgeryData")
.setAllowedOrigins("*"); .setAllowedOrigins("*");
// 错误处理
registry.setErrorHandler(new StompSubProtocolErrorHandler() {
@Override
public Message<byte[]> handleClientMessageProcessingError(Message<byte[]> clientMessage, Throwable ex) {
return super.handleClientMessageProcessingError(clientMessage, ex);
}
@Override
public Message<byte[]> handleErrorMessageToClient(Message<byte[]> errorMessage) {
return super.handleErrorMessageToClient(errorMessage);
}
@Override
protected Message<byte[]> handleInternal(StompHeaderAccessor errorHeaderAccessor, byte[] errorPayload, Throwable cause, StompHeaderAccessor clientHeaderAccessor) {
return super.handleInternal(errorHeaderAccessor, errorPayload, cause, clientHeaderAccessor);
}
});
} }
@Override @Override
@ -31,6 +56,10 @@ public class WebSocketConfig implements WebSocketMessageBrokerConfigurer {
registry.setUserDestinationPrefix("/topic/user"); registry.setUserDestinationPrefix("/topic/user");
} }
/**
* stomp未登录验证
* @param registration
*/
@Override @Override
public void configureClientInboundChannel(ChannelRegistration registration) { public void configureClientInboundChannel(ChannelRegistration registration) {
registration.interceptors(new ChannelInterceptor() { registration.interceptors(new ChannelInterceptor() {
@ -39,8 +68,13 @@ public class WebSocketConfig implements WebSocketMessageBrokerConfigurer {
StompHeaderAccessor accessor = StompHeaderAccessor accessor =
MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class); MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
if (StompCommand.CONNECT.equals(accessor.getCommand())) { if (StompCommand.CONNECT.equals(accessor.getCommand())) {
//Authentication user = ... ; // access authentication header(s) String token = accessor.getNativeHeader("access_token").get(0);
// accessor.setUser(user); OAuth2Authorization authorization = authorizationService.findByToken(token, OAuth2TokenType.ACCESS_TOKEN);
if (authorization == null) {
throw new AccessDeniedException("Access is denied");
} else {
accessor.setUser(authorization.getAttribute("java.security.Principal"));
}
} }
return message; return message;
} }

View File

@ -13,6 +13,8 @@ import org.springframework.messaging.handler.annotation.MessageMapping;
import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController; import org.springframework.web.bind.annotation.RestController;
import java.security.Principal;
/** /**
* 用药 * 用药
* @date 2024.2.19 * @date 2024.2.19
@ -33,12 +35,13 @@ public class MedicineController {
private VitalSignTimer vitalSignTimer; private VitalSignTimer vitalSignTimer;
@MessageMapping("/getSurgeryData") @MessageMapping("/getSurgeryData")
public void doctorMedicine(String body) { public void doctorMedicine(Principal principal, String body) {
JSONObject params = JSONObject.parseObject(body); JSONObject params = JSONObject.parseObject(body);
String username = principal.getName();
if ("stop".equals(params.getString("status"))) { if ("stop".equals(params.getString("status"))) {
vitalSignTimer.stopTimerTaskMongo(params.getString("db"), params.getString("username")); vitalSignTimer.stopTimerTaskMongo(params.getString("db"), username);
} else { } else {
vitalSignTimer.createAndSendMessageMongo(params.getString("db"), params.getString("username")); vitalSignTimer.createAndSendMessageMongo(params.getString("db"), username);
} }
} }

View File

@ -78,9 +78,8 @@ public class VitalSignTimer {
* *
* @author zhaoyz * @author zhaoyz
*/ */
public void createAndSendMessageMongo(String database, String user) { public void createAndSendMessageMongo(String database, String account) {
String account = SecurityUtils.getUser().getUsername(); TimerTask task = timerMongoTaskMap.get(account + ":" + database);
TimerTask task = timerMongoTaskMap.get(account + ":" + user + ":" + database);
if (task != null) { if (task != null) {
return; return;
} }
@ -111,14 +110,14 @@ public class VitalSignTimer {
List flags = flagService.getFlags(template); List flags = flagService.getFlags(template);
result.put("flags", flags); result.put("flags", flags);
simpMessagingTemplate.convertAndSendToUser(account + ":" + user, "/surgeryData", result); simpMessagingTemplate.convertAndSendToUser(account + ":" + database, "/surgeryData", result);
} }
}; };
// 定时任务设置1秒 // 定时任务设置1秒
Timer timer = new Timer(); Timer timer = new Timer();
timer.schedule(timerTask, 0, 1000); timer.schedule(timerTask, 0, 1000);
timerMongoTaskMap.put(account + ":" + user + ":" + database, timerTask); timerMongoTaskMap.put(account + ":" + database, timerTask);
} }
/** /**
@ -170,12 +169,11 @@ public class VitalSignTimer {
* @param database * @param database
* @author zhaoyz * @author zhaoyz
*/ */
public synchronized void stopTimerTaskMongo(String database, String user) { public synchronized void stopTimerTaskMongo(String database, String account) {
String account = SecurityUtils.getUser().getUsername(); TimerTask timerTask = timerMongoTaskMap.get(account + ":" + database);
TimerTask timerTask = timerMongoTaskMap.get(account + ":" + user + ":" + database);
if (timerTask != null) { if (timerTask != null) {
timerTask.cancel(); timerTask.cancel();
timerMongoTaskMap.remove(account + ":" + user + ":" + database); timerMongoTaskMap.remove(account + ":" + database);
MongoDBSource mongoDBSource = mongoDBSourceMap.get(database); MongoDBSource mongoDBSource = mongoDBSourceMap.get(database);
mongoDBSource.decreaseCount(); mongoDBSource.decreaseCount();