"""
Unit Sprite Module
Handles unit-specific sprite functionality
"""
import os
import pygame
from game.core.sprite import Sprite

class UnitSprite(Sprite):
    """
    Sprite class specifically for game units
    """
    def __init__(self, unit, position=(0, 0), scale=1.0, target_size=None):
        """
        Initialize a unit sprite
        
        Args:
            unit: Unit object or dictionary
            position: (x, y) position on screen (top-left corner)
            scale: Scale factor for the image (used if target_size is None)
            target_size: Optional (width, height) tuple for fixed size
        """
        # Extract position from unit if available and position not specified
        unit_position = (0, 0)
        if isinstance(unit, dict) and 'position' in unit:
            unit_position = unit['position']
        elif hasattr(unit, 'position'):
            unit_position = unit.position
            
        # Use the provided position or default to the unit's position
        position_to_use = position if position != (0, 0) else unit_position
        
        # Handle both Unit objects and dictionaries
        if isinstance(unit, dict):
            self.unit_id = unit.get("id", "unknown")
            self.faction = unit.get("faction", "default")
            self.unit_type = unit.get("unit_type", "default")
        else:
            self.unit_id = getattr(unit, "id", "unknown")
            self.faction = getattr(unit, "faction", "default")
            self.unit_type = getattr(unit, "unit_type", "default")
        
        # Create a placeholder image first to ensure there's always a valid rect
        self.temp_image = pygame.Surface((64, 64), pygame.SRCALPHA)
        self.temp_image.fill((128, 128, 128, 128))
        
        # Call parent constructor with the position after we have a temporary image
        super().__init__(self.temp_image, position_to_use, scale, target_size)
        
        # Store the unit
        self.unit = unit
        
        # Try to load the unit's sprite based on faction and type
        self._load_unit_sprite()
        
        # Unit status effects
        self.status_effects = []
        
        # HP bar colors
        self.hp_bar_bg_color = (40, 40, 40)
        self.hp_bar_colors = [
            (200, 0, 0),    # < 25% health - red
            (200, 200, 0),  # < 50% health - yellow
            (0, 200, 0)     # >= 50% health - green
        ]
        
        # Selected outline
        self.is_selected = False
        self.selection_color = (255, 255, 0)  # Yellow
        self.outline_width = 2
        
        # Hover outline
        self.is_hovered = False
        self.hover_color = (150, 150, 255)  # Light blue
        
        # Movement animation
        self.move_start = None
        self.move_end = None
        self.move_duration = 0.3  # seconds
        self.move_timer = 0
        self.moving = False
        
        # Attack animation
        self.attack_flash_duration = 0.2  # seconds
        self.attack_flash_timer = 0
        self.attack_flashing = False
        self.flash_color = (255, 0, 0, 150)  # Red with alpha
        
        # Damage text animation
        self.damage_text = None
        self.damage_text_color = (255, 50, 50)
        self.damage_text_duration = 1.0  # seconds
        self.damage_text_timer = 0
        self.damage_text_position = (0, 0)
        self.damage_text_showing = False
        
    def _is_multi_tile_unit(self):
        """Check if this is a multi-tile unit"""
        if isinstance(self.unit, dict):
            size = self.unit.get("size", (1, 1))
        else:
            size = getattr(self.unit, "size", (1, 1))
        return size[0] > 1 or size[1] > 1
        
    def _load_unit_sprite(self):
        """Load the appropriate sprite image for the unit"""
        # Base directory for unit sprites
        base_dir = os.path.join("game", "assets", "images", "units")
        
        # Try to find unit type image
        filename = None
        extensions_to_check = [".png", ".jpg", ".jpeg"]

        def find_image_with_extensions(base_path):
            for ext in extensions_to_check:
                full_path = f"{base_path}{ext}"
                if os.path.exists(full_path):
                    return full_path
            return None

        # Special case for Space Marine units
        if self.faction == "space_marines":
            special_paths = [
                os.path.join(base_dir, self.faction, f"{self.unit_type}.png"),
                os.path.join(base_dir, self.faction, f"{self.unit_type}.jpg"),
                os.path.join(base_dir, self.faction, f"spacemarine {self.unit_type}.jpg"),
                os.path.join(base_dir, self.faction, f"space_marine_{self.unit_type}.jpg"),
                os.path.join(base_dir, self.faction, f"space marine {self.unit_type}.jpg")
            ]
            
            for path in special_paths:
                if os.path.exists(path):
                    filename = path
                    break
                    
        # Special case for Ork units      
        elif self.faction == "orks":
            special_paths = [
                os.path.join(base_dir, self.faction, f"{self.unit_type}.png"),
                os.path.join(base_dir, self.faction, f"{self.unit_type}.jpg"),
                os.path.join(base_dir, self.faction, f"ork_{self.unit_type}.jpg"),
                os.path.join(base_dir, self.faction, f"ork {self.unit_type}.jpg"),
                os.path.join(base_dir, self.faction, f"battle wagon.jpg"),  # For battle_wagon
                os.path.join(base_dir, self.faction, f"battle_wagon.jpg")
            ]
            
            for path in special_paths:
                if os.path.exists(path):
                    filename = path
                    break
                    
        # Special case for Tau units
        elif self.faction == "tau":
            special_paths = [
                os.path.join(base_dir, self.faction, f"{self.unit_type}.png"),
                os.path.join(base_dir, self.faction, f"{self.unit_type}.jpg"),
                os.path.join(base_dir, self.faction, f"tau_{self.unit_type}.jpg"),
                os.path.join(base_dir, self.faction, f"tau {self.unit_type}.jpg"),
                os.path.join(base_dir, self.faction, f"crisis suit.jpg"),  # For battle_suit
                os.path.join(base_dir, self.faction, f"crisis_suit.jpg"),
                os.path.join(base_dir, self.faction, f"battle_suit.jpg")
            ]
            
            for path in special_paths:
                if os.path.exists(path):
                    filename = path
                    break
                    
        # Special case for Chaos Marines which have differently named files
        elif self.faction == "chaos_marines":
            special_paths = [
                os.path.join(base_dir, self.faction, f"{self.unit_type}.png"),
                os.path.join(base_dir, self.faction, f"{self.unit_type}.jpg"),
                os.path.join(base_dir, self.faction, f"chaos_{self.unit_type}.jpg"),
                os.path.join(base_dir, self.faction, f"chaos {self.unit_type}.jpg"),
                os.path.join(base_dir, self.faction, "chaos spacemarine.jpg"),  # For chaos_marine type
                os.path.join(base_dir, self.faction, "chaos marine.jpg"),  # For chaos_marine type
                os.path.join(base_dir, self.faction, "Khorne Berserkers.jpg"),  # For chaos_berzerker type
                os.path.join(base_dir, self.faction, "khorne berserker.jpg"),  # For chaos_berzerker type
                os.path.join(base_dir, self.faction, "berzerker.jpg"),  # For chaos_berzerker type
                os.path.join(base_dir, self.faction, "lord of skulls.jpg"),      # For lord_of_skulls type
                os.path.join(base_dir, self.faction, "lord_of_skulls.jpg")
            ]
            
            for path in special_paths:
                if os.path.exists(path):
                    filename = path
                    break
        else:
            # Check for specific unit templates first
            if isinstance(self.unit, dict) and "template_id" in self.unit:
                template_base = os.path.join(base_dir, self.faction, self.unit['template_id'])
                filename = find_image_with_extensions(template_base)
            elif hasattr(self.unit, 'template_id') and self.unit.template_id:
                template_base = os.path.join(base_dir, self.faction, self.unit.template_id)
                filename = find_image_with_extensions(template_base)
            
            # If no template-specific image, try unit type
            if not filename:
                type_base = os.path.join(base_dir, self.faction, self.unit_type)
                filename = find_image_with_extensions(type_base)
            
            # If no unit type image, try faction default
            if not filename:
                faction_base = os.path.join(base_dir, self.faction, "default")
                filename = find_image_with_extensions(faction_base)
                
        # If found, load the image using parent method (which will handle resizing)
        if filename:
            try:
                # Load using the parent Sprite's method
                super().load_image(filename)
                print(f"Successfully loaded unit image from: {filename}")
            except Exception as e:
                # If loading fails, we'll fall back to placeholder
                print(f"Error loading image {filename}: {e}")
                self.image = None
                
        # If image wasn't loaded, create a colored placeholder
        if not hasattr(self, 'image') or not self.image:
            # Create a colored placeholder based on faction
            if self.faction == "space_marines":
                color = (0, 0, 200)  # Blue
            elif self.faction == "chaos_marines":
                color = (150, 0, 0)  # Dark red
            elif self.faction == "eldar":
                color = (0, 150, 150)  # Teal
            elif self.faction == "orks":
                color = (0, 100, 0)  # Dark green
            elif self.faction == "tau":
                color = (200, 100, 0)  # Orange
            elif self.faction == "necrons":
                color = (50, 200, 50)  # Green
            else:
                color = (128, 128, 128)  # Gray
            
            # For multi-tile units, create a distinctive placeholder
            if self._is_multi_tile_unit():
                size = self.unit.get("size", (1, 1)) if isinstance(self.unit, dict) else getattr(self.unit, "size", (1, 1))
                unit_name = self.unit.get("name", "Unknown") if isinstance(self.unit, dict) else getattr(self.unit, "name", "Unknown")
                print(f"Creating distinctive placeholder for multi-tile unit: {unit_name} with size {size}")
                
                # Make multi-tile units more distinctive
                color = (min(color[0] + 50, 255), min(color[1] + 50, 255), min(color[2] + 50, 255))
            
            self.image = super()._create_placeholder(color=color)
            self.original_image = self.image.copy()
            self.rect = self.image.get_rect()
            self.rect.topleft = self.position
            print(f"Using placeholder for unit: {self.faction} - {self.unit_type}")
    
    def set_selected(self, selected):
        """Set whether the unit is selected"""
        self.is_selected = selected
    
    def set_hovered(self, hovered):
        """Set whether the unit is being hovered over"""
        self.is_hovered = hovered
    
    def start_move_animation(self, start_pos, end_pos):
        """
        Start an animation to move the unit from start to end position
        
        Args:
            start_pos: Starting grid position (x, y)
            end_pos: Ending grid position (x, y)
        """
        self.move_start = start_pos
        self.move_end = end_pos
        self.move_timer = 0
        self.moving = True
    
    def start_attack_animation(self):
        """Start the attack flash animation"""
        self.attack_flash_timer = 0
        self.attack_flashing = True
    
    def show_damage_text(self, damage, position):
        """
        Show damage text animation
        
        Args:
            damage: Amount of damage to display
            position: Position to show the text
        """
        self.damage_text = str(damage)
        self.damage_text_position = position
        self.damage_text_timer = 0
        self.damage_text_showing = True
    
    def add_status_effect(self, effect_name, duration):
        """
        Add a status effect to the unit
        
        Args:
            effect_name: Name of the effect
            duration: Duration in seconds
        """
        self.status_effects.append({
            "name": effect_name,
            "duration": duration,
            "timer": 0
        })
    
    def update(self, dt):
        """
        Update the unit sprite animations
        
        Args:
            dt: Time delta in seconds
        """
        super().update(dt)
        
        # Update move animation
        if self.moving:
            self.move_timer += dt
            if self.move_timer >= self.move_duration:
                self.moving = False
                self.set_position(self.move_end[0], self.move_end[1])
            else:
                # Interpolate position
                progress = self.move_timer / self.move_duration
                x = self.move_start[0] + (self.move_end[0] - self.move_start[0]) * progress
                y = self.move_start[1] + (self.move_end[1] - self.move_start[1]) * progress
                self.set_position(x, y)
        
        # Update attack flash animation
        if self.attack_flashing:
            self.attack_flash_timer += dt
            if self.attack_flash_timer >= self.attack_flash_duration:
                self.attack_flashing = False
        
        # Update damage text animation
        if self.damage_text_showing:
            self.damage_text_timer += dt
            if self.damage_text_timer >= self.damage_text_duration:
                self.damage_text_showing = False
        
        # Update status effects
        for effect in self.status_effects[:]:  # Make a copy to safely modify
            effect["timer"] += dt
            if effect["timer"] >= effect["duration"]:
                self.status_effects.remove(effect)
    
    def _draw_hp_bar(self, screen):
        """Draw the unit's health bar"""
        # Calculate health percentage
        hp_percent = 1.0
        if isinstance(self.unit, dict) and "hp" in self.unit and "max_hp" in self.unit:
            hp_percent = max(0.0, min(1.0, self.unit["hp"] / self.unit["max_hp"]))
        elif hasattr(self.unit, "current_hp") and hasattr(self.unit, "stats"):
            hp_percent = max(0.0, min(1.0, self.unit.current_hp / self.unit.stats["hp"]))
            
        # HP bar background
        bar_width = self.rect.width
        bar_height = 4
        bar_rect = pygame.Rect(
            self.rect.left, 
            self.rect.bottom + 2, 
            bar_width, 
            bar_height
        )
        pygame.draw.rect(screen, self.hp_bar_bg_color, bar_rect)
        
        # HP bar fill
        if hp_percent > 0:
            fill_color = self.hp_bar_colors[2]  # Default green
            if hp_percent < 0.25:
                fill_color = self.hp_bar_colors[0]  # Red
            elif hp_percent < 0.5:
                fill_color = self.hp_bar_colors[1]  # Yellow
                
            fill_width = int(bar_width * hp_percent)
            fill_rect = pygame.Rect(bar_rect.left, bar_rect.top, fill_width, bar_height)
            pygame.draw.rect(screen, fill_color, fill_rect)
    
    def _draw_damage_text(self, screen):
        """Draw damage text animation"""
        if not self.damage_text_showing:
            return
            
        # Create font if not already created
        if not hasattr(self, 'damage_font'):
            self.damage_font = pygame.font.Font(None, 20)
            
        # Create text
        text_surf = self.damage_font.render(self.damage_text, True, self.damage_text_color)
        text_rect = text_surf.get_rect()
        
        # Position text and animate it upward
        offset_y = int(-20 * (self.damage_text_timer / self.damage_text_duration))
        text_rect.midbottom = (
            self.rect.centerx,
            self.rect.top + offset_y
        )
        
        # Draw text with fade out near the end
        if self.damage_text_timer > self.damage_text_duration * 0.7:
            # Calculate alpha for fade out
            alpha = int(255 * (1 - (self.damage_text_timer - (self.damage_text_duration * 0.7)) / (self.damage_text_duration * 0.3)))
            text_surf.set_alpha(alpha)
            
        screen.blit(text_surf, text_rect)
    
    def _draw_status_effects(self, screen):
        """Draw status effect indicators"""
        status_x = self.rect.right + 5
        status_y = self.rect.top
        
        for effect in self.status_effects:
            # Create a small colored circle for each status
            if effect["name"] == "poisoned":
                color = (0, 150, 0)  # Green
            elif effect["name"] == "stunned":
                color = (150, 150, 0)  # Yellow
            elif effect["name"] == "buffed":
                color = (0, 0, 150)  # Blue
            else:
                color = (100, 100, 100)  # Gray
                
            pygame.draw.circle(screen, color, (status_x, status_y), 5)
            status_y += 12  # Space between status indicators
    
    def draw(self, screen):
        """
        Draw the unit sprite on the screen
        
        Args:
            screen: Pygame surface to draw on
        """
        # Skip drawing if we don't have an image loaded
        if not self.image or not hasattr(self, 'rect'):
            return
        
        # Get unit information for special handling
        unit_size = getattr(self.unit, 'size', (1, 1)) if not isinstance(self.unit, dict) else self.unit.get('size', (1, 1))
        
        # Special handling for multi-tile units
        is_multi_tile = unit_size[0] > 1 or unit_size[1] > 1
        
        # Create a rect for drawing shadows/effects, to keep the original rect unmodified
        draw_rect = self.rect.copy()
        
        # For multi-tile units, draw a subtle highlight underneath
        if is_multi_tile:
            bg_rect = draw_rect.copy()
            bg_rect.inflate_ip(4, 4)  # Slightly larger than the unit
            pygame.draw.rect(screen, (70, 70, 70), bg_rect, border_radius=2)
        
        # If flashing from attack, draw with the flash effect
        if self.attack_flashing:
            flash_surface = pygame.Surface((self.rect.width, self.rect.height), pygame.SRCALPHA)
            flash_surface.fill(self.flash_color)
            screen.blit(flash_surface, self.rect)
        
        # Draw the sprite image (either placeholder or actual unit image)
        screen.blit(self.image, self.rect)
        
        # Draw selection outline if selected
        if self.is_selected:
            pygame.draw.rect(screen, self.selection_color, self.rect, self.outline_width, border_radius=2)
        
        # Draw hover outline if hovered
        elif self.is_hovered:
            pygame.draw.rect(screen, self.hover_color, self.rect, self.outline_width, border_radius=2)
        
        # Draw HP bar
        self._draw_hp_bar(screen)
        
        # Draw damage text
        self._draw_damage_text(screen)
        
        # Draw status effects
        self._draw_status_effects(screen)
        
        # Draw attack flash overlay
        if self.attack_flashing:
            flash_surface = pygame.Surface(self.rect.size, pygame.SRCALPHA)
            flash_surface.fill(self.flash_color)
            screen.blit(flash_surface, self.rect.topleft)

    def set_position(self, x, y):
        """
        Set the sprite position
        
        Args:
            x: X coordinate (screen coordinates)
            y: Y coordinate (screen coordinates)
        """
        self.position = (x, y)
        if hasattr(self, 'rect'):
            self.rect.topleft = (x, y) 