spring boot websocket

WebSocket原生注解

  • pom.xml
<dependency>
  <groupId>org.springframework.boot</groupId>
  <artifactId>spring-boot-starter-websocket</artifactId>
</dependency>
  • 配置类WebSocketConfig,这里开启了配置之后springboot才会去扫描对应的注解
@Configuration
@EnableWebSocket
public class WebSocketConfig {

    @Bean
    public ServerEndpointExporter serverEndpoint() {
        return new ServerEndpointExporter();
    }
}
  • 处理消息类Msg
package com.osvue.env.app.msg;

import java.io.IOException;

import javax.websocket.OnClose;
import javax.websocket.OnMessage;
import javax.websocket.OnOpen;
import javax.websocket.Session;
import javax.websocket.server.ServerEndpoint;

import org.springframework.stereotype.Component;
/**
 * 
 * @ClassName:  Msg   
 * @Description:TODO(web socket 原生注解)   
 * @author: hzq 
 * @date:   2023-4-25 11:08:30    
 * @Copyright: 2023
 */
@ServerEndpoint("/skt")
@Component
public class Msg {
    /**
     * 连接成功
     * @param session
     */
    @OnOpen
    public void onOpen(Session session) {
        System.out.println("连接成功");
    }

    /**
     * 连接关闭
     * @param session
     */
    @OnClose
    public void onClose(Session session) {
        System.out.println("连接关闭");
    }

    /**
     * 接收到消息
     * @param text
     */
    @OnMessage
    public String onMsg(String text) throws IOException {
        return "servet 发送:" + text;
    }
}
  • 后面如果要对客户端发送消息的话使用session.getBasicRemote().sendText(XXX)

  • @ServerEndpoint

  • 通过这个 spring boot 就可以知道你暴露出去的 ws 应用的路径,有点类似我们经常用的@RequestMapping
  • 比如你的启动端口是8080,而这个注解的值是ws,那我们就可以通过 ws://127.0.0.1:8080/ws 来连接你的应用。
  • @OnOpen
  • 当 websocket 建立连接成功后会触发这个注解修饰的方法。

  • @OnClose

  • 当 websocket 建立的连接断开后会触发这个注解修饰的方法。

  • @OnMessage

  • 当客户端发送消息到服务端时,会触发这个注解修改的方法,它有一个 String 入参表明客户端传入的值。
  • @OnError
  • 当 websocket 建立连接时出现异常会触发这个注解修饰的方法。

Spring封装的WebSocket

WebSocketConfig 配置 HttpAuthHandler myInterceptor HttpAuthHandler 定义连接开始 和关闭 myInterceptor 连接前 增加拦截 WsSessionManager

  • spring同样也为我们提供了WebSocket的封装,这种方式可以自己配置拦截器

  • 在tcp握手之前对请求进行一次处理,可以避免一些恶意的连接。

  • 配置类WebSocketConfig,通过实现 WebSocketConfigurer 类并覆盖相应的方法进行 websocket 的配置。

  • 我们主要覆盖 registerWebSocketHandlers 这个方法。通过向 WebSocketHandlerRegistry 设置不同参数来进行配置。

  • 其中 addHandler方法添加我们的 ws 的 handler 处理类,第二个参数是你暴露出的 ws 路径。

  • addInterceptors添加我们写的拦截器。setAllowedOrigins这个是关闭跨域校验。

  • WebSocketConfig

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.config.annotation.EnableWebSocket;
import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;
/**
 * 
 * @ClassName:  WebSocketConfig   
 * @Description:TODO(spring boot config)   
 * @author: hzq 
 * @date:   2023-4-3 15:57:26    
 * @Copyright: 2023
 */
@Configuration
@EnableWebSocket
public class WebSocketConfig implements WebSocketConfigurer {

    @Autowired
    private HttpAuthHandler httpAuthHandler;
    @Autowired
    private MyInterceptor myInterceptor;

    @Override
    public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
        registry
                .addHandler(httpAuthHandler, "itg")
                .addInterceptors(myInterceptor)
                .setAllowedOrigins("*");
    }
}
  • 处理器和拦截器
  • HttpAuthHandler用于处理ws的消息,通过继承 TextWebSocketHandler 类并覆盖相应方法,可以对 websocket 的事件进行处理,这里可以同原生注解的那几个注解连起来看

  • afterConnectionEstablished 方法是在 socket 连接成功后被触发,同原生注解里的 @OnOpen 功能

  • afterConnectionClosed方法是在 socket 连接关闭后被触发,同原生注解里的 @OnClose 功能

  • handleTextMessage方法是在客户端发送信息时触发,同原生注解里的 @OnMessage 功能

/**
 * 
 * @ClassName:  HttpAuthHandler   
 * @Description:TODO(这里用一句话描述这个类的作用)   
 * @author: hzq 
 * @date:   2023-4-14 17:25:19    
 * @Copyright: 2023
 */
@Component
public class HttpAuthHandler extends TextWebSocketHandler {
    /**
     * socket 建立成功事件
     * @param session
     * @throws Exception
     */
    @Override
    public void afterConnectionEstablished(WebSocketSession session) throws Exception {
        Object token = session.getAttributes().get("token");
        if (token != null) {
            // 用户连接成功,放入在线用户缓存
            WsSessionManager.add(token.toString(), session);
        } else {
            throw new RuntimeException("用户登录已经失效!");
        }
    }
    /**
     * 接收消息事件
     * @param session
     * @param message
     * @throws Exception
     */
    @Override
    protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
        // 获得客户端传来的消息
        String payload = message.getPayload();
        Object token = session.getAttributes().get("token");
        System.out.println("server 接收到 " + token + " 发送的 " + payload);
        session.sendMessage(new TextMessage("server 发送给 " + token + " 消息 " + payload + " " +    LocalDateTime.now().toString()));
    }

    /**
     * socket 断开连接时
     * @param session
     * @param status
     * @throws Exception
     */
    @Override
    public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
        Object token = session.getAttributes().get("token");
        if (token != null) {
            // 用户退出,移除缓存
            WsSessionManager.remove(token.toString());
        }
    }
}

  • MyInterceptor用来拦截ws请求
    • 通过实现 HandshakeInterceptor 接口来定义握手拦截器,注意这里与上面 Handler 的事件是不同的
    • 这里是建立握手时的事件,分为握手前与握手后,而 Handler 的事件是在握手成功后的基础上建立 socket 的连接。
    • 所以在如果把认证放在这个步骤相对来说最节省服务器资源。它主要有两个方法 beforeHandshake 与 afterHandshake,
    • 顾名思义一个在握手前触发,一个在握手后触发
package com.osvue.env.app.msg;

import java.util.HashMap;
import java.util.Map;

import org.apache.commons.lang.StringUtils;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;
/**
 * 
 * @ClassName:  MyInterceptor   
 * @Description:TODO(拦截器 --》 http Auth handler)   
 * @author: hzq 
 * @date:   2023-4-14 17:27:04    
 * @Copyright: 2023
 */
@Component
public class MyInterceptor implements HandshakeInterceptor {
    /**
     * 握手前
     * @param request
     * @param response
     * @param wsHandler
     * @param attributes
     * @return
     * @throws Exception
     */
    @Override
    public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {
        System.out.println("握手开始");
        // 获得请求参数
//        HashMap<String, String> paramMap = HttpUtil.decodeParamMap(request.getURI().getQuery(), "utf-8");
//        String uid = paramMap.get("token");
        String uid = "paramMap.";
        System.out.println(request.getURI().getQuery());
        if (StringUtils.isNotBlank(uid)) {
            // 放入属性域
            attributes.put("token", uid);
            System.out.println("用户 token " + uid + " 握手成功!");
            return true;
        }
        System.out.println("用户登录已失效");
        return false;
    }
    /**
     * 握手后
     * @param request
     * @param response
     * @param wsHandler
     * @param exception
     */
    @Override
    public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Exception exception) {
        System.out.println("握手完成");
    }
}
  • WsSessionManager
  • 简单通过 ConcurrentHashMap来实现了一个 session 池,用来保存已经登录的WebSocket 的 session。
  • 服务端发送消息给客户端必须要通过这个 session。
package com.osvue.env.app.msg;

import java.io.IOException;
import java.util.concurrent.ConcurrentHashMap;

import org.springframework.web.socket.WebSocketSession;

import lombok.extern.slf4j.Slf4j;
/**
 * 
 * @ClassName:  WsSessionManager   
 * @Description:TODO(这里用一句话描述这个类的作用)   
 * @author: hzq 
 * @date:   2023-4-14 17:27:50    
 * @Copyright: 2023
 */
@Slf4j
public class WsSessionManager {
    /**
     * 保存连接 session 的地方
     */
    private static ConcurrentHashMap<String, WebSocketSession> SESSION_POOL = new ConcurrentHashMap<>();

    /**
     * 添加 session
     * @param key
     */
    public static void add(String key, WebSocketSession session) {
        // 添加 session
        SESSION_POOL.put(key, session);
    }

    /**
     * 删除 session,会返回删除的 session
     * @param key
     * @return
     */
    public static WebSocketSession remove(String key) {
        // 删除 session
        return SESSION_POOL.remove(key);
    }

    /**
     * 删除并同步关闭连接
     * @param key
     */
    public static void removeAndClose(String key) {
        WebSocketSession session = remove(key);
        if (session != null) {
            try {
                // 关闭连接
                session.close();
            } catch (IOException e) {
                // todo: 关闭出现异常处理
                e.printStackTrace();
            }
        }
    }
    /**
     * 获得 session
     * @param key
     * @return
     */
    public static WebSocketSession get(String key) {
        // 获得 session
        return SESSION_POOL.get(key);
    }
}

个人实例

  • spring boot websocket
package com.osvue.env.app.ws;

import java.io.IOException;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;

import javax.websocket.OnClose;
import javax.websocket.OnError;
import javax.websocket.OnMessage;
import javax.websocket.OnOpen;
import javax.websocket.Session;
import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;
/**
 * 
 * @ClassName:  MsgHandle   
 * @Description:TODO(处理消息)   
 * @author: hzq 
 * @date:   2023-4-25 17:04:46    
 * @Copyright: 2023
 */
@ServerEndpoint("/itg/{userId}")
@Component
public class MsgHandle {
	
 
	    private static final Logger log = LoggerFactory.getLogger(MsgHandle.class);

	    /**
	     * 当前在线连接数
	     */
	    private static AtomicInteger onlineCount = new AtomicInteger(0);

	    /**
	     * 用来存放每个客户端对应的 WebSocketServer 对象
	     */
	    private static ConcurrentHashMap<String, MsgHandle> webSocketMap = new ConcurrentHashMap<>();

	    /**
	     * 与某个客户端的连接会话,需要通过它来给客户端发送数据
	     */
	    private Session session;

	    /**
	     * 接收 userId
	     */
	    private String userId = "";

	    /**
	     * 连接建立成功调用的方法
	     */
	    @OnOpen
	    public void onOpen(Session session, @PathParam("userId") String userId) {
	        this.session = session;
	        this.userId = userId;
	        if (webSocketMap.containsKey(userId)) {
	            webSocketMap.remove(userId);
	            webSocketMap.put(userId, this);
	        } else {
	            webSocketMap.put(userId, this);
	            addOnlineCount();
	        }
	        log.info("用户连接:" + userId + ",当前在线人数为:" + getOnlineCount());
	        try {
	            sendMessage("连接成功!");
	        } catch (IOException e) {
	            log.error("用户:" + userId + ",网络异常!!!!!!");
	        }
	    }

	    /**
	     * 连接关闭调用的方法
	     */
	    @OnClose
	    public void onClose() {
	        if (webSocketMap.containsKey(userId)) {
	            webSocketMap.remove(userId);
	            subOnlineCount();
	        }
	        log.info("用户退出:" + userId + ",当前在线人数为:" + getOnlineCount());
	    }

	    /**
	     * 收到客户端消息后调用的方法
	     *
	     * @param message 客户端发送过来的消息
	     */
	    @OnMessage
	    public void onMessage(String message, Session session) {
	        log.info("用户消息:" + userId + ",报文:" + message);
	        if (!StringUtils.isEmpty(message)) {
	            try {
	            	/*
	                JSONObject jsonObject = JSON.parseObject(message);
	                jsonObject.put("fromUserId", this.userId);
	                String toUserId = jsonObject.getString("toUserId");
	                if (!StringUtils.isEmpty(toUserId) && webSocketMap.containsKey(toUserId)) {
	                
	                    webSocketMap.get(toUserId).sendMessage(jsonObject.toJSONString());
	                } else {
	                    log.error("请求的 userId:" + toUserId + "不在该服务器上");
	                }
	                */
	            } catch (Exception e) {
	                e.printStackTrace();
	            }
	        }
	    }

	    /**
	     * 发生错误时调用
	     *
	     * @param session
	     * @param error
	     */
	    @OnError
	    public void onError(Session session, Throwable error) {
	        log.error("用户错误:" + this.userId + ",原因:" + error.getMessage());
	        error.printStackTrace();
	    }

	    /**
	     * 实现服务器主动推送
	     */
	    public void sendMessage(String message) throws IOException {
	        this.session.getBasicRemote().sendText(message);
	    }

	    public static synchronized AtomicInteger getOnlineCount() {
	        return onlineCount;
	    }

	    public static synchronized void addOnlineCount() {
	    	MsgHandle.onlineCount.getAndIncrement();
	    }

	    public static synchronized void subOnlineCount() {
	    	MsgHandle.onlineCount.getAndDecrement();
	    }

		public static ConcurrentHashMap<String, MsgHandle> getWebSocketMap() {
			return webSocketMap;
		}

		public static void setWebSocketMap(ConcurrentHashMap<String, MsgHandle> webSocketMap) {
			MsgHandle.webSocketMap = webSocketMap;
		}
	 
}




@Configuration
public class WebSocketConfig {
    @Bean
    public ServerEndpointExporter serverEndpointExporter() {
        return new ServerEndpointExporter();
    }
 
}




@RestController
@RequestMapping("/ims")
public class MsgController {

	@Resource
	MsgHandle mh;

	@GetMapping("msg/{msg}")
	public Map<String, Object> send(@PathVariable String msg) {
		Map<String, Object> mp = new HashMap<>();
		String amsg = "  定时发送  " + new Date().toLocaleString() + msg;
		try {
			mh.getWebSocketMap().get("abccd").sendMessage("你好 世界");
		} catch (IOException e) {
			e.printStackTrace();
		}
		return mp;
	}
}


  • 简单前台
<!DOCTYPE HTML>
<html>
<head>
    <title>My WebSocket</title>
</head>
 
<body>
Welcome<br/>
<input id="text" type="text" /><button onclick="send()">Send</button>    <button onclick="closeWebSocket()">Close</button>
<div id="message">
</div>
</body>
 
<script type="text/javascript">
    var websocket = null;
 
    //判断当前浏览器是否支持WebSocket  ,主要此处要更换为自己的地址
    if('WebSocket' in window){
        websocket = new WebSocket("ws://localhost:8005/itg/abccd");
    }
    else{
        alert('Not support websocket')
    }
 
    //连接发生错误的回调方法
    websocket.onerror = function(){
        setMessageInnerHTML("error");
    };
 
    //连接成功建立的回调方法
    websocket.onopen = function(event){
        setMessageInnerHTML("open");
    }
 
    //接收到消息的回调方法
    websocket.onmessage = function(event){
        setMessageInnerHTML(event.data);
    }
 
    //连接关闭的回调方法
    websocket.onclose = function(){
        setMessageInnerHTML("close");
    }
 
    //监听窗口关闭事件,当窗口关闭时,主动去关闭websocket连接,防止连接还没断开就关闭窗口,server端会抛异常。
    window.onbeforeunload = function(){
        websocket.close();
    }
 
    //将消息显示在网页上
    function setMessageInnerHTML(innerHTML){
        document.getElementById('message').innerHTML += innerHTML + '<br/>';
    }
 
    //关闭连接
    function closeWebSocket(){
        websocket.close();
    }
 
    //发送消息
    function send(){
        var message = document.getElementById('text').value;
        websocket.send(message);
    }
</script>
</html>