spring boot websocket
WebSocket原生注解
- pom.xml
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-websocket</artifactId>
</dependency>
- 配置类WebSocketConfig,这里开启了配置之后springboot才会去扫描对应的注解
@Configuration
@EnableWebSocket
public class WebSocketConfig {
@Bean
public ServerEndpointExporter serverEndpoint() {
return new ServerEndpointExporter();
}
}
- 处理消息类Msg
package com.osvue.env.app.msg;
import java.io.IOException;
import javax.websocket.OnClose;
import javax.websocket.OnMessage;
import javax.websocket.OnOpen;
import javax.websocket.Session;
import javax.websocket.server.ServerEndpoint;
import org.springframework.stereotype.Component;
/**
*
* @ClassName: Msg
* @Description:TODO(web socket 原生注解)
* @author: hzq
* @date: 2023-4-25 11:08:30
* @Copyright: 2023
*/
@ServerEndpoint("/skt")
@Component
public class Msg {
/**
* 连接成功
* @param session
*/
@OnOpen
public void onOpen(Session session) {
System.out.println("连接成功");
}
/**
* 连接关闭
* @param session
*/
@OnClose
public void onClose(Session session) {
System.out.println("连接关闭");
}
/**
* 接收到消息
* @param text
*/
@OnMessage
public String onMsg(String text) throws IOException {
return "servet 发送:" + text;
}
}
后面如果要对客户端发送消息的话使用
session.getBasicRemote().sendText(XXX)
@ServerEndpoint
- 通过这个
spring boot
就可以知道你暴露出去的 ws 应用的路径,有点类似我们经常用的@RequestMapping
。 - 比如你的启动端口是8080,而这个注解的值是ws,那我们就可以通过
ws://127.0.0.1:8080/ws
来连接你的应用。
- @OnOpen
当 websocket 建立连接成功后会触发这个注解修饰的方法。
@OnClose
当 websocket 建立的连接断开后会触发这个注解修饰的方法。
@OnMessage
- 当客户端发送消息到服务端时,会触发这个注解修改的方法,它有一个 String 入参表明客户端传入的值。
- @OnError
- 当 websocket 建立连接时出现异常会触发这个注解修饰的方法。
Spring封装的WebSocket
WebSocketConfig 配置 HttpAuthHandler myInterceptor HttpAuthHandler 定义连接开始 和关闭 myInterceptor 连接前 增加拦截 WsSessionManager
spring同样也为我们提供了WebSocket的封装,这种方式可以自己配置拦截器
在tcp握手之前对请求进行一次处理,可以避免一些恶意的连接。
配置类WebSocketConfig,通过实现
WebSocketConfigurer
类并覆盖相应的方法进行 websocket 的配置。我们主要覆盖 registerWebSocketHandlers 这个方法。通过向 WebSocketHandlerRegistry 设置不同参数来进行配置。
其中 addHandler方法添加我们的 ws 的 handler 处理类,第二个参数是你暴露出的 ws 路径。
addInterceptors添加我们写的拦截器。setAllowedOrigins这个是关闭跨域校验。
WebSocketConfig
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.config.annotation.EnableWebSocket;
import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;
/**
*
* @ClassName: WebSocketConfig
* @Description:TODO(spring boot config)
* @author: hzq
* @date: 2023-4-3 15:57:26
* @Copyright: 2023
*/
@Configuration
@EnableWebSocket
public class WebSocketConfig implements WebSocketConfigurer {
@Autowired
private HttpAuthHandler httpAuthHandler;
@Autowired
private MyInterceptor myInterceptor;
@Override
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
registry
.addHandler(httpAuthHandler, "itg")
.addInterceptors(myInterceptor)
.setAllowedOrigins("*");
}
}
- 处理器和拦截器
HttpAuthHandler用于处理ws的消息,通过继承
TextWebSocketHandler
类并覆盖相应方法,可以对 websocket 的事件进行处理,这里可以同原生注解的那几个注解连起来看afterConnectionEstablished 方法是在 socket 连接成功后被触发,同原生注解里的 @OnOpen 功能
afterConnectionClosed方法是在 socket 连接关闭后被触发,同原生注解里的 @OnClose 功能
handleTextMessage方法是在客户端发送信息时触发,同原生注解里的 @OnMessage 功能
/**
*
* @ClassName: HttpAuthHandler
* @Description:TODO(这里用一句话描述这个类的作用)
* @author: hzq
* @date: 2023-4-14 17:25:19
* @Copyright: 2023
*/
@Component
public class HttpAuthHandler extends TextWebSocketHandler {
/**
* socket 建立成功事件
* @param session
* @throws Exception
*/
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
Object token = session.getAttributes().get("token");
if (token != null) {
// 用户连接成功,放入在线用户缓存
WsSessionManager.add(token.toString(), session);
} else {
throw new RuntimeException("用户登录已经失效!");
}
}
/**
* 接收消息事件
* @param session
* @param message
* @throws Exception
*/
@Override
protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
// 获得客户端传来的消息
String payload = message.getPayload();
Object token = session.getAttributes().get("token");
System.out.println("server 接收到 " + token + " 发送的 " + payload);
session.sendMessage(new TextMessage("server 发送给 " + token + " 消息 " + payload + " " + LocalDateTime.now().toString()));
}
/**
* socket 断开连接时
* @param session
* @param status
* @throws Exception
*/
@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
Object token = session.getAttributes().get("token");
if (token != null) {
// 用户退出,移除缓存
WsSessionManager.remove(token.toString());
}
}
}
- MyInterceptor用来拦截ws请求
- 通过实现 HandshakeInterceptor 接口来定义握手拦截器,注意这里与上面 Handler 的事件是不同的
- 这里是建立握手时的事件,分为握手前与握手后,而 Handler 的事件是在握手成功后的基础上建立 socket 的连接。
- 所以在如果把认证放在这个步骤相对来说最节省服务器资源。它主要有两个方法 beforeHandshake 与 afterHandshake,
- 顾名思义一个在握手前触发,一个在握手后触发
package com.osvue.env.app.msg;
import java.util.HashMap;
import java.util.Map;
import org.apache.commons.lang.StringUtils;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;
/**
*
* @ClassName: MyInterceptor
* @Description:TODO(拦截器 --》 http Auth handler)
* @author: hzq
* @date: 2023-4-14 17:27:04
* @Copyright: 2023
*/
@Component
public class MyInterceptor implements HandshakeInterceptor {
/**
* 握手前
* @param request
* @param response
* @param wsHandler
* @param attributes
* @return
* @throws Exception
*/
@Override
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {
System.out.println("握手开始");
// 获得请求参数
// HashMap<String, String> paramMap = HttpUtil.decodeParamMap(request.getURI().getQuery(), "utf-8");
// String uid = paramMap.get("token");
String uid = "paramMap.";
System.out.println(request.getURI().getQuery());
if (StringUtils.isNotBlank(uid)) {
// 放入属性域
attributes.put("token", uid);
System.out.println("用户 token " + uid + " 握手成功!");
return true;
}
System.out.println("用户登录已失效");
return false;
}
/**
* 握手后
* @param request
* @param response
* @param wsHandler
* @param exception
*/
@Override
public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Exception exception) {
System.out.println("握手完成");
}
}
- WsSessionManager
- 简单通过 ConcurrentHashMap来实现了一个 session 池,用来保存已经登录的WebSocket 的 session。
- 服务端发送消息给客户端必须要通过这个 session。
package com.osvue.env.app.msg;
import java.io.IOException;
import java.util.concurrent.ConcurrentHashMap;
import org.springframework.web.socket.WebSocketSession;
import lombok.extern.slf4j.Slf4j;
/**
*
* @ClassName: WsSessionManager
* @Description:TODO(这里用一句话描述这个类的作用)
* @author: hzq
* @date: 2023-4-14 17:27:50
* @Copyright: 2023
*/
@Slf4j
public class WsSessionManager {
/**
* 保存连接 session 的地方
*/
private static ConcurrentHashMap<String, WebSocketSession> SESSION_POOL = new ConcurrentHashMap<>();
/**
* 添加 session
* @param key
*/
public static void add(String key, WebSocketSession session) {
// 添加 session
SESSION_POOL.put(key, session);
}
/**
* 删除 session,会返回删除的 session
* @param key
* @return
*/
public static WebSocketSession remove(String key) {
// 删除 session
return SESSION_POOL.remove(key);
}
/**
* 删除并同步关闭连接
* @param key
*/
public static void removeAndClose(String key) {
WebSocketSession session = remove(key);
if (session != null) {
try {
// 关闭连接
session.close();
} catch (IOException e) {
// todo: 关闭出现异常处理
e.printStackTrace();
}
}
}
/**
* 获得 session
* @param key
* @return
*/
public static WebSocketSession get(String key) {
// 获得 session
return SESSION_POOL.get(key);
}
}
个人实例
- spring boot websocket
package com.osvue.env.app.ws;
import java.io.IOException;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import javax.websocket.OnClose;
import javax.websocket.OnError;
import javax.websocket.OnMessage;
import javax.websocket.OnOpen;
import javax.websocket.Session;
import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;
/**
*
* @ClassName: MsgHandle
* @Description:TODO(处理消息)
* @author: hzq
* @date: 2023-4-25 17:04:46
* @Copyright: 2023
*/
@ServerEndpoint("/itg/{userId}")
@Component
public class MsgHandle {
private static final Logger log = LoggerFactory.getLogger(MsgHandle.class);
/**
* 当前在线连接数
*/
private static AtomicInteger onlineCount = new AtomicInteger(0);
/**
* 用来存放每个客户端对应的 WebSocketServer 对象
*/
private static ConcurrentHashMap<String, MsgHandle> webSocketMap = new ConcurrentHashMap<>();
/**
* 与某个客户端的连接会话,需要通过它来给客户端发送数据
*/
private Session session;
/**
* 接收 userId
*/
private String userId = "";
/**
* 连接建立成功调用的方法
*/
@OnOpen
public void onOpen(Session session, @PathParam("userId") String userId) {
this.session = session;
this.userId = userId;
if (webSocketMap.containsKey(userId)) {
webSocketMap.remove(userId);
webSocketMap.put(userId, this);
} else {
webSocketMap.put(userId, this);
addOnlineCount();
}
log.info("用户连接:" + userId + ",当前在线人数为:" + getOnlineCount());
try {
sendMessage("连接成功!");
} catch (IOException e) {
log.error("用户:" + userId + ",网络异常!!!!!!");
}
}
/**
* 连接关闭调用的方法
*/
@OnClose
public void onClose() {
if (webSocketMap.containsKey(userId)) {
webSocketMap.remove(userId);
subOnlineCount();
}
log.info("用户退出:" + userId + ",当前在线人数为:" + getOnlineCount());
}
/**
* 收到客户端消息后调用的方法
*
* @param message 客户端发送过来的消息
*/
@OnMessage
public void onMessage(String message, Session session) {
log.info("用户消息:" + userId + ",报文:" + message);
if (!StringUtils.isEmpty(message)) {
try {
/*
JSONObject jsonObject = JSON.parseObject(message);
jsonObject.put("fromUserId", this.userId);
String toUserId = jsonObject.getString("toUserId");
if (!StringUtils.isEmpty(toUserId) && webSocketMap.containsKey(toUserId)) {
webSocketMap.get(toUserId).sendMessage(jsonObject.toJSONString());
} else {
log.error("请求的 userId:" + toUserId + "不在该服务器上");
}
*/
} catch (Exception e) {
e.printStackTrace();
}
}
}
/**
* 发生错误时调用
*
* @param session
* @param error
*/
@OnError
public void onError(Session session, Throwable error) {
log.error("用户错误:" + this.userId + ",原因:" + error.getMessage());
error.printStackTrace();
}
/**
* 实现服务器主动推送
*/
public void sendMessage(String message) throws IOException {
this.session.getBasicRemote().sendText(message);
}
public static synchronized AtomicInteger getOnlineCount() {
return onlineCount;
}
public static synchronized void addOnlineCount() {
MsgHandle.onlineCount.getAndIncrement();
}
public static synchronized void subOnlineCount() {
MsgHandle.onlineCount.getAndDecrement();
}
public static ConcurrentHashMap<String, MsgHandle> getWebSocketMap() {
return webSocketMap;
}
public static void setWebSocketMap(ConcurrentHashMap<String, MsgHandle> webSocketMap) {
MsgHandle.webSocketMap = webSocketMap;
}
}
@Configuration
public class WebSocketConfig {
@Bean
public ServerEndpointExporter serverEndpointExporter() {
return new ServerEndpointExporter();
}
}
@RestController
@RequestMapping("/ims")
public class MsgController {
@Resource
MsgHandle mh;
@GetMapping("msg/{msg}")
public Map<String, Object> send(@PathVariable String msg) {
Map<String, Object> mp = new HashMap<>();
String amsg = " 定时发送 " + new Date().toLocaleString() + msg;
try {
mh.getWebSocketMap().get("abccd").sendMessage("你好 世界");
} catch (IOException e) {
e.printStackTrace();
}
return mp;
}
}
- 简单前台
<!DOCTYPE HTML>
<html>
<head>
<title>My WebSocket</title>
</head>
<body>
Welcome<br/>
<input id="text" type="text" /><button onclick="send()">Send</button> <button onclick="closeWebSocket()">Close</button>
<div id="message">
</div>
</body>
<script type="text/javascript">
var websocket = null;
//判断当前浏览器是否支持WebSocket ,主要此处要更换为自己的地址
if('WebSocket' in window){
websocket = new WebSocket("ws://localhost:8005/itg/abccd");
}
else{
alert('Not support websocket')
}
//连接发生错误的回调方法
websocket.onerror = function(){
setMessageInnerHTML("error");
};
//连接成功建立的回调方法
websocket.onopen = function(event){
setMessageInnerHTML("open");
}
//接收到消息的回调方法
websocket.onmessage = function(event){
setMessageInnerHTML(event.data);
}
//连接关闭的回调方法
websocket.onclose = function(){
setMessageInnerHTML("close");
}
//监听窗口关闭事件,当窗口关闭时,主动去关闭websocket连接,防止连接还没断开就关闭窗口,server端会抛异常。
window.onbeforeunload = function(){
websocket.close();
}
//将消息显示在网页上
function setMessageInnerHTML(innerHTML){
document.getElementById('message').innerHTML += innerHTML + '<br/>';
}
//关闭连接
function closeWebSocket(){
websocket.close();
}
//发送消息
function send(){
var message = document.getElementById('text').value;
websocket.send(message);
}
</script>
</html>