from datetime import datetime, timezone
from sqlalchemy.orm import Session
from fastapi import HTTPException, status
from src.core.exceptions import APIException
from src.apps.base.models.regions import Region, SubRegion
from src.core.config import settings
from src.utils.constants import API_PREFIXES
from src.apps.base.schemas.region import (
    RegionCreateSchema,
    RegionUpdateSchema,
    RegionOutSchema,
    SubRegionCreateSchema,
    SubRegionUpdateSchema,
    SubRegionOutSchema,
    RegionFilterSchema,
    SubRegionFilterSchema
)
from src.apps.base.services.country import get_country_by_id
from src.utils.pagination import QueryPaginator
from src.apps.base.models.country import Country
from uuid import UUID
from sqlalchemy.exc import IntegrityError
from typing import Optional


# ------------------------ Region ------------------------

async def get_region_by_id(db: Session, region_id: int) -> Region:
    try:
        region = db.query(Region).filter(
            Region.id == region_id,
            Region.deleted_at.is_(None)
        ).first()
        if not region:
            raise APIException(
                module="Region",
                error={"exception": "Region not found"},
                status_code=status.HTTP_404_NOT_FOUND,
                message="Region not found"
            )
        return region
    except Exception as e:
        raise APIException(
            module="Region",
            error={"exception": str(e)},
            status_code=status.HTTP_404_NOT_FOUND,
            message="Error fetching region by ID."
        )


async def get_all_regions(db: Session, page: int = 1, per_page: int = 10, payload: Optional[RegionFilterSchema] = None):
    try:
        offset = (page - 1) * per_page
        query = db.query(Region).filter(Region.deleted_at.is_(None))

        if payload and payload.country_id:
            country = db.query(Country).filter(Country.id == payload.country_id).first()
            if not country:
                raise APIException(
                    module="Region",
                    error={"exception": "Country not found"},
                    status_code=status.HTTP_404_NOT_FOUND,
                    message="Country not found"
                )
            query = query.join(Country).filter(Country.id == country.id)

        paginator = QueryPaginator(
            query=query,
            schema=RegionOutSchema,
            url="".join([str(settings.api_base_url()), API_PREFIXES.REGION]),
            offset=offset,
            limit=per_page,
            use_orm=True,
        )
        return paginator.paginate()
    except Exception as e:
        raise APIException(
            module="get_all_regions",
            error={"exception": str(e)},
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            message="Failed to fetch regions."
        )

async def create_region(db: Session, payload: RegionCreateSchema) -> Region:
    country = await get_country_by_id(db, payload.country_id)
    try:
        existing = db.query(Region).filter(Region.region_code == payload.region_code).first()
    except Exception as e:
        raise APIException(
            module="Region",
            error={"exception": str(e)},
            status_code=status.HTTP_400_BAD_REQUEST,
            message="Error checking existing region."
        )

    if existing:
        if existing.deleted_at is not None:
            # Restore soft-deleted region
            try:
                existing.deleted_at = None
                for field, value in payload.dict().items():
                    setattr(existing, field, value)
                db.commit()
                db.refresh(existing)
                return existing
            except Exception as e:
                raise APIException(
                    module="Region",
                    error={"exception": str(e)},
                    status_code=status.HTTP_400_BAD_REQUEST,
                    message="Error restoring soft-deleted region."
                )
        raise APIException(
            module="Region",
            error={"exception": "Region with this code already exists"},
            status_code=status.HTTP_400_BAD_REQUEST,
            message="Region with this code already exists"
        )

    region = Region(**payload.dict())
    db.add(region)
    try:
        db.commit()
        db.refresh(region)
        return region
    except Exception as e:
        db.rollback()
        raise APIException(
            module="Region",
            error={"exception": str(e)},
            status_code=status.HTTP_400_BAD_REQUEST,
            message="Error creating region."
        )


async def update_region(db: Session, region_id: int, payload: RegionUpdateSchema) -> Region:
    region = await get_region_by_id(db, region_id)
    try:
        update_data = payload.dict(exclude_unset=True)
        if "country_id" in update_data:
            country = await get_country_by_id(db, update_data["country_id"])
        for field, value in update_data.items():
            setattr(region, field, value)

        db.commit()
        db.refresh(region)
        return region
    except Exception as e:
        db.rollback()
        raise APIException(
            module="Region",
            error={"exception": str(e)},
            status_code=status.HTTP_400_BAD_REQUEST,
            message="Error updating region."
        )




async def delete_region(db: Session, region_id: int) -> None:
    region = await get_region_by_id(db, region_id)

    # Check if any SubRegion is linked
    # if region.sub_regions:
    #     raise HTTPException(
    #         status_code=status.HTTP_400_BAD_REQUEST,
    #         detail="Cannot delete region because it has linked sub-regions"
    #     )

    # Check if any Location is linked
    # if region.locations:
    #     raise HTTPException(
    #         status_code=status.HTTP_400_BAD_REQUEST,
    #         detail="Cannot delete region because it has linked locations"
    #     )

    # Soft delete
    try:
        region.is_deleted = True
        region.deleted_at = datetime.utcnow()
        db.commit()
        return None
    except Exception as e:
        db.rollback()
        raise APIException(
            module="Region",
            error={"exception": str(e)},
            status_code=status.HTTP_400_BAD_REQUEST,
            message="Error deleting region."
        )



# ------------------------ SubRegion ------------------------

async def get_sub_region_by_id(db: Session, sub_region_id: int) -> SubRegion:
    try:
        sub_region = db.query(SubRegion).filter(
            SubRegion.id == sub_region_id,
            SubRegion.deleted_at.is_(None)
        ).first()
        if not sub_region:
            raise APIException(
                module="SubRegion",
                error={"exception": "Sub-region not found"},
                status_code=status.HTTP_404_NOT_FOUND,
                message="Sub-region not found"
            )
        return sub_region
    except Exception as e:
        raise APIException(
            module="SubRegion",
            error={"exception": str(e)},
            status_code=status.HTTP_404_NOT_FOUND,
            message="Error fetching sub-region by ID."
        )


async def get_all_sub_regions(
    db: Session,
    page: int = 1,
    per_page: int = 10,
    payload: Optional[SubRegionFilterSchema] = None,
):
    try:
        offset = (page - 1) * per_page
        query = db.query(SubRegion).filter(SubRegion.deleted_at.is_(None))

        if payload and payload.region_id:
            region = db.query(Region).filter(Region.id == payload.region_id).first()
            if not region:
                raise APIException(
                    module="SubRegion",
                    error={"exception": "Region not found"},
                    status_code=status.HTTP_404_NOT_FOUND,
                    message="Region not found"
                )
            query = query.filter(SubRegion.region_id == region.id)


        paginator = QueryPaginator(
            query=query,
            schema=SubRegionOutSchema,
            url="".join([str(settings.api_base_url()), API_PREFIXES.SUB_REGION]),
            offset=offset,
            limit=per_page,
            use_orm=True,
        )
        return paginator.paginate()
    except Exception as e:
        raise APIException(
            module="get_all_sub_regions",
            error={"exception": str(e)},
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            message="Failed to fetch sub-regions."
        )


async def create_sub_region(db: Session, payload: SubRegionCreateSchema):
    try:
        region = await get_region_by_id(db, payload.region_id)
            
        sub_region = SubRegion(
            name=payload.name,
            region_id= region.id, 
        )

        db.add(sub_region)
        db.commit()
        db.refresh(sub_region)
        return sub_region
    except Exception as e:
        db.rollback()
        raise APIException(
            module="SubRegion",
            error={"exception": str(e)},
            status_code=status.HTTP_400_BAD_REQUEST,
            message="Error creating sub-region."
        )


async def update_sub_region(db: Session, sub_region_id: int, payload: SubRegionUpdateSchema) -> SubRegion:
    sub_region = await get_sub_region_by_id(db, sub_region_id)
    try:
        update_data = payload.dict(exclude_unset=True)

        # if "region_uuid" in update_data:
        #     region = db.query(Region).filter(
        #         Region.uuid == update_data["region_uuid"],
        #         Region.is_deleted == False
        #     ).first()
        #     if not region:
        #         raise HTTPException(status_code=404, detail="Region not found")
        #     sub_region.region_id = region.id 
        if "region_id" in update_data:
            region = await get_region_by_id(db, update_data["region_id"])
        for field, value in update_data.items():
                setattr(sub_region, field, value)

        db.commit()
        db.refresh(sub_region)
        return sub_region
    except Exception as e:
        db.rollback()
        raise APIException(
            module="SubRegion",
            error={"exception": str(e)},
            status_code=status.HTTP_400_BAD_REQUEST,
            message="Error updating sub-region."
        )


async def delete_sub_region(db: Session, sub_region_id: int) -> None:
    sub_region = await get_sub_region_by_id(db, sub_region_id)
    try:
        sub_region.is_deleted = True
        sub_region.deleted_at = datetime.now(timezone.utc)
        db.commit()
        return None
    except Exception as e:
        db.rollback()
        raise APIException(
            module="SubRegion",
            error={"exception": str(e)},
            status_code=status.HTTP_400_BAD_REQUEST,
            message="Error deleting sub-region."
        )


