
    kh>                     j    d Z ddlZddlZddZddZddZddZddZ G d d	      Z	 G d
 de	      Z
y)aY   Mixup and Cutmix

Papers:
mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412)

CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899)

Code Reference:
CutMix: https://github.com/clovaai/CutMix-PyTorch

Hacked together by / Copyright 2019, Ross Wightman
    Nc                     | j                         j                  dd      } t        j                  | j	                         d   |f|| j
                        j                  d| |      S )N   r   )device)longviewtorchfullsizer   scatter_)xnum_classeson_value	off_values       K/var/www/teggl/fontify/venv/lib/python3.12/site-packages/timm/data/mixup.pyone_hotr      sP    	b!A::qvvx{K0)AHHMVVWXZ[]eff    c                     ||z  }d|z
  |z   }t        | |||      }t        | j                  d      |||      }||z  |d|z
  z  z   S )N      ?)r   r   r   )r   flip)targetr   lam	smoothingr   r   y1y2s           r   mixup_targetr      sY    K'II~	)H	x9	MB	Qx9	UB8bBHo%%r   c                 F   t        j                  d|z
        }| dd \  }}t        ||z        t        ||z        }}t        ||z        t        ||z        }
}	t         j                  j	                  d|	z   ||	z
  |      }t         j                  j	                  d|
z   ||
z
  |      }t        j
                  ||dz  z
  d|      }t        j
                  ||dz  z   d|      }t        j
                  ||dz  z
  d|      }t        j
                  ||dz  z   d|      }||||fS )a   Standard CutMix bounding-box
    Generates a random square bbox based on lambda value. This impl includes
    support for enforcing a border margin as percent of bbox dimensions.

    Args:
        img_shape (tuple): Image shape as tuple
        lam (float): Cutmix lambda value
        margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image)
        count (int): Number of bbox to generate
    r   Nr   r      )npsqrtintrandomrandintclip)	img_shaper   margincountratioimg_himg_wcut_hcut_wmargin_ymargin_xcycxylyhxlxhs                    r   	rand_bboxr7      s    GGAGERS>LE5uu}%s55='95EVe^,c&5..AhH			1x<)9		FB			1x<)9		FB	eqj!U	+B	eqj!U	+B	eqj!U	+B	eqj!U	+Br2r>r   c                    t        |      dk(  sJ | dd \  }}t        j                  j                  t	        ||d   z        t	        ||d   z        |      }t        j                  j                  t	        ||d   z        t	        ||d   z        |      }t        j                  j                  d||z
  |      }t        j                  j                  d||z
  |      }||z   }	||z   }
||	||
fS )a   Min-Max CutMix bounding-box
    Inspired by Darknet cutmix impl, generates a random rectangular bbox
    based on min/max percent values applied to each dimension of the input image.

    Typical defaults for minmax are usually in the  .2-.3 for min and .8-.9 range for max.

    Args:
        img_shape (tuple): Image shape as tuple
        minmax (tuple or list): Min and max bbox ratios (as percent of image size)
        count (int): Number of bbox to generate
    r    r   Nr   r   r   )lenr!   r$   r%   r#   )r'   minmaxr)   r+   r,   r-   r.   r3   r5   yuxus              r   rand_bbox_minmaxr=   6   s     v;!RS>LE5IIc%&)"34c%&):K6LSXYEIIc%&)"34c%&):K6LSXYE			1eem%		8B			1eem%		8B	eB	eBr2r>r   c                     |t        | ||      \  }}}}nt        | ||      \  }}}}|s|%||z
  ||z
  z  }	d|	t        | d   | d   z        z  z
  }||||f|fS )z0 Generate bbox and apply lambda correction.
    )r)   r   r   r   )r=   r7   float)
r'   r   ratio_minmaxcorrect_lamr)   r3   r;   r5   r<   	bbox_areas
             r   cutmix_bbox_and_lamrC   M   s     ))\OBB"9c?BBl."Wb)	9uYr]Yr]%BCCCBS  r   c                   @    e Zd ZdZ	 	 d
dZd Zd Zd Zd Zd Z	d	 Z
y)Mixupas   Mixup/Cutmix that applies different params to each element or whole batch

    Args:
        mixup_alpha (float): mixup alpha value, mixup is active if > 0.
        cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0.
        cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None.
        prob (float): probability of applying mixup or cutmix per batch or element
        switch_prob (float): probability of switching to cutmix instead of mixup when both are active
        mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element)
        correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders
        label_smoothing (float): apply label smoothing to the mixed target tensor
        num_classes (int): number of classes for target
    Nc
                     || _         || _        || _        | j                  !t        | j                        dk(  sJ d| _        || _        || _        || _        |	| _        || _        || _	        d| _
        y )Nr    r   T)mixup_alphacutmix_alphacutmix_minmaxr9   mix_probswitch_problabel_smoothingr   moderA   mixup_enabled)
selfrG   rH   rI   probrK   rM   rA   rL   r   s
             r   __init__zMixup.__init__h   s}    &(*)t))*a/// #D&.&	&!r   c           	      P   t        j                  |t         j                        }t        j                  |t              }| j
                  r| j                  dkD  r| j                  dkD  rt         j                  j                  |      | j                  k  }t        j                  |t         j                  j                  | j                  | j                  |      t         j                  j                  | j                  | j                  |            }n| j                  dkD  r7t         j                  j                  | j                  | j                  |      }nh| j                  dkD  rRt        j                  |t              }t         j                  j                  | j                  | j                  |      }nJ d       t        j                  t         j                  j                  |      | j                  k  |j                  t         j                        |      }||fS )Ndtype        r   ROne of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true.)r!   onesfloat32zerosboolrN   rG   rH   r$   randrK   wherebetarJ   astype)rO   
batch_sizer   
use_cutmixlam_mixs        r   _params_per_elemzMixup._params_per_elemy   s   ggj

3XXj5
"$):):R)?YY^^J7$:J:JJ
((IINN4#4#4d6G6GjNYIINN4#3#3T5E5EJNWY !!B&))..)9)94;K;KR\.]""R'WWZt<
))..):):D<M<MT^._rrru((299>>*5Ew~~VXV`V`GacfgCJr   c                 $   d}d}| j                   r|t        j                  j                         | j                  k  rP| j
                  dkD  r| j                  dkD  rt        j                  j                         | j                  k  }|r4t        j                  j                  | j                  | j                        n3t        j                  j                  | j
                  | j
                        }n| j
                  dkD  r5t        j                  j                  | j
                  | j
                        }nM| j                  dkD  r7d}t        j                  j                  | j                  | j                        }nJ d       t        |      }||fS )Nr   FrU   TrV   )
rN   r!   r$   r[   rJ   rG   rH   rK   r]   r?   )rO   r   r`   ra   s       r   _params_per_batchzMixup._params_per_batch   s#   
")).."2T]]"B"$):):R)?YY^^-0@0@@
R\"))..):):D<M<MNIINN4#3#3T5E5EF !!B&))..)9)94;K;KL""R'!
))..):):D<M<MNrrru.CJr   c                    t        |      }| j                  |      \  }}|j                         }t        |      D ]  }||z
  dz
  }||   }|dk7  s||   r^t	        ||   j
                  || j                  | j                        \  \  }	}
}}}||   d d |	|
||f   ||   d d |	|
||f<   |||<   y||   |z  ||   d|z
  z  z   ||<    t        j                  ||j                  |j                        j                  d      S )Nr   r   r@   rA   r   rT   )r9   rb   clonerangerC   shaperI   rA   r	   tensorr   rT   	unsqueezerO   r   r_   	lam_batchr`   x_origijr   r3   r4   r5   r6   s                r   	_mix_elemzMixup._mix_elem   s#   V
 $ 5 5j A	:z" 
	>AQ"AA,Cbya=,?!

Cd6H6HVZVfVf-h)$RRc,21IaB2o,FAaDBrE2b5)#&IaLQ4#:q	QW(==AaD
	> ||IahhaggFPPQRSSr   c                    t        |      }| j                  |dz        \  }}|j                         }t        |dz        D ]  }||z
  dz
  }||   }|dk7  s||   r~t	        ||   j
                  || j                  | j                        \  \  }	}
}}}||   d d |	|
||f   ||   d d |	|
||f<   ||   d d |	|
||f   ||   d d |	|
||f<   |||<   ||   |z  ||   d|z
  z  z   ||<   ||   |z  ||   d|z
  z  z   ||<    t        j                  ||d d d   f      }t        j                  ||j                  |j                        j                  d      S )Nr    r   r   rf   r   rg   )r9   rb   rh   ri   rC   rj   rI   rA   r!   concatenater	   rk   r   rT   rl   rm   s                r   	_mix_pairzMixup._mix_pair   s   V
 $ 5 5jAo F	:zQ' 	>AQ"AA,Cbya=,?!

Cd6H6HVZVfVf-h)$RRc,21IaB2o,FAaDBrE2b5),21IaB2o,FAaDBrE2b5)#&IaLQ4#:q	QW(==AaDQ4#:q	QW(==AaD	> NNIy2#?@	||IahhaggFPPQRSSr   c                    | j                         \  }}|dk(  ry|rft        |j                  || j                  | j                        \  \  }}}}}|j                  d      d d d d ||||f   |d d d d ||||f<   |S |j                  d      j                  d|z
        }|j                  |      j                  |       |S )Nr   rf   r   )rd   rC   rj   rI   rA   r   mul_add_)	rO   r   r   r`   r3   r4   r5   r6   	x_flippeds	            r   
_mix_batchzMixup._mix_batch   s    002Z"9$74+=+=4K[K[%]!RRc$%FF1IaBrE2b5.@$AAaBrE2b5 ! 
 q	rCx0IFF3KY'
r   c                 $   t        |      dz  dk(  sJ d       | j                  dk(  r| j                  |      }n2| j                  dk(  r| j                  |      }n| j	                  |      }t        || j                  || j                        }||fS )Nr    r   )Batch size should be even when using thiselempair)r9   rM   rr   ru   rz   r   r   rL   )rO   r   r   r   s       r   __call__zMixup.__call__   s    1vzQK KK99..#CYY& ..#C//!$Cfd&6&6T=Q=QR&yr   )	r   rU   Nr   g      ?batchTg?i  )__name__
__module____qualname____doc__rQ   rb   rd   rr   ru   rz   r    r   r   rE   rE   Z   s6     cfRV""($T"T(	r   rE   c                   ,    e Zd ZdZddZd Zd ZddZy)	FastCollateMixupz Fast Collate w/ Mixup/Cutmix that applies different params to each element or whole batch

    A Mixup impl that's performed while collating the batches.
    c           	         t        |      }|r|dz  n|}t        |      |k(  sJ | j                  |      \  }}t        |d   d   t        j                        }t        |      D ]  }	||	z
  dz
  }
||	   }||	   d   }|dk7  r3||	   r|s"|r|j                         n|j                         }t        |j                  || j                  | j                        \  \  }}}}}||
   d   d d ||||f   |d d ||||f<   |||	<   n|rf|j                  t        j                        |z  ||
   d   j                  t        j                        d|z
  z  z   }t        j                  ||       nG|j                         |z  ||
   d   j                         d|z
  z  z   }t!        j"                  ||       ||	xx   |r2t!        j$                  |j                  t        j&                              n|j)                         z  cc<    |r*t        j*                  |t        j,                  |      f      }t!        j.                  |      j1                  d      S )Nr    r   r   r   rf   out)r9   rb   
isinstancer!   ndarrayri   copyrh   rC   rj   rI   rA   r^   rX   rintr?   r	   round
from_numpyuint8bytert   rW   rk   rl   )rO   outputr   halfr_   num_elemrn   r`   is_nprp   rq   r   mixedr3   r4   r5   r6   s                    r   _mix_elem_collatez"FastCollateMixup._mix_elem_collate   s(   Z
&*:?
6{h&&& $ 5 5h ?	:58A;

3x 	]AQ"AA,C!HQKEbya=05

5;;=,?%)%7%7$($4$4	-)$RRc .31Xa[BrE2b5-IE!RUBrE/*#&IaL %RZZ 83 >q!ASASTVT^T^A_cdgjcjAk k51 % 3eAhqk6G6G6IQQTW6U UEu51IU))%,,rxx*@APUPZPZP\\I/	]0 	27783D'EFI||I&0033r   c           	         t        |      }| j                  |dz        \  }}t        |d   d   t        j                        }t        |dz        D ]  }||z
  dz
  }||   }	||   d   }
||   d   }d|	cxk  rdk  sJ  J |	dk  r||   rt        |j                  |	| j                  | j                        \  \  }}}}}	|r|
d d ||||f   j                         n|
d d ||||f   j                         }|d d ||||f   |
d d ||||f<   ||d d ||||f<   |	||<   nG|r|
j                  t        j                        |	z  |j                  t        j                        d|	z
  z  z   }|j                  t        j                        |	z  |
j                  t        j                        d|	z
  z  z   }|}
t        j                  ||       t        j                  |
|
       n|
j                         |	z  |j                         d|	z
  z  z   }|j                         |	z  |
j                         d|	z
  z  z   }|}
t!        j"                  ||       t!        j"                  |
|
       ||xx   |r2t!        j$                  |
j                  t        j&                              n|
j)                         z  cc<   ||xx   |r2t!        j$                  |j                  t        j&                              n|j)                         z  cc<    t        j*                  ||d d d   f      }t!        j,                  |      j/                  d      S )Nr    r   r   r   rf   r   r   )r9   rb   r   r!   r   ri   rC   rj   rI   rA   r   rh   r^   rX   r   r?   r	   r   r   r   r   rt   rk   rl   )rO   r   r   r_   rn   r`   r   rp   rq   r   mixed_imixed_jr3   r4   r5   r6   patch_i
mixed_temps                     r   _mix_pair_collatez"FastCollateMixup._mix_pair_collate  s   Z
 $ 5 5jAo F	:58A;

3zQ'  	aAQ"AA,CAhqkGAhqkG?s?"?"?Rxa=,?%)%7%7$($4$4	-)$RRc BGgaB2o6;;=GTUWYZ\W\^`ac^cTcLdLjLjLlG/6q"R%B/GGAr"ubeO,/6GAr"ubeO,#&IaL%,^^BJJ%?#%EWYWaWaHbfgjmfmHn%n
").."<s"BW^^TVT^T^E_cdgjcjEk"k",W5W5%,]]_s%:W]]_PQTWPW=X%X
")--/C"7'--/QQTW:U"U",G9G91Iu))'..*BCRYR^R^R``I1Iu))'..*BCRYR^R^R``IA 	aB NNIy2#?@	||I&0033r   c           	         t        |      }| j                         \  }}t        |d   d   t        j                        }|r5t        |j                  || j                  | j                        \  \  }}}	}
}t        |      D ][  }||z
  dz
  }||   d   }|dk7  r|rC|r|j                         n|j                         }||   d   d d 	
f   |d d |||	|
f<   n|rf|j                  t        j                        |z  ||   d   j                  t        j                        d|z
  z  z   }t        j                  ||       nG|j                         |z  ||   d   j                         d|z
  z  z   }t!        j"                  ||       ||xx   |r2t!        j$                  |j                  t        j&                              n|j)                         z  cc<   ^ |S )Nr   rf   r   r   r   )r9   rd   r   r!   r   rC   rj   rI   rA   ri   r   rh   r^   rX   r   r?   r	   r   r   r   r   )rO   r   r   r_   r   r`   r   r3   r4   r5   r6   rp   rq   r   s                 r   _mix_batch_collatez#FastCollateMixup._mix_batch_collate/  s   Z
002Z58A;

3$7!// ,,	%!RRc z" 	]AQ"A!HQKEby,1EJJLu{{}E-21Xa[BrE2b5-IE!RUBrE/* %RZZ 83 >q!ASASTVT^T^A_cdgjcjAk k51 % 3eAhqk6G6G6IQQTW6U UEu51IU))%,,rxx*@APUPZPZP\\I	] 
r   Nc                 l   t        |      }|dz  dk(  sJ d       d| j                  v }|r|dz  }t        j                  |g|d   d   j                  t        j
                        }| j                  dk(  s| j                  dk(  r| j                  |||      }n4| j                  dk(  r| j                  ||      }n| j                  ||      }t        j                  |D cg c]  }|d	   	 c}t        j                        }t        || j                  || j                        }|d | }||fS c c}w )
Nr    r   r|   r   rS   r}   )r   r~   r   )r9   rM   r	   rY   rj   r   r   r   r   rk   int64r   r   rL   )	rO   r   _r_   r   r   r   br   s	            r   r   zFastCollateMixup.__call__L  s    Z
A~"O$OO""1Jj=58A;+<+<=U[[Q99$))v"5((T(BCYY& ((7C))&%8CU3qt35;;Gfd&6&6T=Q=QR$v~ 4s   #D1)FN)r   r   r   r   r   r   r   r   r   r   r   r   r      s    
!4F'4R:r   r   )r   rU   )rU   Nr   )NTN)r   numpyr!   r	   r   r   r7   r=   rC   rE   r   r   r   r   <module>r      sG     g
&0.
!@ @Fu r   