"""
Unit Renderer Module
Handles rendering of units in the battle scene
"""
import pygame
from game.core.unit_sprite import UnitSprite

class UnitRenderer:
    """
    Handles rendering of units in battle
    """
    def __init__(self, resource_manager, grid_size, cell_size):
        """
        Initialize the unit renderer
        
        Args:
            resource_manager: ResourceManager instance
            grid_size: Grid dimensions (width, height) in cells
            cell_size: Size of each grid cell in pixels
        """
        self.resource_manager = resource_manager
        self.grid_size = grid_size
        self.cell_size = cell_size
        
        # Dictionary of unit sprites by unit ID
        self.unit_sprites = {}
        
        # Whether to use sprites or shapes
        self.use_sprites = resource_manager.has_sprites if hasattr(resource_manager, 'has_sprites') else False
        
        # Sprite scale factor
        self.sprite_scale = 0.8  # 80% of cell size
        
        # Default unit colors by faction (fallback when sprites not available)
        self.faction_colors = {
            "space_marines": (0, 0, 200),      # Blue
            "chaos_marines": (150, 0, 0),      # Dark red
            "eldar": (0, 150, 150),           # Teal
            "orks": (0, 100, 0),              # Dark green
            "tau": (200, 100, 0),             # Orange
            "necrons": (50, 200, 50),         # Green
            "default": (128, 128, 128)        # Gray
        }
    
    def _get_unit_id(self, unit):
        """Get the ID of a unit (handles both dictionaries and objects)"""
        if isinstance(unit, dict):
            return unit.get("id", "unknown")
        else:
            return getattr(unit, "id", "unknown")
            
    def create_unit_sprite(self, unit, grid_pos):
        """
        Create a sprite for a unit
        
        Args:
            unit: Unit object
            grid_pos: (x, y) position on the grid
        """
        # Calculate screen position from grid position
        screen_x = grid_pos[0] * self.cell_size + (self.cell_size // 2)
        screen_y = grid_pos[1] * self.cell_size + (self.cell_size // 2)
        
        # Create the unit sprite
        sprite = UnitSprite(unit, (screen_x, screen_y), self.sprite_scale)
        
        # Store the sprite
        unit_id = self._get_unit_id(unit)
        self.unit_sprites[unit_id] = sprite
        
        return sprite
    
    def update_unit_positions(self, grid, offset=(0, 0)):
        """
        Update unit sprite positions based on grid
        
        Args:
            grid: 2D grid of units
            offset: (x, y) grid offset in pixels
        """
        # Clear any sprites for units no longer in the grid
        existing_units = set()
        for x in range(len(grid)):
            for y in range(len(grid[0])):
                unit = grid[x][y]
                if unit:
                    unit_id = self._get_unit_id(unit)
                    existing_units.add(unit_id)
        
        # Remove sprites for units no longer in the grid
        for unit_id in list(self.unit_sprites.keys()):
            if unit_id not in existing_units:
                del self.unit_sprites[unit_id]
        
        # Update positions of existing units and create sprites for new units
        for x in range(len(grid)):
            for y in range(len(grid[0])):
                unit = grid[x][y]
                if unit:
                    # Calculate screen position
                    screen_x = offset[0] + x * self.cell_size + (self.cell_size // 2)
                    screen_y = offset[1] + y * self.cell_size + (self.cell_size // 2)
                    
                    unit_id = self._get_unit_id(unit)
                    if unit_id in self.unit_sprites:
                        # Update existing sprite position
                        self.unit_sprites[unit_id].set_position(screen_x, screen_y)
                    else:
                        # Create new sprite
                        self.create_unit_sprite(unit, (x, y))
                        
                        # Update position with offset
                        self.unit_sprites[unit_id].set_position(screen_x, screen_y)
    
    def set_selected_unit(self, unit):
        """
        Set the selected unit
        
        Args:
            unit: Selected unit or None
        """
        # Clear selection for all units
        for sprite in self.unit_sprites.values():
            sprite.set_selected(False)
        
        # Set selection for the specified unit
        if unit:
            unit_id = self._get_unit_id(unit)
            if unit_id in self.unit_sprites:
                self.unit_sprites[unit_id].set_selected(True)
    
    def set_hovered_unit(self, unit):
        """
        Set the hovered unit
        
        Args:
            unit: Hovered unit or None
        """
        # Clear hover for all units
        for sprite in self.unit_sprites.values():
            sprite.set_hovered(False)
        
        # Set hover for the specified unit
        if unit:
            unit_id = self._get_unit_id(unit)
            if unit_id in self.unit_sprites:
                self.unit_sprites[unit_id].set_hovered(True)
    
    def animate_unit_move(self, unit, start_pos, end_pos):
        """
        Animate a unit moving from start to end position
        
        Args:
            unit: Unit to animate
            start_pos: Starting grid position (x, y)
            end_pos: Ending grid position (x, y)
        """
        if unit:
            unit_id = self._get_unit_id(unit)
            if unit_id in self.unit_sprites:
                # Convert grid positions to screen positions
                start_screen_x = start_pos[0] * self.cell_size + (self.cell_size // 2)
                start_screen_y = start_pos[1] * self.cell_size + (self.cell_size // 2)
                end_screen_x = end_pos[0] * self.cell_size + (self.cell_size // 2)
                end_screen_y = end_pos[1] * self.cell_size + (self.cell_size // 2)
                
                # Start the animation
                self.unit_sprites[unit_id].start_move_animation(
                    (start_screen_x, start_screen_y),
                    (end_screen_x, end_screen_y)
                )
    
    def animate_unit_attack(self, attacker, defender, damage):
        """
        Animate an attack between units
        
        Args:
            attacker: Attacking unit
            defender: Defending unit
            damage: Amount of damage dealt
        """
        # Flash the attacker
        if attacker:
            attacker_id = self._get_unit_id(attacker)
            if attacker_id in self.unit_sprites:
                self.unit_sprites[attacker_id].start_attack_animation()
        
        # Show damage number on defender
        if defender:
            defender_id = self._get_unit_id(defender)
            if defender_id in self.unit_sprites:
                defender_sprite = self.unit_sprites[defender_id]
                defender_sprite.show_damage_text(damage, defender_sprite.rect.midtop)
    
    def update(self, dt):
        """
        Update all unit sprites
        
        Args:
            dt: Time delta in seconds
        """
        for sprite in self.unit_sprites.values():
            sprite.update(dt)
    
    def render(self, screen, grid_offset=(0, 0)):
        """
        Render all unit sprites
        
        Args:
            screen: Surface to render on
            grid_offset: (x, y) offset for the grid in pixels
        """
        if self.use_sprites:
            # Render with sprites
            for sprite in self.unit_sprites.values():
                sprite.draw(screen)
        else:
            # Render with simple shapes (fallback)
            for unit_id, sprite in self.unit_sprites.items():
                unit = sprite.unit
                pos = sprite.position
                
                # Determine color based on faction
                faction = sprite.faction
                color = self.faction_colors.get(faction, self.faction_colors["default"])
                
                # Draw a colored circle
                circle_radius = int(self.cell_size * 0.4)
                pygame.draw.circle(screen, color, pos, circle_radius)
                pygame.draw.circle(screen, (200, 200, 200), pos, circle_radius, 2)
                
                # Draw HP bar
                hp_percent = 1.0
                if isinstance(unit, dict) and "hp" in unit and "max_hp" in unit:
                    hp_percent = unit["hp"] / unit["max_hp"]
                elif hasattr(unit, "current_hp") and hasattr(unit, "stats"):
                    hp_percent = unit.current_hp / unit.stats["hp"]
                
                bar_width = self.cell_size - 10
                bar_height = 4
                
                # HP bar background
                bar_rect = pygame.Rect(
                    pos[0] - bar_width // 2,
                    pos[1] + circle_radius + 4,
                    bar_width,
                    bar_height
                )
                pygame.draw.rect(screen, (40, 40, 40), bar_rect)
                
                # HP bar fill
                if hp_percent > 0:
                    # Determine color based on HP percentage
                    if hp_percent > 0.6:
                        fill_color = (0, 200, 0)  # Green
                    elif hp_percent > 0.3:
                        fill_color = (200, 200, 0)  # Yellow
                    else:
                        fill_color = (200, 0, 0)  # Red
                        
                    fill_width = int(bar_width * hp_percent)
                    fill_rect = pygame.Rect(bar_rect.left, bar_rect.top, fill_width, bar_height)
                    pygame.draw.rect(screen, fill_color, fill_rect)
                
                # Draw selection highlight if selected
                if sprite.is_selected:
                    pygame.draw.circle(
                        screen,
                        (255, 255, 0),  # Yellow
                        pos,
                        circle_radius + 4,
                        2
                    )
                
                # Draw hover highlight if hovered
                elif sprite.is_hovered:
                    pygame.draw.circle(
                        screen,
                        (150, 150, 255),  # Light blue
                        pos,
                        circle_radius + 4,
                        2
                    ) 