package middleware

import (
	"context"
	"fmt"
	"time"

	"lune/talentscale/infra/cache"
	"lune/talentscale/internal/domain"
	"lune/talentscale/pkg/logger"
	"github.com/gofiber/fiber/v2"
	"github.com/google/uuid"
)

func MonitoringMiddleware() fiber.Handler {
	return func(c *fiber.Ctx) error {
		start := time.Now()
		requestID := uuid.New().String()
		c.Set("X-Request-ID", requestID)
		c.Locals("request_id", requestID)

		// Process request
		err := c.Next()

		duration := time.Since(start).Milliseconds()
		statusCode := c.Response().StatusCode()
		
		// Determine log level
		level := "info"
		if statusCode >= 500 {
			level = "error"
		} else if statusCode >= 400 {
			level = "warning"
		}

		// Prepare log entry
		logEntry := &domain.SystemLog{
			Level:      level,
			Message:    fmt.Sprintf("%s %s -> %d", c.Method(), c.Path(), statusCode),
			Source:     "api",
			Action:     "http_request",
			IPAddress:  c.IP(),
			UserAgent:  c.Get("User-Agent"),
			StatusCode: statusCode,
			DurationMs: duration,
			RequestID:  requestID,
			Method:     c.Method(),
			Endpoint:   c.Path(),
			CreatedAt:  time.Now(),
		}

		// Extract user/company from locals
		if userIDRaw := c.Locals("user_id"); userIDRaw != nil {
			if id, err := uuid.Parse(fmt.Sprintf("%v", userIDRaw)); err == nil {
				logEntry.UserID = &id
			}
		}
		if companyIDRaw := c.Locals("company_id"); companyIDRaw != nil {
			if id, err := uuid.Parse(fmt.Sprintf("%v", companyIDRaw)); err == nil {
				logEntry.CompanyID = &id
			}
		}

		if err != nil {
			logEntry.ErrorStack = err.Error()
		}

		// 🎯 LOG SAMPLING (Optimization for high traffic)
		// Error/Warn: 100% log
		// Info: 20% log (sample 1 in 5)
		shouldLog := true
		if level == "info" {
			if time.Now().UnixNano()%5 != 0 {
				shouldLog = false
			}
		}

		if shouldLog {
			logger.Log(logEntry)
		}

		// Track metrics in Redis (Async - 100% frequency)
		go trackMetrics(c.Path(), duration, statusCode)

		return err
	}
}

func trackMetrics(path string, duration int64, statusCode int) {
	if cache.Client == nil {
		return
	}

	ctx := context.Background()
	pipe := cache.Client.Pipeline()

	now := time.Now().Format("2006-01-02:15:04") // Minute precision
	baseKey := fmt.Sprintf("metrics:%s", now)

	// Total requests
	pipe.Incr(ctx, fmt.Sprintf("%s:total", baseKey))
	pipe.Incr(ctx, fmt.Sprintf("%s:%s:total", baseKey, path))

	// Errors
	if statusCode >= 400 {
		pipe.Incr(ctx, fmt.Sprintf("%s:errors", baseKey))
	}

	// Response time (sum and count for avg)
	pipe.IncrBy(ctx, fmt.Sprintf("%s:duration_sum", baseKey), duration)
	pipe.Incr(ctx, fmt.Sprintf("%s:duration_count", baseKey))

	// Set expiry for metrics
	pipe.Expire(ctx, fmt.Sprintf("%s:total", baseKey), 24*time.Hour)
	// ... could add more expiries but pipeline is efficient

	_, _ = pipe.Exec(ctx)
}
