diff --git a/upms/upms-biz/src/main/resources/application.yml b/upms/upms-biz/src/main/resources/application.yml index 90589dc..6c69c5c 100644 --- a/upms/upms-biz/src/main/resources/application.yml +++ b/upms/upms-biz/src/main/resources/application.yml @@ -76,9 +76,9 @@ security: - /error - /actuator/** - /code/** + - /rax/** # 临时白名单 - /static/** -# - /rax/** # - /topic/** # - /front/** # - /medicine/** diff --git a/vital-signs/src/main/java/com/rax/vital/config/WebSocketConfig.java b/vital-signs/src/main/java/com/rax/vital/config/WebSocketConfig.java index 0c57f68..0c6a83d 100644 --- a/vital-signs/src/main/java/com/rax/vital/config/WebSocketConfig.java +++ b/vital-signs/src/main/java/com/rax/vital/config/WebSocketConfig.java @@ -1,5 +1,6 @@ package com.rax.vital.config; +import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Configuration; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; @@ -9,19 +10,43 @@ import org.springframework.messaging.simp.stomp.StompCommand; import org.springframework.messaging.simp.stomp.StompHeaderAccessor; import org.springframework.messaging.support.ChannelInterceptor; import org.springframework.messaging.support.MessageHeaderAccessor; -import org.springframework.security.core.Authentication; +import org.springframework.security.access.AccessDeniedException; +import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.OAuth2TokenType; import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker; import org.springframework.web.socket.config.annotation.StompEndpointRegistry; import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer; +import org.springframework.web.socket.messaging.StompSubProtocolErrorHandler; @Configuration @EnableWebSocketMessageBroker public class WebSocketConfig implements WebSocketMessageBrokerConfigurer { + @Autowired + private OAuth2AuthorizationService authorizationService; + @Override public void registerStompEndpoints(StompEndpointRegistry registry) { registry.addEndpoint("/rax/chat", "/rax/ai-medicine", "/rax/doctor-medicine", "/rax/vital-signs", "/rax/SurgeryData") .setAllowedOrigins("*"); + // 错误处理 + registry.setErrorHandler(new StompSubProtocolErrorHandler() { + @Override + public Message handleClientMessageProcessingError(Message clientMessage, Throwable ex) { + return super.handleClientMessageProcessingError(clientMessage, ex); + } + + @Override + public Message handleErrorMessageToClient(Message errorMessage) { + return super.handleErrorMessageToClient(errorMessage); + } + + @Override + protected Message handleInternal(StompHeaderAccessor errorHeaderAccessor, byte[] errorPayload, Throwable cause, StompHeaderAccessor clientHeaderAccessor) { + return super.handleInternal(errorHeaderAccessor, errorPayload, cause, clientHeaderAccessor); + } + }); } @Override @@ -31,6 +56,10 @@ public class WebSocketConfig implements WebSocketMessageBrokerConfigurer { registry.setUserDestinationPrefix("/topic/user"); } + /** + * stomp未登录验证 + * @param registration + */ @Override public void configureClientInboundChannel(ChannelRegistration registration) { registration.interceptors(new ChannelInterceptor() { @@ -39,8 +68,13 @@ public class WebSocketConfig implements WebSocketMessageBrokerConfigurer { StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class); if (StompCommand.CONNECT.equals(accessor.getCommand())) { - //Authentication user = ... ; // access authentication header(s) -// accessor.setUser(user); + String token = accessor.getNativeHeader("access_token").get(0); + OAuth2Authorization authorization = authorizationService.findByToken(token, OAuth2TokenType.ACCESS_TOKEN); + if (authorization == null) { + throw new AccessDeniedException("Access is denied"); + } else { + accessor.setUser(authorization.getAttribute("java.security.Principal")); + } } return message; } diff --git a/vital-signs/src/main/java/com/rax/vital/medicine/controller/MedicineController.java b/vital-signs/src/main/java/com/rax/vital/medicine/controller/MedicineController.java index 967f9a8..b135976 100644 --- a/vital-signs/src/main/java/com/rax/vital/medicine/controller/MedicineController.java +++ b/vital-signs/src/main/java/com/rax/vital/medicine/controller/MedicineController.java @@ -13,6 +13,8 @@ import org.springframework.messaging.handler.annotation.MessageMapping; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; +import java.security.Principal; + /** * 用药 * @date 2024.2.19 @@ -33,12 +35,13 @@ public class MedicineController { private VitalSignTimer vitalSignTimer; @MessageMapping("/getSurgeryData") - public void doctorMedicine(String body) { + public void doctorMedicine(Principal principal, String body) { JSONObject params = JSONObject.parseObject(body); + String username = principal.getName(); if ("stop".equals(params.getString("status"))) { - vitalSignTimer.stopTimerTaskMongo(params.getString("db"), params.getString("username")); + vitalSignTimer.stopTimerTaskMongo(params.getString("db"), username); } else { - vitalSignTimer.createAndSendMessageMongo(params.getString("db"), params.getString("username")); + vitalSignTimer.createAndSendMessageMongo(params.getString("db"), username); } } diff --git a/vital-signs/src/main/java/com/rax/vital/timer/VitalSignTimer.java b/vital-signs/src/main/java/com/rax/vital/timer/VitalSignTimer.java index 15b4295..8344ad6 100644 --- a/vital-signs/src/main/java/com/rax/vital/timer/VitalSignTimer.java +++ b/vital-signs/src/main/java/com/rax/vital/timer/VitalSignTimer.java @@ -78,9 +78,8 @@ public class VitalSignTimer { * * @author zhaoyz */ - public void createAndSendMessageMongo(String database, String user) { - String account = SecurityUtils.getUser().getUsername(); - TimerTask task = timerMongoTaskMap.get(account + ":" + user + ":" + database); + public void createAndSendMessageMongo(String database, String account) { + TimerTask task = timerMongoTaskMap.get(account + ":" + database); if (task != null) { return; } @@ -111,14 +110,14 @@ public class VitalSignTimer { List flags = flagService.getFlags(template); result.put("flags", flags); - simpMessagingTemplate.convertAndSendToUser(account + ":" + user, "/surgeryData", result); + simpMessagingTemplate.convertAndSendToUser(account + ":" + database, "/surgeryData", result); } }; // 定时任务,设置1秒 Timer timer = new Timer(); timer.schedule(timerTask, 0, 1000); - timerMongoTaskMap.put(account + ":" + user + ":" + database, timerTask); + timerMongoTaskMap.put(account + ":" + database, timerTask); } /** @@ -170,12 +169,11 @@ public class VitalSignTimer { * @param database * @author zhaoyz */ - public synchronized void stopTimerTaskMongo(String database, String user) { - String account = SecurityUtils.getUser().getUsername(); - TimerTask timerTask = timerMongoTaskMap.get(account + ":" + user + ":" + database); + public synchronized void stopTimerTaskMongo(String database, String account) { + TimerTask timerTask = timerMongoTaskMap.get(account + ":" + database); if (timerTask != null) { timerTask.cancel(); - timerMongoTaskMap.remove(account + ":" + user + ":" + database); + timerMongoTaskMap.remove(account + ":" + database); MongoDBSource mongoDBSource = mongoDBSourceMap.get(database); mongoDBSource.decreaseCount();