diff --git a/vital-signs/src/main/java/com/rax/vital/config/WebSocketConfig.java b/vital-signs/src/main/java/com/rax/vital/config/WebSocketConfig.java index 0c6a83d..6c782cf 100644 --- a/vital-signs/src/main/java/com/rax/vital/config/WebSocketConfig.java +++ b/vital-signs/src/main/java/com/rax/vital/config/WebSocketConfig.java @@ -1,19 +1,12 @@ package com.rax.vital.config; +import com.rax.vital.interceptor.WSChannelInterceptor; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Configuration; import org.springframework.messaging.Message; -import org.springframework.messaging.MessageChannel; import org.springframework.messaging.simp.config.ChannelRegistration; import org.springframework.messaging.simp.config.MessageBrokerRegistry; -import org.springframework.messaging.simp.stomp.StompCommand; import org.springframework.messaging.simp.stomp.StompHeaderAccessor; -import org.springframework.messaging.support.ChannelInterceptor; -import org.springframework.messaging.support.MessageHeaderAccessor; -import org.springframework.security.access.AccessDeniedException; -import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; -import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; -import org.springframework.security.oauth2.server.authorization.OAuth2TokenType; import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker; import org.springframework.web.socket.config.annotation.StompEndpointRegistry; import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer; @@ -24,7 +17,7 @@ import org.springframework.web.socket.messaging.StompSubProtocolErrorHandler; public class WebSocketConfig implements WebSocketMessageBrokerConfigurer { @Autowired - private OAuth2AuthorizationService authorizationService; + private WSChannelInterceptor wsChannelInterceptor; @Override public void registerStompEndpoints(StompEndpointRegistry registry) { @@ -32,16 +25,36 @@ public class WebSocketConfig implements WebSocketMessageBrokerConfigurer { .setAllowedOrigins("*"); // 错误处理 registry.setErrorHandler(new StompSubProtocolErrorHandler() { + /** + * + * @param clientMessage the client message related to the error, possibly + * {@code null} if error occurred while parsing a WebSocket message + * @param ex the cause for the error, never {@code null} + * @return + */ @Override public Message handleClientMessageProcessingError(Message clientMessage, Throwable ex) { return super.handleClientMessageProcessingError(clientMessage, ex); } + /** + * + * @param errorMessage the error message, never {@code null} + * @return + */ @Override public Message handleErrorMessageToClient(Message errorMessage) { return super.handleErrorMessageToClient(errorMessage); } + /** + * + * @param errorHeaderAccessor + * @param errorPayload + * @param cause + * @param clientHeaderAccessor + * @return + */ @Override protected Message handleInternal(StompHeaderAccessor errorHeaderAccessor, byte[] errorPayload, Throwable cause, StompHeaderAccessor clientHeaderAccessor) { return super.handleInternal(errorHeaderAccessor, errorPayload, cause, clientHeaderAccessor); @@ -58,27 +71,12 @@ public class WebSocketConfig implements WebSocketMessageBrokerConfigurer { /** * stomp未登录验证 + * * @param registration */ @Override public void configureClientInboundChannel(ChannelRegistration registration) { - registration.interceptors(new ChannelInterceptor() { - @Override - public Message preSend(Message message, MessageChannel channel) { - StompHeaderAccessor accessor = - MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class); - if (StompCommand.CONNECT.equals(accessor.getCommand())) { - String token = accessor.getNativeHeader("access_token").get(0); - OAuth2Authorization authorization = authorizationService.findByToken(token, OAuth2TokenType.ACCESS_TOKEN); - if (authorization == null) { - throw new AccessDeniedException("Access is denied"); - } else { - accessor.setUser(authorization.getAttribute("java.security.Principal")); - } - } - return message; - } - }); + registration.interceptors(wsChannelInterceptor); } } diff --git a/vital-signs/src/main/java/com/rax/vital/interceptor/WSChannelInterceptor.java b/vital-signs/src/main/java/com/rax/vital/interceptor/WSChannelInterceptor.java new file mode 100644 index 0000000..2c0a19f --- /dev/null +++ b/vital-signs/src/main/java/com/rax/vital/interceptor/WSChannelInterceptor.java @@ -0,0 +1,60 @@ +package com.rax.vital.interceptor; + +import com.rax.vital.timer.VitalSignTimer; +import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.simp.stomp.StompCommand; +import org.springframework.messaging.simp.stomp.StompHeaderAccessor; +import org.springframework.messaging.support.ChannelInterceptor; +import org.springframework.messaging.support.MessageHeaderAccessor; +import org.springframework.security.access.AccessDeniedException; +import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.OAuth2TokenType; +import org.springframework.stereotype.Component; + +import java.util.List; + +@Slf4j +@Component +public class WSChannelInterceptor implements ChannelInterceptor { + @Autowired + private OAuth2AuthorizationService authorizationService; + + @Autowired + private VitalSignTimer vitalSignTimer; + + @Override + public Message preSend(Message message, MessageChannel channel) { + + StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class); + + if (accessor != null) { + + List accessToken = accessor.getNativeHeader("token"); + + if (accessToken != null && !accessToken.isEmpty()) { + + String token = accessToken.get(0); + OAuth2Authorization authorization = authorizationService.findByToken(token, OAuth2TokenType.ACCESS_TOKEN); + + if (StompCommand.CONNECT.equals(accessor.getCommand())) { + if (authorization == null) { + throw new AccessDeniedException("Access is denied"); + } + } + } + + if (StompCommand.ABORT.equals(accessor.getCommand())) { + System.out.println("StompCommand.ABORT"); + } else if (StompCommand.DISCONNECT.equals(accessor.getCommand()) + || StompCommand.UNSUBSCRIBE.equals(accessor.getCommand())) { + String simpSessionId = (String) accessor.getHeader("simpSessionId"); + vitalSignTimer.stopTimerTaskMongo(simpSessionId); + } + } + return message; + } +} diff --git a/vital-signs/src/main/java/com/rax/vital/medicine/controller/MedicineController.java b/vital-signs/src/main/java/com/rax/vital/medicine/controller/MedicineController.java index b135976..ccaeeee 100644 --- a/vital-signs/src/main/java/com/rax/vital/medicine/controller/MedicineController.java +++ b/vital-signs/src/main/java/com/rax/vital/medicine/controller/MedicineController.java @@ -6,17 +6,20 @@ import com.rax.vital.medicine.service.DoctorMedicineService; import com.rax.vital.timer.VitalSignTimer; import io.swagger.v3.oas.annotations.security.SecurityRequirement; import io.swagger.v3.oas.annotations.tags.Tag; -import lombok.RequiredArgsConstructor; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.http.HttpHeaders; +import org.springframework.messaging.MessageHeaders; import org.springframework.messaging.handler.annotation.MessageMapping; +import org.springframework.security.access.AccessDeniedException; +import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.OAuth2TokenType; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; -import java.security.Principal; - /** * 用药 + * * @date 2024.2.19 */ @RestController @@ -34,14 +37,19 @@ public class MedicineController { @Autowired private VitalSignTimer vitalSignTimer; + @Autowired + private OAuth2AuthorizationService authorizationService; + @MessageMapping("/getSurgeryData") - public void doctorMedicine(Principal principal, String body) { + public void doctorMedicine(MessageHeaders messageHeaders, String body) { JSONObject params = JSONObject.parseObject(body); - String username = principal.getName(); - if ("stop".equals(params.getString("status"))) { - vitalSignTimer.stopTimerTaskMongo(params.getString("db"), username); + OAuth2Authorization authorization = authorizationService.findByToken(params.getString("token"), OAuth2TokenType.ACCESS_TOKEN); + if (authorization != null) { + String username = authorization.getPrincipalName(); + String simpSessionId = messageHeaders.get("simpSessionId", String.class); + vitalSignTimer.createAndSendMessageMongo(params.getString("db"), username, simpSessionId); } else { - vitalSignTimer.createAndSendMessageMongo(params.getString("db"), username); + throw new AccessDeniedException("Access is denied"); } } diff --git a/vital-signs/src/main/java/com/rax/vital/timer/VitalSignTimer.java b/vital-signs/src/main/java/com/rax/vital/timer/VitalSignTimer.java index 8344ad6..c7c0160 100644 --- a/vital-signs/src/main/java/com/rax/vital/timer/VitalSignTimer.java +++ b/vital-signs/src/main/java/com/rax/vital/timer/VitalSignTimer.java @@ -78,18 +78,17 @@ public class VitalSignTimer { * * @author zhaoyz */ - public void createAndSendMessageMongo(String database, String account) { - TimerTask task = timerMongoTaskMap.get(account + ":" + database); + public void createAndSendMessageMongo(String database, String username, String simpSessionId) { + TimerTask task = timerMongoTaskMap.get(simpSessionId); if (task != null) { return; } - MongoDBSource mongoDBSource = mongoDBSourceMap.get(database); + MongoDBSource mongoDBSource = mongoDBSourceMap.get(simpSessionId); if (mongoDBSource == null) { mongoDBSource = new MongoDBSource(mongoDBHost, mongoPassword, mongoUsername, database); - mongoDBSourceMap.put(database, mongoDBSource); + mongoDBSourceMap.put(simpSessionId, mongoDBSource); mongoDBSource.open(); - mongoDBSource.increaseCount(); } MongoDBSource finalMongoDBSource = mongoDBSource; @@ -110,14 +109,14 @@ public class VitalSignTimer { List flags = flagService.getFlags(template); result.put("flags", flags); - simpMessagingTemplate.convertAndSendToUser(account + ":" + database, "/surgeryData", result); + simpMessagingTemplate.convertAndSendToUser(username + ":" + database, "/surgeryData", result); } }; // 定时任务,设置1秒 Timer timer = new Timer(); timer.schedule(timerTask, 0, 1000); - timerMongoTaskMap.put(account + ":" + database, timerTask); + timerMongoTaskMap.put(simpSessionId, timerTask); } /** @@ -166,22 +165,16 @@ public class VitalSignTimer { /** * 停止指定的某个用户查询的患者数据库定时器,数据库类型是MongoDB * - * @param database * @author zhaoyz */ - public synchronized void stopTimerTaskMongo(String database, String account) { - TimerTask timerTask = timerMongoTaskMap.get(account + ":" + database); + public synchronized void stopTimerTaskMongo(String simpSessionId) { + TimerTask timerTask = timerMongoTaskMap.get(simpSessionId); if (timerTask != null) { timerTask.cancel(); - timerMongoTaskMap.remove(account + ":" + database); - - MongoDBSource mongoDBSource = mongoDBSourceMap.get(database); - mongoDBSource.decreaseCount(); - int count = mongoDBSource.getCount(); - if (count == 0) { - mongoDBSource.close(); - mongoDBSourceMap.remove(database); - } + MongoDBSource mongoDBSource = mongoDBSourceMap.get(simpSessionId); + mongoDBSource.close(); + timerMongoTaskMap.remove(simpSessionId); + mongoDBSourceMap.remove(simpSessionId); } }