package com.emonster.taroaichat.config;

import static com.emonster.taroaichat.security.SecurityUtils.JWT_ALGORITHM;

import com.emonster.taroaichat.management.SecurityMetersService;
import com.nimbusds.jose.jwk.source.ImmutableSecret;
import com.nimbusds.jose.util.Base64;
import javax.crypto.SecretKey;
import javax.crypto.spec.SecretKeySpec;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Primary;
import org.springframework.security.oauth2.jwt.JwtDecoder;
import org.springframework.security.oauth2.jwt.JwtEncoder;
import org.springframework.security.oauth2.jwt.NimbusJwtDecoder;
import org.springframework.security.oauth2.jwt.NimbusJwtEncoder;
import org.springframework.security.oauth2.server.resource.web.BearerTokenResolver;
import org.springframework.security.oauth2.server.resource.web.DefaultBearerTokenResolver;

@Configuration
public class SecurityJwtConfiguration {

    private static final Logger LOG = LoggerFactory.getLogger(SecurityJwtConfiguration.class);

    @Value("${application.security.authentication.jwt.base64-secret}")
    private String jwtKey;

    @Value("${application.security.authentication.jwt.refresh-token.base64-secret}")
    private String jwtRefreshTokenKey;

    @Bean("accessTokenJwtDecoder")
    @Primary
    public JwtDecoder jwtDecoder(SecurityMetersService metersService) {
        NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withSecretKey(getAccessSecretKey()).macAlgorithm(JWT_ALGORITHM).build();
        return token -> {
            try {
                return jwtDecoder.decode(token);
            } catch (Exception e) {
                if (e.getMessage().contains("Invalid signature")) {
                    metersService.trackTokenInvalidSignature();
                } else if (e.getMessage().contains("Jwt expired at")) {
                    metersService.trackTokenExpired();
                } else if (
                    e.getMessage().contains("Invalid JWT serialization") ||
                    e.getMessage().contains("Malformed token") ||
                    e.getMessage().contains("Invalid unsecured/JWS/JWE")
                ) {
                    metersService.trackTokenMalformed();
                } else {
                    LOG.error("Unknown JWT error {}", e.getMessage());
                }
                throw e;
            }
        };
    }

    @Bean("accessTokenJwtEncoder")
    @Primary
    public JwtEncoder jwtEncoder() {
        return new NimbusJwtEncoder(new ImmutableSecret<>(getAccessSecretKey()));
    }

    @Bean("refreshTokenJwtEncoder")
    public JwtEncoder refreshTokenJwtEncoder() {
        if (jwtRefreshTokenKey == null || jwtRefreshTokenKey.trim().isEmpty()) {
            LOG.error("Refresh token secret key ('application.security.authentication.jwt.refresh-token.base64-secret') is not configured!");
            throw new IllegalStateException("Refresh token secret key is not configured.");
        }
        LOG.info("Configuring refreshTokenJwtEncoder with a dedicated secret.");
        return new NimbusJwtEncoder(new ImmutableSecret<>(getRefreshSecretKey()));
    }

    @Bean("refreshTokenJwtDecoder")
    public JwtDecoder refreshTokenJwtDecoder(SecurityMetersService metersService) {
        if (jwtRefreshTokenKey == null || jwtRefreshTokenKey.trim().isEmpty()) {
            LOG.error("Refresh token secret key ('application.security.authentication.jwt.refresh-token.base64-secret') is not configured for the decoder!");
            throw new IllegalStateException("Refresh token secret key is not configured.");
        }
        LOG.info("Configuring refreshTokenJwtDecoder with a dedicated secret.");
        NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withSecretKey(getRefreshSecretKey()).macAlgorithm(JWT_ALGORITHM).build();
        return token -> {
            try {
                return jwtDecoder.decode(token);
            } catch (Exception e) {
                // Handle refresh token specific errors
                LOG.error("Error decoding refresh token: {}", e.getMessage());
                throw e;
            }
        };
    }

    @Bean
    public BearerTokenResolver bearerTokenResolver() {
        var bearerTokenResolver = new DefaultBearerTokenResolver();
        bearerTokenResolver.setAllowUriQueryParameter(true);
        return bearerTokenResolver;
    }

    private SecretKey getAccessSecretKey() {
        byte[] keyBytes = Base64.from(jwtKey).decode();
        return new SecretKeySpec(keyBytes, 0, keyBytes.length, JWT_ALGORITHM.getName());
    }

    private SecretKey getRefreshSecretKey() {
        if (jwtRefreshTokenKey == null || jwtRefreshTokenKey.trim().isEmpty()) {
            LOG.error("Attempted to get refresh secret key, but 'application.security.authentication.jwt.refresh-token.base64-secret' is not configured!");
            throw new IllegalStateException("Refresh token secret key is not configured. Cannot create refresh key.");
        }
        byte[] keyBytes = Base64.from(jwtRefreshTokenKey).decode();
        return new SecretKeySpec(keyBytes, 0, keyBytes.length, JWT_ALGORITHM.getName());
    }
}
