from datetime import datetime
from uuid import UUID
from sqlalchemy.orm import Session
from fastapi import HTTPException, status
from src.core.exceptions import APIException
from src.utils.pagination import QueryPaginator
from sqlalchemy import func, cast, Integer,and_,or_
from src.apps.wine.keyword_spread.models.keyword_spread import KeywordSpread
from src.apps.wine.keyword_spread.schemas.keyword_spread import KeywordSpreadCreateSchema,KeywordSpreadOutputSchema,KeywordSpreadSchema,KeywordSpreadUpdateSchema,KeywordSpreadFilterSchema
from src.utils.constants import API_PREFIXES
from src.core.config import settings
from collections import defaultdict


async def get_all_keyword_spreads(
    db: Session,
    page: int = 1,
    per_page: int = 10,
    payload: KeywordSpreadFilterSchema = None
) -> any:
    """Retrieve all keyword spreads with pagination."""
    try:
        offset = (page - 1) * per_page
        query = db.query(KeywordSpread).filter(KeywordSpread.deleted_at.is_(None))
        if payload:
            if payload.region_codes:
                region_codes = [rc.strip() for rc in payload.region_codes.split(',') if rc.strip()]
                if region_codes:
                    query = query.filter(KeywordSpread.region_code.in_(region_codes))
        query = query.order_by(KeywordSpread.region_code, KeywordSpread.keyword_type, KeywordSpread.date_created.desc())
        paginator = QueryPaginator(
            query=query, schema=KeywordSpreadOutputSchema, url="".join([str(settings.api_base_url()), API_PREFIXES.KEYWORD_SPREAD]), offset=offset, limit=per_page, use_orm=True
        )
        return paginator.paginate()
    except Exception as e:
        raise APIException(
            module="get_all_keyword_spreads",
            error={"exception": str(e)},
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            message="Error retrieving keyword spreads."
        )
        
async def get_grouped_keyword_spreads(
    db: Session,
    payload: KeywordSpreadFilterSchema = None,
) -> any:
    """Retrieve keyword spreads grouped by their region code and keyword type."""
    try:
        query = db.query(KeywordSpread).filter(KeywordSpread.deleted_at.is_(None))
        region_codes = None
        if payload and getattr(payload, 'region_codes', None):
            region_codes = [rc.strip() for rc in payload.region_codes.split(',') if rc.strip()]
            if region_codes:
                query = query.filter(KeywordSpread.region_code.in_(region_codes))
        keyword_spreads = query.order_by(KeywordSpread.region_code, KeywordSpread.keyword_type, KeywordSpread.date_created.desc()).all()
        grouped = defaultdict(lambda: defaultdict(list))
        for ks in keyword_spreads:
            grouped[ks.region_code][ks.keyword_type].append(KeywordSpreadOutputSchema.model_validate(ks).model_dump())
        return {region: dict(types) for region, types in grouped.items()}
    except Exception as e:
        raise APIException(
            module="get_grouped_keyword_spreads",
            error={"exception": str(e)},
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            message="Error retrieving grouped keyword spreads."
        )