package com.onsiteservice.common.socket;

import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.onsiteservice.constant.constant.Constants;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.*;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

/**
 * @author 潘维吉
 * @date 2019-04-09 11:05
 * 统一处理 WebSocket消息
 */
@Component
@Slf4j
public class SocketHandler implements WebSocketHandler {

    /** 在线用户,将其保存在set中,避免用户重复登录,出现多个session */
    private static final Map<String, WebSocketSession> USERS;
    private static final String SEND_ALL = "all";

    static {
        USERS = Collections.synchronizedMap(new HashMap<>());
    }

    /**
     * 给某个用户发送消息
     *
     * @param userId  用户id
     * @param message 消息
     */
    public void sendMessageToUser(String userId, TextMessage message) {
        WebSocketSession user = USERS.get(userId);
        try {
            if (user.isOpen()) {
                user.sendMessage(message); }
        } catch (Exception e) {
            log.warn("给userId=" + userId + "用户发送消息异常:" + e.getMessage());
        }
    }

    /**
     * 给某些用户发送消息
     *
     * @param userId  用户id
     * @param message 消息
     */
    public void sendMessageToSomeUser(TextMessage message, String... userId) {
        Arrays.asList(userId).forEach(item -> sendMessageToUser(item, message));
    }

    /**
     * 给所有在线用户发送消息
     *
     * @param message 文本消息
     */
    public void sendMessageToAll(TextMessage message) {
        WebSocketSession user = null;
        for (String key : USERS.keySet()) {
            user = USERS.get(key);
            try {
                if (user.isOpen()) {
                    user.sendMessage(message);
                }
            } catch (Exception e) {
                log.warn("给所有在线用户发送消息异常:" + e.getMessage());
            }
        }
    }

    /**
     * 成功连接WebSocket后执行
     */
    @Override
    public void afterConnectionEstablished(WebSocketSession session) {
        try {
            log.info("WebSocket链接成功 \uD83D\uDE02 ");
            String userId = (String) session.getAttributes().get(Constants.WEBSOCKET_USER_ID);
            if (userId != null) {
                USERS.put(userId, session);
                // log.info("有新WebSocket连接加入,当前在线人数为:" + USERS.size());
                session.sendMessage(new SocketData().toTextMessage("open", "WebSocket连接成功"));
            }
        } catch (Exception e) {
            log.error("WebSocket连接 afterConnectionEstablished异常:" + e.getMessage());
        }
    }

    /**
     * 处理收到客户端消息
     *
     * @param message 客户端发送过来的消息
     */
    @Override
    public void handleMessage(WebSocketSession session, WebSocketMessage<?> message) throws Exception {
        var payload = message.getPayload();
        if (payload instanceof String) {
            log.info("处理客户端发送的字符串消息:" + payload.toString());
        } else if (payload instanceof Object) {
            JSONObject msg = JSON.parseObject(payload.toString());
            log.info("处理客户端发送的对象消息:" + payload.toString());
            JSONObject obj = new JSONObject();
            String type = msg.get("type").toString();
            if (StringUtils.isNotBlank(type) && SEND_ALL.equals(type)) {
                //给所有人
                obj.put("msg", msg.getString("msg"));
                log.info("给所有人发消息");
                sendMessageToAll(new TextMessage(obj.toJSONString()));
            } else {
                //给个人
                String to = msg.getString("to");
                obj.put("msg", msg.getString("msg"));
                log.info("给个人发消息");
                sendMessageToUser(to, new TextMessage(obj.toJSONString()));
            }
        }
    }

    /**
     * 处理传输错误
     */
    @Override
    public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
        if (session.isOpen()) {
            session.close();
        }
        log.info("链接出错，关闭链接,异常信息:" + exception.getMessage());
        String userId = getUserId(session);
        if (USERS.get(userId) != null) {
            USERS.remove(userId);
        }
    }

    /**
     * 在两端WebSocket connection都关闭或transport error发生后执行
     */
    @Override
    public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception {
        log.info("链接关闭,关闭信息:" + closeStatus.toString());
        String userId = getUserId(session);
        if (USERS.get(userId) != null) {
            USERS.remove(userId);
        }
    }

    @Override
    public boolean supportsPartialMessages() {
        return false;
    }

    /**
     * 获取用户id
     */
    private String getUserId(WebSocketSession session) {
        try {
            return (String) session.getAttributes().get(Constants.WEBSOCKET_USER_ID);
        } catch (Exception e) {
            return null;
        }
    }

    /**
     * 根据当前用户id登出
     *
     * @param userId 用户id
     */
    public void logout(String userId) {
        USERS.remove(userId);
        log.info("用户登出,userId:" + userId);
    }
}
