
    khK              
          d Z ddlZddlZddlmZmZ ddlZddlmZ ddl	mc m
Z ddlmZ ddlmZ ddlmZ ddlmZ  eej*                  j-                  d	d            dkD  Z	 	 d(d
eeef   deeeef      dedej2                  fdZdeeef   deedf   fdZ	 	 d)dedefdZdeeef   deedf   fdZ G d dej>                        Z 	 	 d*deeef   deeef   fdZ! G d dej>                        Z"	 d+dedee   fd Z#d!ed"ed#ej2                  d$ej2                  dej2                  f
d%Z$ G d& d'ej>                        Z%y),zf Relative position embedding modules and functions

Hacked together by / Copyright 2022 Ross Wightman
    N)OptionalTuple   )ndgrid)RegularGridInterpolator)Mlp)trunc_normal_TIMM_USE_SCIPY_INTERPq_sizek_sizeclass_tokenreturnc           	         |J d       t        j                  t        t        j                  | d         t        j                  | d                     j	                  d      }|d d d d d f   |d d d d d f   z
  }|j                  ddd      }|d d d d dfxx   | d   dz
  z  cc<   |d d d d dfxx   | d   dz
  z  cc<   |d d d d dfxx   d| d   z  dz
  z  cc<   d| d   z  dz
  d| d   z  dz
  z  }|j                  d      }|r5t        j                  |g d      }||ddd f<   |dz   |dd df<   |dz   |d<   |j                         S )Nz-Different q & k sizes not currently supportedr   r      )r   r   r   r   r   r   )
torchstackr   arangeflattenpermutesumFpad
contiguous)r   r   r   coordsrelative_coordsnum_relative_distancerelative_position_indexs          U/var/www/teggl/fontify/venv/lib/python3.12/site-packages/timm/layers/pos_embed_rel.pygen_relative_position_indexr!      sw    >JJJ>[[VAY 7fQi9PQRZZ[\]FQ4Z(6!T1*+==O%--aA6OAq!Gq	A-Aq!Gq	A-Aq!GF1I 11]Q.1vay=13DE0 .11"5 #$%%(?"N)>12&)>)BA&(=(A%"--//    new_window_sizenew_bias_shape.c           	      l   |d   dz  dz
  |d   dz  dz
  f}| j                   dk(  r|\  }}}| j                  \  }}}	||d   k(  r||d   k(  sJ ||k7  s|	|k7  rKt        j                  j                  j                  | j                  d      |dd      j                  d      } | S | j                   dk(  sJ |\  }
}| j                  \  }}|
|d   |d   z  z
  }t        ||z
  dz        }||f}|d   |d   k7  s|d   |d   k7  r|r| | d d d f   }| d | d d f   } nd }t        j                  j                  j                  | j                  dd      j                  dd	|d   |d   f      |dd      j                  d	|
|z
        j                  dd      } |t        j                  | |fd
      } | S )Nr   r   r      bicubicF)sizemodealign_corners      ?r   dim)ndimshaper   nn
functionalinterpolate	unsqueezesqueezeint	transposereshapeviewcat)rel_pos_biasr#   r$   dst_size_dst_hdst_wnum_attn_headssrc_hsrc_wdst_num_possrc_num_posnum_extra_tokenssrc_sizeextra_tokenss                  r     resize_rel_pos_bias_table_simplerG   I   s#   
  "Q&*OA,>,BQ,FGHA(5%'3'9'9$u#!(<<<E>Ue^ 88..::&&q)#	 ; 
 gaj @ 3   A%%%'Q&2&8&8#^&(1+*CD&663>?h'A;(1+%!)C+-=,=,>,AB+,>.>->,>,AB# 88..::&&q!,44aXa[(ST+5VW#	 ; 
 d2{%556yyA  '$yy,)E1Mr"   interpolation	antialiasc                    | j                         \  }}|\  }}||k(  sJ ||k7  r| j                  }| j                         } t        |dz        }	t        |dz        }
t	        j
                  | j                  dd      j                  d||	|	      |
|
f||      }|j                  ||      j                  dd      }|j                  |       |S | S )z
    Resample relative position bias table suggested in LeVit
    Adapted from: https://github.com/microsoft/Cream/blob/main/TinyViT/utils.py
    r+   r   r   )r(   r)   rI   )	r(   dtypefloatr5   r   r2   r   r8   to)position_bias_tablenew_sizerH   rI   L1nH1L2nH2
orig_dtypeS1S2$relative_position_bias_table_resizeds               r    resize_rel_pos_bias_table_levitrX   x   s     "&&(GBGB#::	Rx(..
1779s^s^/0}}''1-221c2rBb	0!, 155c2>FFq!L 	-,//
;33""r"   c                    t         rddlm} |d   dz  dz
  |d   dz  dz
  f}| j                  dk(  r1d}|\  }}}||d   k(  r||d   k(  sJ | j                  \  }	}
}|
|f}d}nJ| j                  dk(  sJ |\  }}| j                  \  }}	||d   |d   z  z
  }t        ||z
  dz        }||f}d}|d   |d   k7  s|d   |d   k7  r:|r| | d	d	d	f   }| d	| d	d	f   } nd	}d
 fd} ||d   |d         } ||d   |d         }t        j                  |      t        j                  |      g}|d   dz  }|d   dz  }t        j                  | |dz   d      }t        j                  | |dz   d      }t        ||      }g }t        |	      D ]  }|r.| d	d	|f   j                  |d   |d         j                         }n| |d	d	d	d	f   j                         }t         rgj                  |||j                         d      }t        j                   |||            j!                         j#                  | j$                        }n;t'        ||      } ||      j!                         j#                  | j$                        }|r|j                  dd      }|j)                  |        |rt        j*                  |d      } nt        j*                  |d      } ||sJ t        j*                  | |fd      } | S )a   Resize relative position bias table using more advanced interpolation.

    Modified from code in Microsoft Unilm (https://github.com/microsoft/unilm) repo (BeiT, BeiT-v2, etc).

    https://github.com/microsoft/unilm/blob/5255d52de86dad642810f5849dd357769346c1d7/beit/run_class_finetuning.py#L351

    Args:
        rel_pos_bias:
        new_window_size:
        new_bias_shape:

    Returns:

    r   )r2   r   r   r&   Fr+   TNc                 $    | d||z  z
  z  d|z
  z  S )N      ? )arns      r    geometric_progressionz8resize_rel_pos_bias_table.<locals>.geometric_progression   s    a1f%q11r"   c                 (   d\  }}||z
  dkD  r+||z   dz  } d|| dz        }||dz  kD  r|}n|}||z
  dkD  r+g }d}t        | dz        D ]  }|j                  |       ||dz   z  z  }  t        |      D 	cg c]  }	|	  }
}	|
dgz   |z   S c c}	w )N)g)\(?g      ?gư>       @r   r   r   )rangeappendreversed)srcdstleftrightqgpdiscurir<   r_idsr`   s              r    _calcz(resize_rel_pos_bias_table.<locals>._calc   s    #KD%$,%E\S(*1a:q=ED $,% CC3!8_ $

3qQU|#$ "*#/AaR/E/A3;$$ 0s   :
Brb   g?r[   cubic)kindr   r,   )
_USE_SCIPYscipyr2   r.   r/   r5   r   tensorr   r   rc   r8   rL   interp2dnumpyTensorr   rM   devicer   rd   r9   ) r:   r#   r$   r2   r;   rD   r<   r=   r>   r?   r@   rA   rE   has_flat_shaperB   rC   rF   rp   yxyxtytxdydxdyxall_rel_pos_biasrn   zfr^   r`   s                                   @r    resize_rel_pos_bias_tabler      sJ   & %"Q&*OA,>,BQ,FGHA(5%#!(<<<'3'9'9$u5>  A%%%'Q&2&8&8#^&(1+*CD&663>?h'{hqk!Xa[HQK%?')9(9(:A(=>L'(:*:):(:A(=>LL	2	%$ (1+x{+(1+x{+ll1ou||A/ a[Ca[C\\2#rCx-\\2#rCx-Rn ~& 	'A A&++HQK!EKKM Aq)//1((Aqwwyw(GLL2r+668;;L<O<OP
 ,B2cF%%'**<+>+>?FF2qM##A&'	'*  99%52>L 99%51=L#!!> 99lL%AqILr"   c                   r     e Zd ZdZd fd	Zd Zdej                  fdZd	de	ej                     fdZ
 xZS )

RelPosBiasz_ Relative Position Bias
    Adapted from Swin-V1 relative position bias impl, modularized.
    c                    t         |           |dk  sJ || _        |d   |d   z  | _        | j                  |z   fdz  |fz   | _        d|d   z  dz
  d|d   z  dz
  z  d|z  z   }t        j                  t        j                  ||            | _	        | j                  dt        | j                  |dkD        j                  d      d	       | j                          y )
Nr   r   r   r&   r   )r   r   F
persistent)super__init__window_sizewindow_area
bias_shaper0   	Parameterr   zerosrelative_position_bias_tableregister_bufferr!   r8   init_weights)selfr   	num_headsprefix_tokensr   	__class__s        r    r   zRelPosBias.__init__  s    !!!&&q>KN:++m;=AYLP!"[^!3a!7AA<NQR<R SVWZgVg g,.LLEZ\e9f,g)%'(8(8mVWFWX]]^`a 	 	
 	r"   c                 2    t        | j                  d       y Ng{Gz?)std)r	   r   r   s    r    r   zRelPosBias.init_weights"  s    d77SAr"   r   c                     | j                   | j                     }|j                  | j                        j	                  ddd      }|j                  d      j                         S )Nr   r   r   )r   r   r8   r   r   r3   r   r   relative_position_biass     r    get_biaszRelPosBias.get_bias%  sZ    !%!B!B4C_C_!`!7!<!<T__!M!U!UVWYZ\]!^%//2==??r"   shared_rel_posc                 (    || j                         z   S Nr   r   attnr   s      r    forwardzRelPosBias.forward+      dmmo%%r"   r   r   __name__
__module____qualname____doc__r   r   r   rx   r   r   r   __classcell__r   s   @r    r   r     s:    "B@%,, @&HU\\,B &r"   r   win_sizepretrained_win_sizec                    |dv sJ t        j                  | d   dz
   | d         j                  t         j                        }t        j                  | d   dz
   | d         j                  t         j                        }t        j                  t        ||            }|j                  ddd      j                         }|dk(  r|d   dkD  r5|d d d d dfxx   |d   dz
  z  cc<   |d d d d dfxx   |d   dz
  z  cc<   n4|d d d d dfxx   | d   dz
  z  cc<   |d d d d dfxx   | d   dz
  z  cc<   |dz  }t        j                  |      t        j                  d|j                         z         z  t        j                  d      z  }|S t        j                  |      t        j                  d|j                         z         z  }|S )N)swincrr   r   r   r      r[   )r   r   rM   float32r   r   r   r   signlog2absmathlog)r   r   r)   relative_coords_hrelative_coords_wrelative_coords_tables         r    gen_relative_log_coordsr   /  s   
 >!!!x{Q%7!EHHWx{Q%7!EHHW!KK/@BS(TU199!QBMMOv~q!A%!!Q'*/B1/E/IJ*!!Q'*/B1/E/IJ*!!Q'*x{Q?*!!Q'*x{Q?*" %

+@ AEJJ'++--E/ !/151!> !  !&

+@ AEII'++--E/ !/ ! r"   c                   v     e Zd ZdZ	 	 	 	 	 d fd	Zdej                  fdZddeej                     fdZ	 xZ
S )		RelPosMlpz Log-Coordinate Relative Position MLP
    Based on ideas presented in Swin-V2 paper (https://arxiv.org/abs/2111.09883)

    This impl covers the 'swin' implementation as well as two timm specific modes ('cr', and 'rw')
    c                 :   t         |           || _        | j                  d   | j                  d   z  | _        || _        || _        | j                  fdz  |fz   | _        |dk(  r#t        j                         | _	        d| _
        d}n"t        j                         | _	        d | _
        d}t        d||t        j                  |d	      | _        | j                  d
t!        |      j#                  d      d       | j                  dt%        |||      d       y )Nr   r   r   r      )TFT)g      ?g        )hidden_featuresout_features	act_layerbiasdropr   r   Fr   rel_coords_log)r)   )r   r   r   r   r   r   r   r0   Sigmoidbias_act	bias_gainIdentityr   ReLUmlpr   r!   r8   r   )	r   r   r   
hidden_dimr   r)   pretrained_window_sizemlp_biasr   s	           r    r   zRelPosMlp.__init__R  s    	&++A.1A1A!1DD*"++-1YL@6>JJLDMDN$HKKMDM!DNH&"gg
 	%'499"= 	 	 	#K1GdS 	 	r"   r   c                    | j                  | j                        }| j                  D|j                  d| j                        | j                     }|j                  | j
                        }|j                  ddd      }| j                  |      }| j                  | j                  |z  }| j                  r.t        j                  || j                  d| j                  dg      }|j                  d      j                         S )Nr   r   r   r   )r   r   r   r8   r   r   r   r   r   r   r   r   r3   r   r   s     r    r   zRelPosMlp.get_bias~  s    !%$*=*=!>''3%;%@%@T^^%TUYUqUq%r"%;%@%@%Q"!7!?!?1a!H!%/E!F>>%%)^^6L%L"%&UU+ADDVDVXY[_[m[mopCq%r"%//2==??r"   r   c                 (    || j                         z   S r   r   r   s      r    r   zRelPosMlp.forward  r   r"   )r      r   r   r   r   )r   r   r   r   r   r   rx   r   r   r   r   r   s   @r    r   r   L  sF     #)*X@%,, @&HU\\,B &r"   r   lengthmax_relative_positionc                     || dz
  }d|z  dz   }t        j                  | | |      }t        |       D ]4  }t        |       D ]$  }||z
  |z   }t        ||z
        |kD  rd||||f<   & 6 |S )a  Generate a one_hot lookup tensor to reindex embeddings along one dimension.

    Args:
        length: the length to reindex to.
        max_relative_position: the maximum relative position to consider.
            Relative position embeddings for distances above this threshold
            are zeroed out.
    Returns:
        a lookup Tensor of size [length, length, vocab_size] that satisfies
            ret[n,m,v] = 1{m - n + max_relative_position = v}.
    r   r   )r   r   rc   r   )r   r   
vocab_sizeretrn   r|   vs          r    generate_lookup_tensorr     s     $ &
**Q.J
++ffj
1C6] v 	AA--A1q5z11C1aL		 Jr"   heightwidthheight_lookupwidth_lookupc                     t        j                  d| |      }t        j                  d||      }||z  }|j                  | j                  d   ||      S )a\  Reindex 2d relative position bias with 2 independent einsum lookups.

    Adapted from:
     https://github.com/google-research/maxvit/blob/2e06a7f1f70c76e64cd3dabe5cd1b8c1a23c9fb7/maxvit/models/attention_utils.py

    Args:
        relative_position_tensor: tensor of shape
            [..., vocab_height, vocab_width, ...].
        height: height to reindex to.
        width: width to reindex to.
        height_lookup: one-hot height lookup
        width_lookup: one-hot width lookup
    Returns:
        reindexed_tensor: a Tensor of shape
            [..., height * width, height * width, ...]
    znhw,ixh->nixwznixw,jyw->nijxyr   )r   einsumr7   r/   )relative_position_tensorr   r   r   r   reindexed_tensorareas          r    reindex_2d_einsum_lookupr     sW    . ||O5M}]||$57GVE>D##$<$B$B1$EtTRRr"   c                   r     e Zd ZdZd fd	Zd Zdej                  fdZd	de	ej                     fdZ
 xZS )
RelPosBiasTfz Relative Position Bias Impl (Compatible with Tensorflow MaxViT models)
    Adapted from:
     https://github.com/google-research/maxvit/blob/2e06a7f1f70c76e64cd3dabe5cd1b8c1a23c9fb7/maxvit/models/attention_utils.py
    c                    t         |           |dk  sJ || _        |d   |d   z  | _        || _        d|d   z  dz
  }d|d   z  dz
  }| j                  ||f| _        t        j                  t        j                  | j
                              | _
        | j                  dt        |d         d       | j                  dt        |d         d       | j                          y )Nr   r   r   r   Fr   r   )r   r   r   r   r   r   r0   r   r   r   r   r   r   r   )r   r   r   r   vocab_heightvocab_widthr   s         r    r   zRelPosBiasTf.__init__  s    !!!&&q>KN:";q>)A-+a.(1,>><E,.LLT__9U,V)_.D[QR^.Tafg^-CKPQN-S`efr"   c                 Z    t         j                  j                  | j                  d       y r   )r0   initnormal_r   r   s    r    r   zRelPosBiasTf.init_weights  s    
99sCr"   r   c                     t        | j                  | j                  d   | j                  d   | j                  | j                        S )Nr   r   )r   r   r   r   r   r   s    r    r   zRelPosBiasTf.get_bias  sF    '--QQ
 	
r"   r   c                 (    || j                         z   S r   r   r   s      r    r   zRelPosBiasTf.forward  r   r"   r   r   r   r   s   @r    r   r     s8    D
%,, 
&HU\\,B &r"   r   )NF)r'   T)r   r   r   )&r   r   ostypingr   r   r   torch.nnr0   torch.nn.functionalr1   r   gridr   r2   r   r   r   weight_initr	   r5   environgetrs   boolrx   r!   rG   strrX   r   Moduler   r   r   r   r   r   r\   r"   r    <module>r      s    	 "      0  & 7;<q@

 -1!10c3h10sCx)10 10 \\	10h,sCx, c3h,d '	# # 	#@qsCxq c3hqh &  &J 06!S/!"38_!:@&		 @&J 04'}:SS S ||	S
 llS \\S:"&299 "&r"   