package com.emonster.taroaichat.service.sse;

import com.emonster.taroaichat.service.dto.sse.SseEvent;
import com.emonster.taroaichat.service.dto.sse.SseEventType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;

import java.io.IOException;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/**
 * Service for managing Server-Sent Events connections.
 */
@Service
public class SseService {

    private static final Logger log = LoggerFactory.getLogger(SseService.class);

    private final Map<String, SseEmitter> emitters = new ConcurrentHashMap<>();

    /**
     * Create a new SSE connection for a session.
     */
    public SseEmitter createEmitter(String sessionId) {
        SseEmitter emitter = new SseEmitter(Long.MAX_VALUE);

        emitter.onCompletion(() -> {
            log.debug("SSE connection completed for session: {}", sessionId);
            emitters.remove(sessionId);
        });

        emitter.onTimeout(() -> {
            log.debug("SSE connection timed out for session: {}", sessionId);
            emitters.remove(sessionId);
        });

        emitter.onError(e -> {
            log.warn("SSE connection error for session: {}", sessionId, e);
            emitters.remove(sessionId);
        });

        emitters.put(sessionId, emitter);
        log.debug("Created SSE connection for session: {}", sessionId);

        return emitter;
    }

    /**
     * Send an event to a specific session.
     */
    public void sendEvent(String sessionId, SseEvent event) {
        SseEmitter emitter = emitters.get(sessionId);
        if (emitter != null) {
            try {
                // Use message ID as event ID if available for proper resumption
                String eventId;
                if (event.getData() instanceof Map) {
                    Map<String, Object> dataMap = (Map<String, Object>) event.getData();
                    Object messageId = dataMap.get("_messageId");
                    if (messageId != null) {
                        eventId = "msg-" + messageId;
                    } else {
                        eventId = event.getType() + "-" + System.currentTimeMillis();
                    }
                } else {
                    eventId = event.getType() + "-" + System.currentTimeMillis();
                }

                SseEmitter.SseEventBuilder builder = SseEmitter.event()
                    .id(eventId)
                    .name(event.getType().toString().toLowerCase())
                    .data(event);

                emitter.send(builder);
//                log.debug("Sent SSE event {} to session: {}", event.getType(), sessionId);
            } catch (IOException e) {
                log.error("Failed to send SSE event to session: {}", sessionId, e);
                emitters.remove(sessionId);
                emitter.completeWithError(e);
            }
        } else {
            log.warn("No active SSE connection for session: {}", sessionId);
        }
    }

    /**
     * Send an error event to a specific session.
     */
    public void sendError(String sessionId, String error, String code) {
        Map<String, String> errorData = Map.of(
            "error", error,
            "code", code
        );
        sendEvent(sessionId, new SseEvent(sessionId, SseEventType.ERROR, errorData));
    }

    /**
     * Close a session's SSE connection.
     */
    public void closeConnection(String sessionId) {
        SseEmitter emitter = emitters.remove(sessionId);
        if (emitter != null) {
            emitter.complete();
            log.debug("Closed SSE connection for session: {}", sessionId);
        }
    }

    /**
     * Check if a session has an active connection.
     */
    public boolean hasActiveConnection(String sessionId) {
        return emitters.containsKey(sessionId);
    }

    /**
     * Get the number of active connections.
     */
    public int getActiveConnectionCount() {
        return emitters.size();
    }

    /**
     * Send heartbeat to all active connections to keep them alive.
     * This prevents proxy timeouts and connection drops.
     */
    @Scheduled(fixedDelay = 60000) // Every 60 seconds
    public void sendHeartbeat() {
        if (emitters.isEmpty()) {
            return;
        }

        log.debug("Sending heartbeat to {} active SSE connections", emitters.size());

        emitters.forEach((sessionId, emitter) -> {
            try {
                // Send a comment (starts with :) which is ignored by EventSource but keeps connection alive
                emitter.send(SseEmitter.event().comment("heartbeat"));
            } catch (IOException e) {
                log.debug("Failed to send heartbeat to session {}, removing stale connection", sessionId);
                emitters.remove(sessionId);
                try {
                    emitter.completeWithError(e);
                } catch (Exception ex) {
                    // Ignore errors when completing
                }
            }
        });
    }
}
