package scoring

import (
	"context"
	"encoding/json"
	"errors"
	"fmt"
	"strings"

	"github.com/google/uuid"
)

type Service interface {
	Calculate(ctx context.Context, sessionID uuid.UUID) (*CandidateResult, error)
}

type service struct {
	repo Repository
}

func NewService(repo Repository) Service {
	return &service{repo: repo}
}

func (s *service) Calculate(ctx context.Context, sessionID uuid.UUID) (*CandidateResult, error) {
	// 1. Fetch Scoring Data (Zero N+1 Query)
	data, err := s.repo.GetScoringData(ctx, sessionID)
	if err != nil {
		return nil, fmt.Errorf("failed to get scoring data: %w", err)
	}

	if len(data) == 0 {
		return nil, errors.New("no answers found for session")
	}

	candidateID := data[0].CandidateID
	testTypeID := data[0].TestTypeID

	// 2. Fetch Rules
	rule, err := s.repo.GetScoringRule(ctx, testTypeID)
	if err != nil {
		return nil, fmt.Errorf("failed to get scoring rules: %w", err)
	}

	// Determine test mode from rule or fallback to question type
	var ruleConfig struct {
		TestCode string `json:"test_code"`
	}
	_ = json.Unmarshal(rule.Rules, &ruleConfig)

	testMode := ruleConfig.TestCode
	if testMode == "" {
		firstQT := data[0].QuestionType
		if firstQT == "multiple_choice" {
			testMode = "MULTIPLE_CHOICE"
		} else if firstQT == "most_least_choice" {
			testMode = "DISC"
		} else if firstQT == "likert_scale" {
			testMode = "MBTI"
		} else if firstQT == "pattern_matrix" {
			testMode = "CFIT"
		}
	}

	var result *CandidateResult

	// 3. Switch-case routing logic based on test type
	switch testMode {
	case "CFIT", "pattern_matrix":
		result, err = s.calculateCFIT(sessionID, candidateID, testTypeID, data, rule.Rules)
	case "DISC", "most_least_choice":
		result, err = s.calculateDISC(sessionID, candidateID, testTypeID, data)
	case "MBTI", "likert_scale":
		result, err = s.calculateMBTI(sessionID, candidateID, testTypeID, data)
	default:
		// Default to generic multiple choice check
		result, err = s.calculateMultipleChoice(sessionID, candidateID, testTypeID, data)
	}

	if err != nil {
		return nil, err
	}

	// 4. Data Persistence
	err = s.repo.SaveResult(ctx, result)
	if err != nil {
		return nil, fmt.Errorf("failed to save result: %w", err)
	}

	return result, nil
}

func (s *service) calculateMultipleChoice(sessionID, candidateID, testTypeID uuid.UUID, data []ScoringDataRow) (*CandidateResult, error) {
	var totalScore float64
	var breakdowns []map[string]interface{}

	for _, row := range data {
		var ans map[string]interface{}
		_ = json.Unmarshal(row.Answer, &ans)

		var key map[string]interface{}
		_ = json.Unmarshal(row.AnswerKey, &key)

		isCorrect := false
		
		// Attempt to match selected index
		if selIdx, ok := ans["selected_index"].(float64); ok {
			if corrIdx, ok2 := key["correct_index"].(float64); ok2 {
				if selIdx == corrIdx {
					isCorrect = true
				}
			}
		}
		
		// Attempt to match selected text
		if selText, ok := ans["selected"].(string); ok {
			if corrText, ok2 := key["correct_answer"].(string); ok2 {
				if selText == corrText {
					isCorrect = true
				}
			}
		}

		score := 0.0
		if isCorrect {
			score = row.ScoringWeight
			totalScore += score
		}

		breakdowns = append(breakdowns, map[string]interface{}{
			"question_id": row.QuestionID,
			"is_correct":  isCorrect,
			"score":       score,
		})
	}

	resultJSON, _ := json.Marshal(map[string]interface{}{
		"breakdown": breakdowns,
	})

	return &CandidateResult{
		SessionID:   sessionID,
		CandidateID: candidateID,
		TestTypeID:  testTypeID,
		TotalScore:  totalScore,
		Result:      resultJSON,
	}, nil
}

func (s *service) calculateCFIT(sessionID, candidateID, testTypeID uuid.UUID, data []ScoringDataRow, rulesJSON []byte) (*CandidateResult, error) {
	var totalRaw float64
	sectionScores := make(map[int]float64)

	for _, row := range data {
		isCorrect := false

		// For Section 2, handle slice comparison for multiple selections (e.g. ["A", "C"])
		var ansSlice struct {
			Selected []string `json:"selected"`
		}
		errAns := json.Unmarshal(row.Answer, &ansSlice)
		
		var keySlice struct {
			CorrectAnswers []string `json:"correct_answers"`
		}
		errKey := json.Unmarshal(row.AnswerKey, &keySlice)

		// If it's a multiple selection format
		if errAns == nil && errKey == nil && len(ansSlice.Selected) > 0 && len(keySlice.CorrectAnswers) > 0 {
			if len(ansSlice.Selected) == len(keySlice.CorrectAnswers) {
				match := true
				ansMap := make(map[string]bool)
				for _, a := range ansSlice.Selected { ansMap[a] = true }
				for _, k := range keySlice.CorrectAnswers {
					if !ansMap[k] { match = false; break }
				}
				isCorrect = match
			}
		} else {
			// Fallback to single choice check
			var ansSingle struct {
				Selected string `json:"selected"`
			}
			_ = json.Unmarshal(row.Answer, &ansSingle)
			var keySingle struct {
				CorrectAnswer string `json:"correct_answer"`
			}
			_ = json.Unmarshal(row.AnswerKey, &keySingle)
			
			if ansSingle.Selected != "" && keySingle.CorrectAnswer != "" {
				isCorrect = (ansSingle.Selected == keySingle.CorrectAnswer)
			} else {
				// Try selected index fallback
				var ansIdx struct {
					SelectedIndex int `json:"selected_index"`
				}
				_ = json.Unmarshal(row.Answer, &ansIdx)
				var keyIdx struct {
					CorrectIndex int `json:"correct_index"`
				}
				_ = json.Unmarshal(row.AnswerKey, &keyIdx)
				
				// Assumes 0 index could be an answer, so we strictly check if answer json actually had it
				if strings.Contains(string(row.Answer), "selected_index") && strings.Contains(string(row.AnswerKey), "correct_index") {
				    isCorrect = (ansIdx.SelectedIndex == keyIdx.CorrectIndex)
				}
			}
		}

		if isCorrect {
			totalRaw += row.ScoringWeight
			sectionScores[row.OrderNumber] += row.ScoringWeight
		}
	}

	// Map raw score to normalized score
	var rules struct {
		Mapping map[string]float64 `json:"mapping"`
	}
	_ = json.Unmarshal(rulesJSON, &rules)

	var normalized *float64
	if val, ok := rules.Mapping[fmt.Sprintf("%.0f", totalRaw)]; ok {
		normalized = &val
	}

	resultJSON, _ := json.Marshal(map[string]interface{}{
		"raw_score":      totalRaw,
		"section_scores": sectionScores,
	})

	return &CandidateResult{
		SessionID:       sessionID,
		CandidateID:     candidateID,
		TestTypeID:      testTypeID,
		TotalScore:      totalRaw,
		NormalizedScore: normalized,
		Result:          resultJSON,
	}, nil
}

func (s *service) calculateMBTI(sessionID, candidateID, testTypeID uuid.UUID, data []ScoringDataRow) (*CandidateResult, error) {
	counts := map[string]int{"E": 0, "I": 0, "S": 0, "N": 0, "T": 0, "F": 0, "J": 0, "P": 0}

	for _, row := range data {
		var ans struct {
			SelectedItemID string `json:"selected_item_id"`
		}
		_ = json.Unmarshal(row.Answer, &ans)
		if ans.SelectedItemID == "" {
		    continue
		}

		var items []struct {
			ID   string `json:"id"`
			Meta struct {
				TraitCode string `json:"trait_code"`
			} `json:"meta"`
		}
		_ = json.Unmarshal(row.ItemsJSON, &items)

		for _, it := range items {
			if it.ID == ans.SelectedItemID && it.Meta.TraitCode != "" {
				counts[it.Meta.TraitCode]++
				break
			}
		}
	}

	mbti := ""
	if counts["E"] >= counts["I"] { mbti += "E" } else { mbti += "I" }
	if counts["S"] >= counts["N"] { mbti += "S" } else { mbti += "N" }
	if counts["T"] >= counts["F"] { mbti += "T" } else { mbti += "F" }
	if counts["J"] >= counts["P"] { mbti += "J" } else { mbti += "P" }

	resultJSON, _ := json.Marshal(map[string]interface{}{
		"mbti_type":    mbti,
		"trait_scores": counts,
	})

	return &CandidateResult{
		SessionID:   sessionID,
		CandidateID: candidateID,
		TestTypeID:  testTypeID,
		TotalScore:  0,
		Result:      resultJSON,
	}, nil
}

func (s *service) calculateDISC(sessionID, candidateID, testTypeID uuid.UUID, data []ScoringDataRow) (*CandidateResult, error) {
	type discScores struct {
		Most   int `json:"most"`
		Least  int `json:"least"`
		Change int `json:"change"`
	}
	scores := map[string]*discScores{
		"D": {0, 0, 0}, "I": {0, 0, 0}, "S": {0, 0, 0}, "C": {0, 0, 0}, "*": {0, 0, 0},
	}

	for _, row := range data {
		var ans struct {
			Most  string `json:"most"`
			Least string `json:"least"`
		}
		if err := json.Unmarshal(row.Answer, &ans); err != nil {
			continue
		}

		var items []struct {
			ID   string `json:"id"`
			Meta struct {
				PlusTrait  string `json:"plus_trait"`
				MinusTrait string `json:"minus_trait"`
			} `json:"meta"`
		}
		_ = json.Unmarshal(row.ItemsJSON, &items)

		for _, it := range items {
			if it.ID == ans.Most {
				trait := strings.ToUpper(it.Meta.PlusTrait)
				if val, ok := scores[trait]; ok {
					val.Most++
				}
			}
			if it.ID == ans.Least {
				trait := strings.ToUpper(it.Meta.MinusTrait)
				if val, ok := scores[trait]; ok {
					val.Least++
				}
			}
		}
	}

	for k, v := range scores {
		v.Change = v.Most - v.Least
		scores[k] = v
	}

	resultJSON, _ := json.Marshal(map[string]interface{}{
		"disc_scores": scores,
	})

	return &CandidateResult{
		SessionID:   sessionID,
		CandidateID: candidateID,
		TestTypeID:  testTypeID,
		TotalScore:  0,
		Result:      resultJSON,
	}, nil
}
